Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import glob | |
| import torch | |
| import torchaudio | |
| import librosa | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| import torch | |
| import torchaudio | |
| from transformers import Wav2Vec2FeatureExtractor | |
| import scipy.signal as signal | |
| import scipy.signal | |
| import random | |
| class FakeMusicCapsDataset(Dataset): | |
| def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): | |
| self.file_paths = file_paths | |
| self.labels = labels | |
| self.sr = sr | |
| self.target_samples = int(target_duration * sr) # Fixed length: 10 seconds | |
| self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def pre_emphasis(self, x, alpha=0.97): | |
| return np.append(x[0], x[1:] - alpha * x[:-1]) | |
| def highpass_filter(self, y, sr, cutoff=1000, order=5): | |
| if isinstance(sr, np.ndarray): | |
| sr = np.mean(sr) | |
| if not isinstance(sr, (int, float)): | |
| raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") | |
| if sr <= 0: | |
| raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") | |
| nyquist = 0.5 * sr | |
| if cutoff <= 0 or cutoff >= nyquist: | |
| print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") | |
| cutoff = max(10, min(cutoff, nyquist - 1)) | |
| normal_cutoff = cutoff / nyquist | |
| b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) | |
| y_filtered = signal.lfilter(b, a, y) | |
| return y_filtered | |
| # ์๊ฐ ์กฐ์ (Time Stretch), ์ดํ๋ผ์ด์ ์กฐ์ (EQ), ๋ฆฌ๋ฒ๋ธ ์ถ๊ฐ | |
| def augment_audio(self, y, sr): | |
| if isinstance(y, torch.Tensor): | |
| y = y.numpy() # Tensor โ Numpy ๋ณํ | |
| if random.random() < 0.5: # ์๊ฐ ์กฐ์ (Time Stretch) | |
| rate = random.uniform(0.8, 1.2) | |
| y = librosa.effects.time_stretch(y=y, rate=rate) | |
| if random.random() < 0.5: # ํผ์น ์ํํธ (Pitch Shift) | |
| n_steps = random.randint(-2, 2) | |
| y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps) | |
| if random.random() < 0.5: # ํ์ดํธ ๋ ธ์ด์ฆ ์ถ๊ฐ (White Noise Addition) | |
| noise_level = np.random.uniform(0.001, 0.005) | |
| y = y + np.random.normal(0, noise_level, y.shape) | |
| return torch.tensor(y, dtype=torch.float32) # ๋ค์ Tensor๋ก ๋ณํ | |
| def __getitem__(self, idx): | |
| audio_path = self.file_paths[idx] | |
| label = self.labels[idx] | |
| waveform, sr = torchaudio.load(audio_path) | |
| target_sr = self.processor.sampling_rate | |
| if sr != target_sr: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) | |
| waveform = resampler(waveform) | |
| waveform = waveform.mean(dim=0).squeeze(0) | |
| if label == 0: | |
| waveform = self.augment_audio(waveform, self.sr) | |
| if label == 1: | |
| waveform = self.highpass_filter(waveform, self.sr) | |
| # waveform = self.pre_emphasis(waveform) | |
| waveform = self.augment_audio(waveform, self.sr) | |
| # if label == 1: | |
| # waveform = self.pre_emphasis(waveform) | |
| # waveform = torch.tensor(waveform, dtype=torch.float32) | |
| current_samples = waveform.shape[0] | |
| if current_samples > self.target_samples: | |
| waveform = waveform[:self.target_samples] # Truncate | |
| elif current_samples < self.target_samples: | |
| pad_length = self.target_samples - current_samples | |
| waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad | |
| if isinstance(waveform, torch.Tensor): | |
| waveform = waveform.numpy() # Tensor์ผ ๊ฒฝ์ฐ์๋ง ๋ณํ | |
| print(waveform.shape) | |
| inputs = self.processor(waveform, sampling_rate=target_sr, return_tensors="pt", padding=True) | |
| print(inputs["input_values"].shape) | |
| return inputs["input_values"].squeeze(0), torch.tensor(label, dtype=torch.long) # [1, time] โ [time] | |
| def collate_fn(batch, target_samples=16000 * 10): | |
| inputs, labels = zip(*batch) # Unzip batch | |
| processed_inputs = [] | |
| for waveform in inputs: | |
| current_samples = waveform.shape[0] | |
| if current_samples > target_samples: | |
| start_idx = (current_samples - target_samples) // 2 | |
| cropped_waveform = waveform[start_idx:start_idx + target_samples] | |
| else: | |
| pad_length = target_samples - current_samples | |
| cropped_waveform = torch.nn.functional.pad(waveform, (0, pad_length)) | |
| processed_inputs.append(cropped_waveform) | |
| processed_inputs = torch.stack(processed_inputs) # [batch, target_samples] | |
| labels = torch.tensor(labels, dtype=torch.long) # [batch] | |
| return processed_inputs, labels | |
| def preprocess_audio(audio_path, target_sr=16000, max_length=160000): | |
| """ | |
| ์ค๋์ค๋ฅผ ๋ชจ๋ธ ์ ๋ ฅ์ ๋ง๊ฒ ๋ณํ | |
| - target_sr: 16kHz๋ก ๋ณํ | |
| - max_length: ์ต๋ ๊ธธ์ด 160000 (10์ด) | |
| """ | |
| waveform, sr = torchaudio.load(audio_path) | |
| # Resample if needed | |
| if sr != target_sr: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) | |
| # Convert to mono | |
| waveform = waveform.mean(dim=0).unsqueeze(0) # (1, sequence_length) | |
| current_samples = waveform.shape[1] | |
| if current_samples > max_length: | |
| start_idx = (current_samples - max_length) // 2 | |
| waveform = waveform[:, start_idx:start_idx + max_length] | |
| elif current_samples < max_length: | |
| pad_length = max_length - current_samples | |
| waveform = torch.nn.functional.pad(waveform, (0, pad_length)) | |
| return waveform | |
| def collect_files(base_path): | |
| real_files = glob.glob(os.path.join(base_path, "real", "**", "*.wav"), recursive=True) | |
| gen_files = glob.glob(os.path.join(base_path, "generative", "**", "*.wav"), recursive=True) | |
| real_labels = [0] * len(real_files) | |
| gen_labels = [1] * len(gen_files) | |
| return real_files + gen_files, real_labels + gen_labels | |