AI-Music-Detection-FST / inference.py
Seonghyeon Go
initial segment transformer
c3c908f
raw
history blame
10.1 kB
import os
from pathlib import Path
import json
import numpy as np
import torch
from typing import List, Tuple, Optional
import pytorch_lightning as pl
from model import MusicAudioClassifier
import argparse
import torch
import torchaudio
import scipy.signal as signal
from typing import Dict, List
from dataset_f import FakeMusicCapsDataset
from preprocess import get_segments_from_wav, find_optimal_segment_length
def highpass_filter(y, sr, cutoff=1000, order=5):
if isinstance(sr, np.ndarray):
sr = np.mean(sr)
if not isinstance(sr, (int, float)):
raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}")
nyquist = 0.5 * sr
if cutoff <= 0 or cutoff >= nyquist:
cutoff = max(10, min(cutoff, nyquist - 1))
normal_cutoff = cutoff / nyquist
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
y_filtered = signal.lfilter(b, a, y)
return y_filtered
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
"""
์˜ค๋””์˜ค ํŒŒ์ผ์„ ๋ถˆ๋Ÿฌ์™€ ์„ธ๊ทธ๋จผํŠธ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
๊ณ ์ •๋œ ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ๋ฅผ ์ตœ๋Œ€ 48๊ฐœ ์ถ”์ถœํ•˜๊ณ , ๋ถ€์กฑํ•œ ๊ฒฝ์šฐ ํŒจ๋”ฉ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
Args:
audio_path: ์˜ค๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
sr: ๋ชฉํ‘œ ์ƒ˜ํ”Œ๋ง ๋ ˆ์ดํŠธ (๊ธฐ๋ณธ๊ฐ’ 24000)
Returns:
Tuple containing:
- ์˜ค๋””์˜ค ํŒŒํ˜•์ด ๋‹ด๊ธด ํ…์„œ (48, 1, 240000)
- ํŒจ๋”ฉ ๋งˆ์Šคํฌ ํ…์„œ (48), True = ํŒจ๋”ฉ, False = ์‹ค์ œ ์˜ค๋””์˜ค
"""
beats, downbeats = get_segments_from_wav(audio_path)
optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats)
waveform, sample_rate = torchaudio.load(audio_path)
# ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ float32๋กœ ๋ณ€ํ™˜
waveform = waveform.to(torch.float32)
if sample_rate != sr:
resampler = torchaudio.transforms.Resample(sample_rate, sr)
waveform = resampler(waveform)
# ๋ชจ๋…ธ๋กœ ๋ณ€ํ™˜ (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 120000 ์ƒ˜ํ”Œ = 5์ดˆ @ 24kHz
fixed_samples = 240000
# 5์ดˆ ๊ธธ์ด์˜ ๋ฌด์Œ(silence) ํŒจ๋”ฉ ์ƒ์„ฑ
if waveform.shape[1]<= 240000:
padding = torch.zeros(1, 120000, dtype=torch.float32)
# ์›๋ณธ ์˜ค๋””์˜ค ๋’ค์— ํŒจ๋”ฉ ์ถ”๊ฐ€
waveform = torch.cat([waveform, padding], dim=1)
# ๊ฐ downbeat์—์„œ ์‹œ์ž‘ํ•˜๋Š” segment ์ƒ์„ฑ
segments = []
for i, start_time in enumerate(cleaned_downbeats):
# ์‹œ์ž‘ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ
start_sample = int(start_time * sr)
# ๋ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ (์‹œ์ž‘ ์ง€์  + ๊ณ ์ • ๊ธธ์ด)
end_sample = start_sample + fixed_samples
# ํŒŒ์ผ ๋์„ ๋„˜์–ด๊ฐ€๋Š”์ง€ ํ™•์ธ
if end_sample > waveform.size(1):
continue
# ์ •ํ™•ํžˆ fixed_samples ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ
segment = waveform[:, start_sample:end_sample]
# ํ•˜์ดํŒจ์Šค ํ•„ํ„ฐ ์ ์šฉ - ์ฑ„๋„ ์ฐจ์› ์œ ์ง€
#filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) # ์ด๊ฑฐ ๋ชจ๋ฅด๊ฒ ๋‹ค์•ผ..? ๋‹ค์–‘ํ•œ ์ „์ฒ˜๋ฆฌ ํ›„ inferenceํ•ด๋ณด๋Š”๊ฑฐ๋„ ๊ดœ์ฐฎ๊ฒ ๋„ค
filtered = torch.tensor(segment.squeeze().numpy(), dtype=torch.float32).unsqueeze(0) # processor ์•ˆ์“ฐ๋„ค?
#์—ฌ๊ธฐ์— ๋ชจ๋ธ๋ณ„ preprocess๊ฐ€ ์›๋ž˜๋Š” ๋“ค์–ด๊ฐ€๋Š”๊ฒŒ ๋งž์Œ.
segments.append(filtered)
# ์ตœ๋Œ€ 48๊ฐœ ์„ธ๊ทธ๋จผํŠธ๋งŒ ์‚ฌ์šฉ
if len(segments) >= 48:
break
# ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if not segments:
return torch.zeros((1, 1, fixed_samples), dtype=torch.float32), torch.ones(1, dtype=torch.bool)
# ์Šคํƒํ•˜์—ฌ ํ…์„œ๋กœ ๋ณ€ํ™˜ - (n_segments, 1, time_samples) ํ˜•ํƒœ ์œ ์ง€
stacked_segments = torch.stack(segments)
# ์‹ค์ œ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ (ํŒจ๋”ฉ ์•„๋‹˜)
num_segments = stacked_segments.shape[0]
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์ƒ์„ฑ (False = ์‹ค์ œ ์˜ค๋””์˜ค, True = ํŒจ๋”ฉ)
padding_mask = torch.zeros(48, dtype=torch.bool)
# 48๊ฐœ ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ ํŒจ๋”ฉ ์ถ”๊ฐ€
if num_segments < 48:
# ๋นˆ ์„ธ๊ทธ๋จผํŠธ๋กœ ํŒจ๋”ฉ (zeros)
padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32)
stacked_segments = torch.cat([stacked_segments, padding], dim=0)
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์„ค์ • (True = ํŒจ๋”ฉ)
padding_mask[num_segments:] = True
return stacked_segments, padding_mask
def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict:
"""
Run inference on audio segments.
Args:
model: The loaded model
audio_segments: Preprocessed audio segments tensor (48, 1, 240000)
device: Device to run inference on
Returns:
Dictionary with prediction results
"""
model.eval()
model.to(device)
model = model.half()
with torch.no_grad():
# ๋ฐ์ดํ„ฐ ํ˜•ํƒœ ํ™•์ธ ๋ฐ ์กฐ์ •
# wav_collate_with_mask ํ•จ์ˆ˜์™€ ์ผ์น˜ํ•˜๋„๋ก ์ฒ˜๋ฆฌ
if audio_segments.shape[1] == 1: # (48, 1, 240000) ํ˜•ํƒœ
# ์ฑ„๋„ ์ฐจ์› ์ œ๊ฑฐํ•˜๊ณ  ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000)
else:
audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์‚ฌ์‹ค audio๊ฐ€ ์•„๋‹ˆ๋ผ embedding segments์ผ์ˆ˜๋„
# ๋ฐ์ดํ„ฐ๋ฅผ half ํƒ€์ž…์œผ๋กœ ๋ณ€ํ™˜
if padding_mask.dim() == 1:
padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48]
audio_segments = audio_segments.to(device).half()
mask = padding_mask.to(device)
# ์ถ”๋ก  ์‹คํ–‰ (๋งˆ์Šคํฌ ํฌํ•จ)
outputs = model(audio_segments, mask)
# ๋ชจ๋ธ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ฒ˜๋ฆฌ
if isinstance(outputs, dict):
result = outputs
else:
# ๋‹จ์ผ ํ…์„œ์ธ ๊ฒฝ์šฐ (๋กœ์ง“)
logits = outputs.squeeze()
prob = scaled_sigmoid(logits, scale_factor=1.0, linear_property=0.0).item()
result = {
"prediction": "Fake" if prob > 0.5 else "Real",
"confidence": f"{max(prob, 1-prob)*100:.2f}",
"fake_probability": f"{prob:.4f}",
"real_probability": f"{1-prob:.4f}",
"raw_output": logits.cpu().numpy().tolist()
}
return result
# Custom scaling function to moderate extreme sigmoid values
def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
# Apply scaling to make sigmoid less extreme
scaled_x = x * scale_factor
# Combine sigmoid with linear component
raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50)
# Clip to ensure bounds
return torch.clamp(raw_prob, min=0.011, max=0.989)
# Apply the scaled sigmoid
def get_model(model_type, device):
"""Load the specified model."""
if model_type == "MERT":
from ISMIR_2025.MERT.networks import CCV
#from model import MusicAudioClassifier
model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
#model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
ckpt_file = 'mert_finetune_10.pth'
model.load_state_dict(torch.load(ckpt_file, map_location=device))
embed_dim = 768
elif model_type == "pure_MERT":
from ISMIR_2025.MERT.networks import MERTFeatureExtractor
model = MERTFeatureExtractor().to(device)
embed_dim = 768
else:
raise ValueError(f"Unknown model type: {model_type}")
model.eval()
return model, embed_dim
def inference(audio_path):
parser = argparse.ArgumentParser(description="Music classifier inference")
parser.add_argument("--model_type", type=str, required=True, choices=["MERT", "AudioCNN"], help="Type of model")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--output_path", type=str, default=None, help="Path to save results (default: print to console)")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run inference on")
args = parser.parse_args()
audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3"
# Note: Model loading would be handled by your code
print(f"Loading model of type {args.model_type} from {args.checkpoint_path}")
backbone_model, input_dim = get_model('MERT', 'cuda')
segments, padding_mask = load_audio(audio_path, sr=24000)
segments = segments.to(args.device).to(torch.float32)
padding_mask = padding_mask.to(args.device).unsqueeze(0)
logits,embedding = backbone_model(segments.squeeze(1))
test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
test_data, test_target = test_dataset[0]
test_data = test_data.to(args.device).to(torch.float32)
test_target = test_target.to(args.device)
output, _ = backbone_model(test_data.unsqueeze(0))
# ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ์ถ”๊ฐ€
model = MusicAudioClassifier.load_from_checkpoint(
args.checkpoint_path,
input_dim=input_dim,
#emb_model=backbone_model
is_emb = True,
#mode = 'both'
)
# Run inference
print(f"Segments shape: {segments.shape}")
print("Running inference...")
results = run_inference(model, embedding, padding_mask, device=args.device)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print(f"Results: {results}")
# ๊ฒฐ๊ณผ ์ €์žฅ
if args.output_path:
with open(args.output_path, 'w') as f:
json.dump(results, f, indent=4)
print(f"Results saved to {args.output_path}")
return results
if __name__ == "__main__":
main()