Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| import json | |
| import numpy as np | |
| import torch | |
| from typing import List, Tuple, Optional | |
| import pytorch_lightning as pl | |
| from model import MusicAudioClassifier | |
| import argparse | |
| import torch | |
| import torchaudio | |
| import scipy.signal as signal | |
| from typing import Dict, List | |
| from dataset_f import FakeMusicCapsDataset | |
| from preprocess import get_segments_from_wav, find_optimal_segment_length | |
| def highpass_filter(y, sr, cutoff=1000, order=5): | |
| if isinstance(sr, np.ndarray): | |
| sr = np.mean(sr) | |
| if not isinstance(sr, (int, float)): | |
| raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}") | |
| nyquist = 0.5 * sr | |
| if cutoff <= 0 or cutoff >= nyquist: | |
| cutoff = max(10, min(cutoff, nyquist - 1)) | |
| normal_cutoff = cutoff / nyquist | |
| b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) | |
| y_filtered = signal.lfilter(b, a, y) | |
| return y_filtered | |
| def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| ์ค๋์ค ํ์ผ์ ๋ถ๋ฌ์ ์ธ๊ทธ๋จผํธ๋ก ๋ถํ ํฉ๋๋ค. | |
| ๊ณ ์ ๋ ๊ธธ์ด์ ์ธ๊ทธ๋จผํธ๋ฅผ ์ต๋ 48๊ฐ ์ถ์ถํ๊ณ , ๋ถ์กฑํ ๊ฒฝ์ฐ ํจ๋ฉ์ ์ถ๊ฐํฉ๋๋ค. | |
| Args: | |
| audio_path: ์ค๋์ค ํ์ผ ๊ฒฝ๋ก | |
| sr: ๋ชฉํ ์ํ๋ง ๋ ์ดํธ (๊ธฐ๋ณธ๊ฐ 24000) | |
| Returns: | |
| Tuple containing: | |
| - ์ค๋์ค ํํ์ด ๋ด๊ธด ํ ์ (48, 1, 240000) | |
| - ํจ๋ฉ ๋ง์คํฌ ํ ์ (48), True = ํจ๋ฉ, False = ์ค์ ์ค๋์ค | |
| """ | |
| beats, downbeats = get_segments_from_wav(audio_path) | |
| optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats) | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # ๋ฐ์ดํฐ ํ์ ์ float32๋ก ๋ณํ | |
| waveform = waveform.to(torch.float32) | |
| if sample_rate != sr: | |
| resampler = torchaudio.transforms.Resample(sample_rate, sr) | |
| waveform = resampler(waveform) | |
| # ๋ชจ๋ ธ๋ก ๋ณํ (ํ์ํ ๊ฒฝ์ฐ) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| # 120000 ์ํ = 5์ด @ 24kHz | |
| fixed_samples = 240000 | |
| # 5์ด ๊ธธ์ด์ ๋ฌด์(silence) ํจ๋ฉ ์์ฑ | |
| if waveform.shape[1]<= 240000: | |
| padding = torch.zeros(1, 120000, dtype=torch.float32) | |
| # ์๋ณธ ์ค๋์ค ๋ค์ ํจ๋ฉ ์ถ๊ฐ | |
| waveform = torch.cat([waveform, padding], dim=1) | |
| # ๊ฐ downbeat์์ ์์ํ๋ segment ์์ฑ | |
| segments = [] | |
| for i, start_time in enumerate(cleaned_downbeats): | |
| # ์์ ์ํ ์ธ๋ฑ์ค ๊ณ์ฐ | |
| start_sample = int(start_time * sr) | |
| # ๋ ์ํ ์ธ๋ฑ์ค ๊ณ์ฐ (์์ ์ง์ + ๊ณ ์ ๊ธธ์ด) | |
| end_sample = start_sample + fixed_samples | |
| # ํ์ผ ๋์ ๋์ด๊ฐ๋์ง ํ์ธ | |
| if end_sample > waveform.size(1): | |
| continue | |
| # ์ ํํ fixed_samples ๊ธธ์ด์ ์ธ๊ทธ๋จผํธ ์ถ์ถ | |
| segment = waveform[:, start_sample:end_sample] | |
| # ํ์ดํจ์ค ํํฐ ์ ์ฉ - ์ฑ๋ ์ฐจ์ ์ ์ง | |
| #filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) # ์ด๊ฑฐ ๋ชจ๋ฅด๊ฒ ๋ค์ผ..? ๋ค์ํ ์ ์ฒ๋ฆฌ ํ inferenceํด๋ณด๋๊ฑฐ๋ ๊ด์ฐฎ๊ฒ ๋ค | |
| filtered = torch.tensor(segment.squeeze().numpy(), dtype=torch.float32).unsqueeze(0) # processor ์์ฐ๋ค? | |
| #์ฌ๊ธฐ์ ๋ชจ๋ธ๋ณ preprocess๊ฐ ์๋๋ ๋ค์ด๊ฐ๋๊ฒ ๋ง์. | |
| segments.append(filtered) | |
| # ์ต๋ 48๊ฐ ์ธ๊ทธ๋จผํธ๋ง ์ฌ์ฉ | |
| if len(segments) >= 48: | |
| break | |
| # ์ธ๊ทธ๋จผํธ๊ฐ ์๋ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
| if not segments: | |
| return torch.zeros((1, 1, fixed_samples), dtype=torch.float32), torch.ones(1, dtype=torch.bool) | |
| # ์คํํ์ฌ ํ ์๋ก ๋ณํ - (n_segments, 1, time_samples) ํํ ์ ์ง | |
| stacked_segments = torch.stack(segments) | |
| # ์ค์ ์ธ๊ทธ๋จผํธ ์ (ํจ๋ฉ ์๋) | |
| num_segments = stacked_segments.shape[0] | |
| # ํจ๋ฉ ๋ง์คํฌ ์์ฑ (False = ์ค์ ์ค๋์ค, True = ํจ๋ฉ) | |
| padding_mask = torch.zeros(48, dtype=torch.bool) | |
| # 48๊ฐ ๋ฏธ๋ง์ธ ๊ฒฝ์ฐ ํจ๋ฉ ์ถ๊ฐ | |
| if num_segments < 48: | |
| # ๋น ์ธ๊ทธ๋จผํธ๋ก ํจ๋ฉ (zeros) | |
| padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32) | |
| stacked_segments = torch.cat([stacked_segments, padding], dim=0) | |
| # ํจ๋ฉ ๋ง์คํฌ ์ค์ (True = ํจ๋ฉ) | |
| padding_mask[num_segments:] = True | |
| return stacked_segments, padding_mask | |
| def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict: | |
| """ | |
| Run inference on audio segments. | |
| Args: | |
| model: The loaded model | |
| audio_segments: Preprocessed audio segments tensor (48, 1, 240000) | |
| device: Device to run inference on | |
| Returns: | |
| Dictionary with prediction results | |
| """ | |
| model.eval() | |
| model.to(device) | |
| model = model.half() | |
| with torch.no_grad(): | |
| # ๋ฐ์ดํฐ ํํ ํ์ธ ๋ฐ ์กฐ์ | |
| # wav_collate_with_mask ํจ์์ ์ผ์นํ๋๋ก ์ฒ๋ฆฌ | |
| if audio_segments.shape[1] == 1: # (48, 1, 240000) ํํ | |
| # ์ฑ๋ ์ฐจ์ ์ ๊ฑฐํ๊ณ ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ | |
| audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000) | |
| else: | |
| audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์ฌ์ค audio๊ฐ ์๋๋ผ embedding segments์ผ์๋ | |
| # ๋ฐ์ดํฐ๋ฅผ half ํ์ ์ผ๋ก ๋ณํ | |
| if padding_mask.dim() == 1: | |
| padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48] | |
| audio_segments = audio_segments.to(device).half() | |
| mask = padding_mask.to(device) | |
| # ์ถ๋ก ์คํ (๋ง์คํฌ ํฌํจ) | |
| outputs = model(audio_segments, mask) | |
| # ๋ชจ๋ธ ์ถ๋ ฅ ๊ตฌ์กฐ์ ๋ฐ๋ผ ์ฒ๋ฆฌ | |
| if isinstance(outputs, dict): | |
| result = outputs | |
| else: | |
| # ๋จ์ผ ํ ์์ธ ๊ฒฝ์ฐ (๋ก์ง) | |
| logits = outputs.squeeze() | |
| prob = scaled_sigmoid(logits, scale_factor=1.0, linear_property=0.0).item() | |
| result = { | |
| "prediction": "Fake" if prob > 0.5 else "Real", | |
| "confidence": f"{max(prob, 1-prob)*100:.2f}", | |
| "fake_probability": f"{prob:.4f}", | |
| "real_probability": f"{1-prob:.4f}", | |
| "raw_output": logits.cpu().numpy().tolist() | |
| } | |
| return result | |
| # Custom scaling function to moderate extreme sigmoid values | |
| def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3): | |
| # Apply scaling to make sigmoid less extreme | |
| scaled_x = x * scale_factor | |
| # Combine sigmoid with linear component | |
| raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50) | |
| # Clip to ensure bounds | |
| return torch.clamp(raw_prob, min=0.011, max=0.989) | |
| # Apply the scaled sigmoid | |
| def get_model(model_type, device): | |
| """Load the specified model.""" | |
| if model_type == "MERT": | |
| from ISMIR_2025.MERT.networks import CCV | |
| #from model import MusicAudioClassifier | |
| model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) | |
| #model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device) | |
| ckpt_file = 'mert_finetune_10.pth' | |
| model.load_state_dict(torch.load(ckpt_file, map_location=device)) | |
| embed_dim = 768 | |
| elif model_type == "pure_MERT": | |
| from ISMIR_2025.MERT.networks import MERTFeatureExtractor | |
| model = MERTFeatureExtractor().to(device) | |
| embed_dim = 768 | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| model.eval() | |
| return model, embed_dim | |
| def inference(audio_path): | |
| parser = argparse.ArgumentParser(description="Music classifier inference") | |
| parser.add_argument("--model_type", type=str, required=True, choices=["MERT", "AudioCNN"], help="Type of model") | |
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to model checkpoint") | |
| parser.add_argument("--output_path", type=str, default=None, help="Path to save results (default: print to console)") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run inference on") | |
| args = parser.parse_args() | |
| audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3" | |
| # Note: Model loading would be handled by your code | |
| print(f"Loading model of type {args.model_type} from {args.checkpoint_path}") | |
| backbone_model, input_dim = get_model('MERT', 'cuda') | |
| segments, padding_mask = load_audio(audio_path, sr=24000) | |
| segments = segments.to(args.device).to(torch.float32) | |
| padding_mask = padding_mask.to(args.device).unsqueeze(0) | |
| logits,embedding = backbone_model(segments.squeeze(1)) | |
| test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0) | |
| test_data, test_target = test_dataset[0] | |
| test_data = test_data.to(args.device).to(torch.float32) | |
| test_target = test_target.to(args.device) | |
| output, _ = backbone_model(test_data.unsqueeze(0)) | |
| # ๋ชจ๋ธ ๋ก๋ ๋ถ๋ถ ์ถ๊ฐ | |
| model = MusicAudioClassifier.load_from_checkpoint( | |
| args.checkpoint_path, | |
| input_dim=input_dim, | |
| #emb_model=backbone_model | |
| is_emb = True, | |
| #mode = 'both' | |
| ) | |
| # Run inference | |
| print(f"Segments shape: {segments.shape}") | |
| print("Running inference...") | |
| results = run_inference(model, embedding, padding_mask, device=args.device) | |
| # ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
| print(f"Results: {results}") | |
| # ๊ฒฐ๊ณผ ์ ์ฅ | |
| if args.output_path: | |
| with open(args.output_path, 'w') as f: | |
| json.dump(results, f, indent=4) | |
| print(f"Results saved to {args.output_path}") | |
| return results | |
| if __name__ == "__main__": | |
| main() | |