File size: 1,501 Bytes
82b272f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
from transformers import WhisperFeatureExtractor
from models.tinyoctopus import TINYOCTOPUS
from utils import prepare_one_sample
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TINYOCTOPUS.from_config(cfg.config.model)
model.to(device)
model.eval()
# Load processor
wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3")
def transcribe(audio_path, task="dialect"):
"""
Perform inference on an audio file.
Args:
audio_path (str): Path to the audio file.
task (str): Task to perform. Options: "dialect", "asr", "translation".
Returns:
str: The generated text.
"""
task_prompts = {
"dialect": "What is the dialect of the speaker?",
"asr": "تعرف على الكلام وأعطني النص.",
"translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية."
}
if task not in task_prompts:
raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.")
try:
prompt = task_prompts[task]
samples = prepare_one_sample(audio_path, wav_processor)
prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"]
generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0]
return generated_text.replace('<s>', '').replace('</s>', '').strip()
except Exception as e:
return f"Error: {e}"
|