|
import torch |
|
from transformers import WhisperFeatureExtractor |
|
from models.tinyoctopus import TINYOCTOPUS |
|
from utils import prepare_one_sample |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = TINYOCTOPUS.from_config(cfg.config.model) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3") |
|
|
|
def transcribe(audio_path, task="dialect"): |
|
""" |
|
Perform inference on an audio file. |
|
|
|
Args: |
|
audio_path (str): Path to the audio file. |
|
task (str): Task to perform. Options: "dialect", "asr", "translation". |
|
|
|
Returns: |
|
str: The generated text. |
|
""" |
|
task_prompts = { |
|
"dialect": "What is the dialect of the speaker?", |
|
"asr": "تعرف على الكلام وأعطني النص.", |
|
"translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية." |
|
} |
|
|
|
if task not in task_prompts: |
|
raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.") |
|
|
|
try: |
|
prompt = task_prompts[task] |
|
samples = prepare_one_sample(audio_path, wav_processor) |
|
prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"] |
|
generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0] |
|
return generated_text.replace('<s>', '').replace('</s>', '').strip() |
|
|
|
except Exception as e: |
|
return f"Error: {e}" |
|
|