AI-Music-Detection-FST / inference.py
slslslrhfem
change probability func
693d2c7
import os
from pathlib import Path
import json
import numpy as np
import torch
from typing import List, Tuple, Optional
import pytorch_lightning as pl
from model import MusicAudioClassifier
import argparse
import torch
import torchaudio
import scipy.signal as signal
from typing import Dict, List
from dataset_f import FakeMusicCapsDataset
from networks import MERT_AudioCNN
from preprocess import get_segments_from_wav, find_optimal_segment_length
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
"""
์˜ค๋””์˜ค ํŒŒ์ผ์„ ๋ถˆ๋Ÿฌ์™€ ์„ธ๊ทธ๋จผํŠธ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
๊ณ ์ •๋œ ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ๋ฅผ ์ตœ๋Œ€ 48๊ฐœ ์ถ”์ถœํ•˜๊ณ , ๋ถ€์กฑํ•œ ๊ฒฝ์šฐ ํŒจ๋”ฉ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
Args:
audio_path: ์˜ค๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
sr: ๋ชฉํ‘œ ์ƒ˜ํ”Œ๋ง ๋ ˆ์ดํŠธ (๊ธฐ๋ณธ๊ฐ’ 24000)
Returns:
Tuple containing:
- ์˜ค๋””์˜ค ํŒŒํ˜•์ด ๋‹ด๊ธด ํ…์„œ (48, 1, 240000)
- ํŒจ๋”ฉ ๋งˆ์Šคํฌ ํ…์„œ (48), True = ํŒจ๋”ฉ, False = ์‹ค์ œ ์˜ค๋””์˜ค
"""
beats, downbeats = get_segments_from_wav(audio_path)
optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats)
waveform, sample_rate = torchaudio.load(audio_path)
# ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ float32๋กœ ๋ณ€ํ™˜
waveform = waveform.to(torch.float32)
if sample_rate != sr:
resampler = torchaudio.transforms.Resample(sample_rate, sr)
waveform = resampler(waveform)
# ๋ชจ๋…ธ๋กœ ๋ณ€ํ™˜ (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 120000 ์ƒ˜ํ”Œ = 5์ดˆ @ 24kHz
fixed_samples = 240000
# 5์ดˆ ๊ธธ์ด์˜ ๋ฌด์Œ(silence) ํŒจ๋”ฉ ์ƒ์„ฑ
if waveform.shape[1]<= 240000:
padding = torch.zeros(1, 120000, dtype=torch.float32)
# ์›๋ณธ ์˜ค๋””์˜ค ๋’ค์— ํŒจ๋”ฉ ์ถ”๊ฐ€
waveform = torch.cat([waveform, padding], dim=1)
# ๊ฐ downbeat์—์„œ ์‹œ์ž‘ํ•˜๋Š” segment ์ƒ์„ฑ
segments = []
for i, start_time in enumerate(cleaned_downbeats):
# ์‹œ์ž‘ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ
start_sample = int(start_time * sr)
# ๋ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ (์‹œ์ž‘ ์ง€์  + ๊ณ ์ • ๊ธธ์ด)
end_sample = start_sample + fixed_samples
# ํŒŒ์ผ ๋์„ ๋„˜์–ด๊ฐ€๋Š”์ง€ ํ™•์ธ
if end_sample > waveform.size(1):
continue
# ์ •ํ™•ํžˆ fixed_samples ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ
segment = waveform[:, start_sample:end_sample]
# ํ•˜์ดํŒจ์Šค ํ•„ํ„ฐ ์ ์šฉ - ์ฑ„๋„ ์ฐจ์› ์œ ์ง€
#filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) # ์ด๊ฑฐ ๋ชจ๋ฅด๊ฒ ๋‹ค์•ผ..? ๋‹ค์–‘ํ•œ ์ „์ฒ˜๋ฆฌ ํ›„ inferenceํ•ด๋ณด๋Š”๊ฑฐ๋„ ๊ดœ์ฐฎ๊ฒ ๋„ค
filtered = torch.tensor(segment.squeeze().numpy(), dtype=torch.float32).unsqueeze(0) # processor ์•ˆ์“ฐ๋„ค?
#์—ฌ๊ธฐ์— ๋ชจ๋ธ๋ณ„ preprocess๊ฐ€ ์›๋ž˜๋Š” ๋“ค์–ด๊ฐ€๋Š”๊ฒŒ ๋งž์Œ.
segments.append(filtered)
# ์ตœ๋Œ€ 48๊ฐœ ์„ธ๊ทธ๋จผํŠธ๋งŒ ์‚ฌ์šฉ
if len(segments) >= 48:
break
# ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if not segments:
return torch.zeros((1, 1, fixed_samples), dtype=torch.float32), torch.ones(1, dtype=torch.bool)
# ์Šคํƒํ•˜์—ฌ ํ…์„œ๋กœ ๋ณ€ํ™˜ - (n_segments, 1, time_samples) ํ˜•ํƒœ ์œ ์ง€
stacked_segments = torch.stack(segments)
# ์‹ค์ œ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ (ํŒจ๋”ฉ ์•„๋‹˜)
num_segments = stacked_segments.shape[0]
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์ƒ์„ฑ (False = ์‹ค์ œ ์˜ค๋””์˜ค, True = ํŒจ๋”ฉ)
padding_mask = torch.zeros(48, dtype=torch.bool)
# 48๊ฐœ ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ ํŒจ๋”ฉ ์ถ”๊ฐ€
if num_segments < 48:
# ๋นˆ ์„ธ๊ทธ๋จผํŠธ๋กœ ํŒจ๋”ฉ (zeros)
padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32)
stacked_segments = torch.cat([stacked_segments, padding], dim=0)
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์„ค์ • (True = ํŒจ๋”ฉ)
padding_mask[num_segments:] = True
return stacked_segments, padding_mask
def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict:
"""
Run inference on audio segments.
Args:
model: The loaded model
audio_segments: Preprocessed audio segments tensor (48, 1, 240000)
device: Device to run inference on
Returns:
Dictionary with prediction results
"""
model.eval()
model.to(device)
model = model.half()
with torch.no_grad():
# ๋ฐ์ดํ„ฐ ํ˜•ํƒœ ํ™•์ธ ๋ฐ ์กฐ์ •
# wav_collate_with_mask ํ•จ์ˆ˜์™€ ์ผ์น˜ํ•˜๋„๋ก ์ฒ˜๋ฆฌ
if audio_segments.shape[1] == 1: # (48, 1, 240000) ํ˜•ํƒœ
# ์ฑ„๋„ ์ฐจ์› ์ œ๊ฑฐํ•˜๊ณ  ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000)
else:
audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์‚ฌ์‹ค audio๊ฐ€ ์•„๋‹ˆ๋ผ embedding segments์ผ์ˆ˜๋„
# ๋ฐ์ดํ„ฐ๋ฅผ half ํƒ€์ž…์œผ๋กœ ๋ณ€ํ™˜
if padding_mask.dim() == 1:
padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48]
audio_segments = audio_segments.to(device)
mask = padding_mask.to(device)
# ์ถ”๋ก  ์‹คํ–‰ (๋งˆ์Šคํฌ ํฌํ•จ)
outputs = model(audio_segments, mask)
# ๋ชจ๋ธ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ฒ˜๋ฆฌ
if isinstance(outputs, dict):
result = outputs
else:
# ๋‹จ์ผ ํ…์„œ์ธ ๊ฒฝ์šฐ (๋กœ์ง“)
logits = outputs.squeeze()
prob = scaled_sigmoid(logits, scale_factor=1.0, linear_property=0.0).item()
result = {
"prediction": "Fake" if prob > 0.5 else "Real",
"confidence": f"{max(prob, 1-prob)*100:.2f}",
"fake_probability": f"{prob:.4f}",
"real_probability": f"{1-prob:.4f}",
"raw_output": logits.cpu().numpy().tolist()
}
return result
# Custom scaling function to moderate extreme sigmoid values
def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
# Apply scaling to make sigmoid less extreme
scaled_x = x * scale_factor
# Combine sigmoid with linear component
raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50)
# Clip to ensure bounds
return torch.clamp(raw_prob, min=0.01, max=0.99)
# Apply the scaled sigmoid
def get_model(model_type, device):
"""Load the specified model."""
if model_type == "MERT":
ckpt_file = 'checkpoints/step=003432-val_loss=0.0216-val_acc=0.9963.ckpt'
# map_location ์ถ”๊ฐ€
model = MERT_AudioCNN.load_from_checkpoint(
ckpt_file,
map_location=device # ๋˜๋Š” 'cuda:0' ๋˜๋Š” 'cpu'
).to(device)
model.eval()
embed_dim = 768
elif model_type == "pure_MERT":
from ICASSP_2026.MERT.networks import MERTFeatureExtractor
model = MERTFeatureExtractor().to(device)
embed_dim = 768
else:
raise ValueError(f"Unknown model type: {model_type}")
model.eval()
return model, embed_dim
def inference(audio_path):
# device ์„ค์ •์„ ๋ช…ํ™•ํžˆ ํ•˜๊ธฐ
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
backbone_model, input_dim = get_model('MERT', device)
segments, padding_mask = load_audio(audio_path, sr=24000)
segments = segments.to(device).to(torch.float32)
padding_mask = padding_mask.to(device).unsqueeze(0)
logits, embedding = backbone_model(segments.squeeze(1))
# ๋ชจ๋ธ ๋กœ๋“œํ•  ๋•Œ๋„ map_location ์ถ”๊ฐ€
model = MusicAudioClassifier.load_from_checkpoint(
checkpoint_path='checkpoints/EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt',
input_dim=input_dim,
map_location=device # ์ด ๋ถ€๋ถ„ ์ถ”๊ฐ€
)
# Run inference
print(f"Segments shape: {segments.shape}")
print("Running inference...")
results = run_inference(model, embedding, padding_mask, device)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print(f"Results: {results}")
return results
if __name__ == "__main__":
inference("some path")