gujarati-tisv / app.py
Irsh Vijayvargia
Add application file
1071dae
raw
history blame
4.24 kB
import torch
import librosa
import numpy as np
import os
import webrtcvad
import wave
import contextlib
import gradio as gr
from utils.VAD_segments import *
from utils.hparam import hparam as hp
from utils.speech_embedder_net import *
from utils.evaluation import *
def read_wave(audio_data):
"""Reads audio data and returns (PCM audio data, sample rate).
Assumes the input is a tuple (sample_rate, numpy_array).
If the sample rate is unsupported, resamples to 16000 Hz.
"""
sample_rate, data = audio_data
# Ensure data is in the correct shape
assert len(data.shape) == 1, "Audio data must be a 1D array"
# Convert to floating point if necessary
if not np.issubdtype(data.dtype, np.floating):
data = data.astype(np.float32) / np.iinfo(data.dtype).max
# Supported sample rates
supported_sample_rates = (8000, 16000, 32000, 48000)
# If sample rate is not supported, resample to 16000 Hz
if sample_rate not in supported_sample_rates:
data = librosa.resample(data, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
# Convert numpy array to PCM format
pcm_data = (data * np.iinfo(np.int16).max).astype(np.int16).tobytes()
return data, pcm_data
def VAD_chunk(aggressiveness, data):
audio, byte_audio = read_wave(data)
vad = webrtcvad.Vad(int(aggressiveness))
frames = frame_generator(20, byte_audio, hp.data.sr)
frames = list(frames)
times = vad_collector(hp.data.sr, 20, 200, vad, frames)
speech_times = []
speech_segs = []
for i, time in enumerate(times):
start = np.round(time[0],decimals=2)
end = np.round(time[1],decimals=2)
j = start
while j + .4 < end:
end_j = np.round(j+.4,decimals=2)
speech_times.append((j, end_j))
speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)])
j = end_j
else:
speech_times.append((j, end))
speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)])
return speech_times, speech_segs
def get_embedding(data, embedder_net, device, n_threshold=-1):
times, segs = VAD_chunk(0, data)
if not segs:
print(f'No voice activity detected')
return None
concat_seg = concat_segs(times, segs)
if not concat_seg:
print(f'No concatenated segments')
return None
STFT_frames = get_STFTs(concat_seg)
if not STFT_frames:
#print(f'No STFT frames')
return None
STFT_frames = np.stack(STFT_frames, axis=2)
STFT_frames = torch.tensor(np.transpose(STFT_frames, axes=(2, 1, 0)), device=device)
with torch.no_grad():
embeddings = embedder_net(STFT_frames)
embeddings = embeddings[:n_threshold, :]
avg_embedding = torch.mean(embeddings, dim=0, keepdim=True).cpu().numpy()
return avg_embedding
model_path = "./speech_id_checkpoint/saved_02.model"
embedder_net = SpeechEmbedder()
embedder_net.load_state_dict(torch.load(model_path))
embedder_net.eval()
def process_audio(audio1, audio2, threshold):
e1 = get_embedding(audio1, embedder_net, torch.device("cpu"))
if(e1 is None):
return "No Voice Detected in file 1"
e2 = get_embedding(audio2, embedder_net, torch.device("cpu"))
if(e2 is None):
return "No Voice Detected in file 2"
cosi = cosine_similarity(e1, e2)
if(cosi > threshold):
return f"Same Speaker"
else:
return f"Different Speaker"
# Define the Gradio interface
def gradio_interface(audio1, audio2, threshold):
output_text = process_audio(audio1, audio2, threshold)
return output_text
# Create the Gradio interface with microphone inputs
iface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Audio("microphone", type="numpy", label="Audio File 1"),
gr.Audio("microphone", type="numpy", label="Audio File 2"),
gr.Slider(0.0, 1.0, value=0.85, step=0.01, label="Threshold")
],
outputs="text",
title="Gujarati Text Independent Speaker Verification",
description="Record two audio files and get the text output from the model."
)
# Launch the interface
iface.launch(share=False)