Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from pathlib import Path | |
| import json | |
| import numpy as np | |
| import torch | |
| from typing import List, Tuple, Optional | |
| import pytorch_lightning as pl | |
| from model import MusicAudioClassifier | |
| import argparse | |
| import torch | |
| import torchaudio | |
| import scipy.signal as signal | |
| from typing import Dict, List | |
| from dataset_f import FakeMusicCapsDataset | |
| from networks import MERT_AudioCNN | |
| from preprocess import get_segments_from_wav, find_optimal_segment_length | |
| def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| ์ค๋์ค ํ์ผ์ ๋ถ๋ฌ์ ์ธ๊ทธ๋จผํธ๋ก ๋ถํ ํฉ๋๋ค. | |
| ๊ณ ์ ๋ ๊ธธ์ด์ ์ธ๊ทธ๋จผํธ๋ฅผ ์ต๋ 48๊ฐ ์ถ์ถํ๊ณ , ๋ถ์กฑํ ๊ฒฝ์ฐ ํจ๋ฉ์ ์ถ๊ฐํฉ๋๋ค. | |
| Args: | |
| audio_path: ์ค๋์ค ํ์ผ ๊ฒฝ๋ก | |
| sr: ๋ชฉํ ์ํ๋ง ๋ ์ดํธ (๊ธฐ๋ณธ๊ฐ 24000) | |
| Returns: | |
| Tuple containing: | |
| - ์ค๋์ค ํํ์ด ๋ด๊ธด ํ ์ (48, 1, 240000) | |
| - ํจ๋ฉ ๋ง์คํฌ ํ ์ (48), True = ํจ๋ฉ, False = ์ค์ ์ค๋์ค | |
| """ | |
| beats, downbeats = get_segments_from_wav(audio_path) | |
| optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats) | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # ๋ฐ์ดํฐ ํ์ ์ float32๋ก ๋ณํ | |
| waveform = waveform.to(torch.float32) | |
| if sample_rate != sr: | |
| resampler = torchaudio.transforms.Resample(sample_rate, sr) | |
| waveform = resampler(waveform) | |
| # ๋ชจ๋ ธ๋ก ๋ณํ (ํ์ํ ๊ฒฝ์ฐ) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # 120000 ์ํ = 5์ด @ 24kHz | |
| fixed_samples = 240000 | |
| # 5์ด ๊ธธ์ด์ ๋ฌด์(silence) ํจ๋ฉ ์์ฑ | |
| if waveform.shape[1]<= 240000: | |
| padding = torch.zeros(1, 120000, dtype=torch.float32) | |
| # ์๋ณธ ์ค๋์ค ๋ค์ ํจ๋ฉ ์ถ๊ฐ | |
| waveform = torch.cat([waveform, padding], dim=1) | |
| # ๊ฐ downbeat์์ ์์ํ๋ segment ์์ฑ | |
| segments = [] | |
| for i, start_time in enumerate(cleaned_downbeats): | |
| # ์์ ์ํ ์ธ๋ฑ์ค ๊ณ์ฐ | |
| start_sample = int(start_time * sr) | |
| # ๋ ์ํ ์ธ๋ฑ์ค ๊ณ์ฐ (์์ ์ง์ + ๊ณ ์ ๊ธธ์ด) | |
| end_sample = start_sample + fixed_samples | |
| # ํ์ผ ๋์ ๋์ด๊ฐ๋์ง ํ์ธ | |
| if end_sample > waveform.size(1): | |
| continue | |
| # ์ ํํ fixed_samples ๊ธธ์ด์ ์ธ๊ทธ๋จผํธ ์ถ์ถ | |
| segment = waveform[:, start_sample:end_sample] | |
| # ํ์ดํจ์ค ํํฐ ์ ์ฉ - ์ฑ๋ ์ฐจ์ ์ ์ง | |
| #filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) # ์ด๊ฑฐ ๋ชจ๋ฅด๊ฒ ๋ค์ผ..? ๋ค์ํ ์ ์ฒ๋ฆฌ ํ inferenceํด๋ณด๋๊ฑฐ๋ ๊ด์ฐฎ๊ฒ ๋ค | |
| filtered = torch.tensor(segment.squeeze().numpy(), dtype=torch.float32).unsqueeze(0) # processor ์์ฐ๋ค? | |
| #์ฌ๊ธฐ์ ๋ชจ๋ธ๋ณ preprocess๊ฐ ์๋๋ ๋ค์ด๊ฐ๋๊ฒ ๋ง์. | |
| segments.append(filtered) | |
| # ์ต๋ 48๊ฐ ์ธ๊ทธ๋จผํธ๋ง ์ฌ์ฉ | |
| if len(segments) >= 48: | |
| break | |
| # ์ธ๊ทธ๋จผํธ๊ฐ ์๋ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
| if not segments: | |
| return torch.zeros((1, 1, fixed_samples), dtype=torch.float32), torch.ones(1, dtype=torch.bool) | |
| # ์คํํ์ฌ ํ ์๋ก ๋ณํ - (n_segments, 1, time_samples) ํํ ์ ์ง | |
| stacked_segments = torch.stack(segments) | |
| # ์ค์ ์ธ๊ทธ๋จผํธ ์ (ํจ๋ฉ ์๋) | |
| num_segments = stacked_segments.shape[0] | |
| # ํจ๋ฉ ๋ง์คํฌ ์์ฑ (False = ์ค์ ์ค๋์ค, True = ํจ๋ฉ) | |
| padding_mask = torch.zeros(48, dtype=torch.bool) | |
| # 48๊ฐ ๋ฏธ๋ง์ธ ๊ฒฝ์ฐ ํจ๋ฉ ์ถ๊ฐ | |
| if num_segments < 48: | |
| # ๋น ์ธ๊ทธ๋จผํธ๋ก ํจ๋ฉ (zeros) | |
| padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32) | |
| stacked_segments = torch.cat([stacked_segments, padding], dim=0) | |
| # ํจ๋ฉ ๋ง์คํฌ ์ค์ (True = ํจ๋ฉ) | |
| padding_mask[num_segments:] = True | |
| return stacked_segments, padding_mask | |
| def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict: | |
| """ | |
| Run inference on audio segments. | |
| Args: | |
| model: The loaded model | |
| audio_segments: Preprocessed audio segments tensor (48, 1, 240000) | |
| device: Device to run inference on | |
| Returns: | |
| Dictionary with prediction results | |
| """ | |
| model.eval() | |
| model.to(device) | |
| model = model.half() | |
| with torch.no_grad(): | |
| # ๋ฐ์ดํฐ ํํ ํ์ธ ๋ฐ ์กฐ์ | |
| # wav_collate_with_mask ํจ์์ ์ผ์นํ๋๋ก ์ฒ๋ฆฌ | |
| if audio_segments.shape[1] == 1: # (48, 1, 240000) ํํ | |
| # ์ฑ๋ ์ฐจ์ ์ ๊ฑฐํ๊ณ ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ | |
| audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000) | |
| else: | |
| audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์ฌ์ค audio๊ฐ ์๋๋ผ embedding segments์ผ์๋ | |
| # ๋ฐ์ดํฐ๋ฅผ half ํ์ ์ผ๋ก ๋ณํ | |
| if padding_mask.dim() == 1: | |
| padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48] | |
| audio_segments = audio_segments.to(device) | |
| mask = padding_mask.to(device) | |
| # ์ถ๋ก ์คํ (๋ง์คํฌ ํฌํจ) | |
| outputs = model(audio_segments, mask) | |
| # ๋ชจ๋ธ ์ถ๋ ฅ ๊ตฌ์กฐ์ ๋ฐ๋ผ ์ฒ๋ฆฌ | |
| if isinstance(outputs, dict): | |
| result = outputs | |
| else: | |
| # ๋จ์ผ ํ ์์ธ ๊ฒฝ์ฐ (๋ก์ง) | |
| logits = outputs.squeeze() | |
| prob = scaled_sigmoid(logits, scale_factor=1.0, linear_property=0.0).item() | |
| result = { | |
| "prediction": "Fake" if prob > 0.5 else "Real", | |
| "confidence": f"{max(prob, 1-prob)*100:.2f}", | |
| "fake_probability": f"{prob:.4f}", | |
| "real_probability": f"{1-prob:.4f}", | |
| "raw_output": logits.cpu().numpy().tolist() | |
| } | |
| return result | |
| # Custom scaling function to moderate extreme sigmoid values | |
| def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3): | |
| # Apply scaling to make sigmoid less extreme | |
| scaled_x = x * scale_factor | |
| # Combine sigmoid with linear component | |
| raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50) | |
| # Clip to ensure bounds | |
| return torch.clamp(raw_prob, min=0.01, max=0.99) | |
| # Apply the scaled sigmoid | |
| def get_model(model_type, device): | |
| """Load the specified model.""" | |
| if model_type == "MERT": | |
| ckpt_file = 'checkpoints/step=003432-val_loss=0.0216-val_acc=0.9963.ckpt' | |
| # map_location ์ถ๊ฐ | |
| model = MERT_AudioCNN.load_from_checkpoint( | |
| ckpt_file, | |
| map_location=device # ๋๋ 'cuda:0' ๋๋ 'cpu' | |
| ).to(device) | |
| model.eval() | |
| embed_dim = 768 | |
| elif model_type == "pure_MERT": | |
| from ICASSP_2026.MERT.networks import MERTFeatureExtractor | |
| model = MERTFeatureExtractor().to(device) | |
| embed_dim = 768 | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| model.eval() | |
| return model, embed_dim | |
| def inference(audio_path): | |
| # device ์ค์ ์ ๋ช ํํ ํ๊ธฐ | |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| backbone_model, input_dim = get_model('MERT', device) | |
| segments, padding_mask = load_audio(audio_path, sr=24000) | |
| segments = segments.to(device).to(torch.float32) | |
| padding_mask = padding_mask.to(device).unsqueeze(0) | |
| logits, embedding = backbone_model(segments.squeeze(1)) | |
| # ๋ชจ๋ธ ๋ก๋ํ ๋๋ map_location ์ถ๊ฐ | |
| model = MusicAudioClassifier.load_from_checkpoint( | |
| checkpoint_path='checkpoints/EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt', | |
| input_dim=input_dim, | |
| map_location=device # ์ด ๋ถ๋ถ ์ถ๊ฐ | |
| ) | |
| # Run inference | |
| print(f"Segments shape: {segments.shape}") | |
| print("Running inference...") | |
| results = run_inference(model, embedding, padding_mask, device) | |
| # ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
| print(f"Results: {results}") | |
| return results | |
| if __name__ == "__main__": | |
| inference("some path") | |