Spaces:
Sleeping
Sleeping
Update gradio_app.py
Browse files- gradio_app.py +12 -6
gradio_app.py
CHANGED
|
@@ -40,11 +40,11 @@ def separate_speakers_core(audio_path):
|
|
| 40 |
waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
|
| 41 |
|
| 42 |
if waveform.dim() == 1:
|
| 43 |
-
waveform = waveform.unsqueeze(0)
|
| 44 |
-
audio_input = waveform.unsqueeze(0).to(device)
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
-
ests_speech = sep_model(audio_input).squeeze(0)
|
| 48 |
|
| 49 |
session_id = uuid.uuid4().hex[:8]
|
| 50 |
output_dir = os.path.join("output_sep", session_id)
|
|
@@ -53,15 +53,21 @@ def separate_speakers_core(audio_path):
|
|
| 53 |
output_files = []
|
| 54 |
for i in range(ests_speech.shape[0]):
|
| 55 |
path = os.path.join(output_dir, f"speaker_{i+1}.wav")
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
return output_files
|
| 62 |
|
| 63 |
|
| 64 |
|
|
|
|
| 65 |
@spaces.GPU()
|
| 66 |
def separate_dnr(audio_file):
|
| 67 |
audio, sr = torchaudio.load(audio_file)
|
|
|
|
| 40 |
waveform = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR)(waveform)
|
| 41 |
|
| 42 |
if waveform.dim() == 1:
|
| 43 |
+
waveform = waveform.unsqueeze(0) # Ensure shape is (1, samples)
|
| 44 |
+
audio_input = waveform.unsqueeze(0).to(device) # Shape: (1, 1, samples)
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
+
ests_speech = sep_model(audio_input).squeeze(0) # Shape: (num_speakers, samples)
|
| 48 |
|
| 49 |
session_id = uuid.uuid4().hex[:8]
|
| 50 |
output_dir = os.path.join("output_sep", session_id)
|
|
|
|
| 53 |
output_files = []
|
| 54 |
for i in range(ests_speech.shape[0]):
|
| 55 |
path = os.path.join(output_dir, f"speaker_{i+1}.wav")
|
| 56 |
+
speaker_waveform = ests_speech[i].cpu()
|
| 57 |
+
|
| 58 |
+
if speaker_waveform.dim() == 1:
|
| 59 |
+
speaker_waveform = speaker_waveform.unsqueeze(0) # (1, samples)
|
| 60 |
|
| 61 |
+
# Ensure correct dtype and save in a widely compatible format
|
| 62 |
+
speaker_waveform = speaker_waveform.to(torch.float32)
|
| 63 |
+
torchaudio.save(path, speaker_waveform, TARGET_SR, format="wav", encoding="PCM_S", bits_per_sample=16)
|
| 64 |
+
output_files.append(path)
|
| 65 |
|
| 66 |
return output_files
|
| 67 |
|
| 68 |
|
| 69 |
|
| 70 |
+
|
| 71 |
@spaces.GPU()
|
| 72 |
def separate_dnr(audio_file):
|
| 73 |
audio, sr = torchaudio.load(audio_file)
|