AI-Music-Detection-FST / inference.py
Seonghyeon Go
initial commit for AIGM
0ede85b
raw
history blame
9.2 kB
import os
from pathlib import Path
import json
import numpy as np
import torch
from typing import List, Tuple, Optional
import pytorch_lightning as pl
from model import MusicAudioClassifier
import argparse
import torch
import torchaudio
import scipy.signal as signal
from typing import Dict, List
from dataset_f import FakeMusicCapsDataset
from networks import MERT_AudioCNN
from preprocess import get_segments_from_wav, find_optimal_segment_length
def highpass_filter(y, sr, cutoff=1000, order=5):
if isinstance(sr, np.ndarray):
sr = np.mean(sr)
if not isinstance(sr, (int, float)):
raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}")
nyquist = 0.5 * sr
if cutoff <= 0 or cutoff >= nyquist:
cutoff = max(10, min(cutoff, nyquist - 1))
normal_cutoff = cutoff / nyquist
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
y_filtered = signal.lfilter(b, a, y)
return y_filtered
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
"""
์˜ค๋””์˜ค ํŒŒ์ผ์„ ๋ถˆ๋Ÿฌ์™€ ์„ธ๊ทธ๋จผํŠธ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
๊ณ ์ •๋œ ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ๋ฅผ ์ตœ๋Œ€ 48๊ฐœ ์ถ”์ถœํ•˜๊ณ , ๋ถ€์กฑํ•œ ๊ฒฝ์šฐ ํŒจ๋”ฉ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
Args:
audio_path: ์˜ค๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
sr: ๋ชฉํ‘œ ์ƒ˜ํ”Œ๋ง ๋ ˆ์ดํŠธ (๊ธฐ๋ณธ๊ฐ’ 24000)
Returns:
Tuple containing:
- ์˜ค๋””์˜ค ํŒŒํ˜•์ด ๋‹ด๊ธด ํ…์„œ (48, 1, 240000)
- ํŒจ๋”ฉ ๋งˆ์Šคํฌ ํ…์„œ (48), True = ํŒจ๋”ฉ, False = ์‹ค์ œ ์˜ค๋””์˜ค
"""
beats, downbeats = get_segments_from_wav(audio_path)
optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats)
waveform, sample_rate = torchaudio.load(audio_path)
# ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ float32๋กœ ๋ณ€ํ™˜
waveform = waveform.to(torch.float32)
if sample_rate != sr:
resampler = torchaudio.transforms.Resample(sample_rate, sr)
waveform = resampler(waveform)
# ๋ชจ๋…ธ๋กœ ๋ณ€ํ™˜ (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 120000 ์ƒ˜ํ”Œ = 5์ดˆ @ 24kHz
fixed_samples = 240000
# 5์ดˆ ๊ธธ์ด์˜ ๋ฌด์Œ(silence) ํŒจ๋”ฉ ์ƒ์„ฑ
if waveform.shape[1]<= 240000:
padding = torch.zeros(1, 120000, dtype=torch.float32)
# ์›๋ณธ ์˜ค๋””์˜ค ๋’ค์— ํŒจ๋”ฉ ์ถ”๊ฐ€
waveform = torch.cat([waveform, padding], dim=1)
# ๊ฐ downbeat์—์„œ ์‹œ์ž‘ํ•˜๋Š” segment ์ƒ์„ฑ
segments = []
for i, start_time in enumerate(cleaned_downbeats):
# ์‹œ์ž‘ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ
start_sample = int(start_time * sr)
# ๋ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ (์‹œ์ž‘ ์ง€์  + ๊ณ ์ • ๊ธธ์ด)
end_sample = start_sample + fixed_samples
# ํŒŒ์ผ ๋์„ ๋„˜์–ด๊ฐ€๋Š”์ง€ ํ™•์ธ
if end_sample > waveform.size(1):
continue
# ์ •ํ™•ํžˆ fixed_samples ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ
segment = waveform[:, start_sample:end_sample]
# ํ•˜์ดํŒจ์Šค ํ•„ํ„ฐ ์ ์šฉ - ์ฑ„๋„ ์ฐจ์› ์œ ์ง€
#filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) # ์ด๊ฑฐ ๋ชจ๋ฅด๊ฒ ๋‹ค์•ผ..? ๋‹ค์–‘ํ•œ ์ „์ฒ˜๋ฆฌ ํ›„ inferenceํ•ด๋ณด๋Š”๊ฑฐ๋„ ๊ดœ์ฐฎ๊ฒ ๋„ค
filtered = torch.tensor(segment.squeeze().numpy(), dtype=torch.float32).unsqueeze(0) # processor ์•ˆ์“ฐ๋„ค?
#์—ฌ๊ธฐ์— ๋ชจ๋ธ๋ณ„ preprocess๊ฐ€ ์›๋ž˜๋Š” ๋“ค์–ด๊ฐ€๋Š”๊ฒŒ ๋งž์Œ.
segments.append(filtered)
# ์ตœ๋Œ€ 48๊ฐœ ์„ธ๊ทธ๋จผํŠธ๋งŒ ์‚ฌ์šฉ
if len(segments) >= 48:
break
# ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if not segments:
return torch.zeros((1, 1, fixed_samples), dtype=torch.float32), torch.ones(1, dtype=torch.bool)
# ์Šคํƒํ•˜์—ฌ ํ…์„œ๋กœ ๋ณ€ํ™˜ - (n_segments, 1, time_samples) ํ˜•ํƒœ ์œ ์ง€
stacked_segments = torch.stack(segments)
# ์‹ค์ œ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ (ํŒจ๋”ฉ ์•„๋‹˜)
num_segments = stacked_segments.shape[0]
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์ƒ์„ฑ (False = ์‹ค์ œ ์˜ค๋””์˜ค, True = ํŒจ๋”ฉ)
padding_mask = torch.zeros(48, dtype=torch.bool)
# 48๊ฐœ ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ ํŒจ๋”ฉ ์ถ”๊ฐ€
if num_segments < 48:
# ๋นˆ ์„ธ๊ทธ๋จผํŠธ๋กœ ํŒจ๋”ฉ (zeros)
padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32)
stacked_segments = torch.cat([stacked_segments, padding], dim=0)
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์„ค์ • (True = ํŒจ๋”ฉ)
padding_mask[num_segments:] = True
return stacked_segments, padding_mask
def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict:
"""
Run inference on audio segments.
Args:
model: The loaded model
audio_segments: Preprocessed audio segments tensor (48, 1, 240000)
device: Device to run inference on
Returns:
Dictionary with prediction results
"""
model.eval()
model.to(device)
model = model.half()
with torch.no_grad():
# ๋ฐ์ดํ„ฐ ํ˜•ํƒœ ํ™•์ธ ๋ฐ ์กฐ์ •
# wav_collate_with_mask ํ•จ์ˆ˜์™€ ์ผ์น˜ํ•˜๋„๋ก ์ฒ˜๋ฆฌ
if audio_segments.shape[1] == 1: # (48, 1, 240000) ํ˜•ํƒœ
# ์ฑ„๋„ ์ฐจ์› ์ œ๊ฑฐํ•˜๊ณ  ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000)
else:
audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์‚ฌ์‹ค audio๊ฐ€ ์•„๋‹ˆ๋ผ embedding segments์ผ์ˆ˜๋„
# ๋ฐ์ดํ„ฐ๋ฅผ half ํƒ€์ž…์œผ๋กœ ๋ณ€ํ™˜
if padding_mask.dim() == 1:
padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48]
audio_segments = audio_segments.to(device)
mask = padding_mask.to(device)
# ์ถ”๋ก  ์‹คํ–‰ (๋งˆ์Šคํฌ ํฌํ•จ)
outputs = model(audio_segments, mask)
# ๋ชจ๋ธ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ฒ˜๋ฆฌ
if isinstance(outputs, dict):
result = outputs
else:
# ๋‹จ์ผ ํ…์„œ์ธ ๊ฒฝ์šฐ (๋กœ์ง“)
logits = outputs.squeeze()
prob = scaled_sigmoid(logits, scale_factor=1.0, linear_property=0.0).item()
result = {
"prediction": "Fake" if prob > 0.5 else "Real",
"confidence": f"{max(prob, 1-prob)*100:.2f}",
"fake_probability": f"{prob:.4f}",
"real_probability": f"{1-prob:.4f}",
"raw_output": logits.cpu().numpy().tolist()
}
return result
# Custom scaling function to moderate extreme sigmoid values
def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
# Apply scaling to make sigmoid less extreme
scaled_x = x * scale_factor
# Combine sigmoid with linear component
raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50)
# Clip to ensure bounds
return torch.clamp(raw_prob, min=0.011, max=0.989)
# Apply the scaled sigmoid
def get_model(model_type, device):
"""Load the specified model."""
if model_type == "MERT":
#from model import MusicAudioClassifier
#model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
ckpt_file = 'checkpoints/step=007000-val_loss=0.1831-val_acc=0.9278.ckpt'#'mert_finetune_10.pth'
model = MERT_AudioCNN.load_from_checkpoint(ckpt_file).to(device)
model.eval()
# model.load_state_dict(torch.load(ckpt_file, map_location=device))
embed_dim = 768
elif model_type == "pure_MERT":
from ISMIR_2025.MERT.networks import MERTFeatureExtractor
model = MERTFeatureExtractor().to(device)
embed_dim = 768
else:
raise ValueError(f"Unknown model type: {model_type}")
model.eval()
return model, embed_dim
def inference(audio_path):
backbone_model, input_dim = get_model('MERT', 'cuda')
segments, padding_mask = load_audio(audio_path, sr=24000)
segments = segments.to('cuda').to(torch.float32)
padding_mask = padding_mask.to('cuda').unsqueeze(0)
logits,embedding = backbone_model(segments.squeeze(1))
test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
test_data, test_target = test_dataset[0]
test_data = test_data.to('cuda').to(torch.float32)
test_target = test_target.to('cuda')
output, _ = backbone_model(test_data.unsqueeze(0))
# ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ์ถ”๊ฐ€
model = MusicAudioClassifier.load_from_checkpoint(
checkpoint_path = 'checkpoints/EmbeddingModel_MERT_768-epoch=0073-val_loss=0.1058-val_acc=0.9585-val_f1=0.9366-val_precision=0.9936-val_recall=0.8857.ckpt',
input_dim=input_dim,
#emb_model=backbone_model
is_emb = True,
#mode = 'both'
)
# Run inference
print(f"Segments shape: {segments.shape}")
print("Running inference...")
results = run_inference(model, embedding, padding_mask, 'cuda')
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print(f"Results: {results}")
asdf
return results
if __name__ == "__main__":
main()