Wav2Txt / app.py
Merlintxu's picture
Update app.py
e35f365 verified
raw
history blame
2.89 kB
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
import librosa
import subprocess
from langdetect import detect
# Modelos por idioma
MODELS = {
"es": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
"en": "facebook/wav2vec2-large-960h", # Puedes añadir más modelos aquí según sea necesario
# Añadir más modelos por idioma si es necesario
}
def convert_audio_to_wav(audio_path):
wav_path = "converted_audio.wav"
command = ["ffmpeg", "-i", audio_path, "-ac", "1", "-ar", "16000", wav_path]
subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return wav_path
def detect_language(audio_path):
# Cargar los primeros 15 segundos del audio
speech, _ = librosa.load(audio_path, sr=16000, duration=15)
# Convertir audio a texto usando el modelo inglés como predeterminado para detección
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")
input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return detect(transcription)
def transcribe_audio(audio):
# Convertir audio a formato WAV
wav_audio = convert_audio_to_wav(audio)
# Detectar el idioma del audio
language = detect_language(wav_audio)
model_name = MODELS.get(language, "facebook/wav2vec2-large-960h") # Modelo predeterminado en caso de que no se detecte el idioma
# Cargar el modelo y el procesador adecuados
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# Cargar el audio completo
speech, rate = librosa.load(wav_audio, sr=16000)
# Procesar el audio
input_values = processor(speech, return_tensors="pt", sampling_rate=rate).input_values
# Generar las predicciones (logits)
with torch.no_grad():
logits = model(input_values).logits
# Obtener las predicciones (tokens) y convertirlas en texto
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
# Guardar la transcripción en un archivo de texto
with open("transcription.txt", "w") as file:
file.write(transcription)
return "transcription.txt"
# Configurar la interfaz de Gradio
iface = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(type="filepath"),
outputs=gr.File(),
title="Transcriptor de Audio Multilingüe",
description="Sube un archivo de audio y obtén la transcripción en un archivo de texto."
)
# Iniciar la interfaz
if __name__ == "__main__":
iface.launch()