File size: 1,501 Bytes
82b272f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from transformers import WhisperFeatureExtractor
from models.tinyoctopus import TINYOCTOPUS
from utils import prepare_one_sample

# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TINYOCTOPUS.from_config(cfg.config.model)
model.to(device)
model.eval()

# Load processor
wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3")

def transcribe(audio_path, task="dialect"):
    """
    Perform inference on an audio file.
    
    Args:
        audio_path (str): Path to the audio file.
        task (str): Task to perform. Options: "dialect", "asr", "translation".
    
    Returns:
        str: The generated text.
    """
    task_prompts = {
        "dialect": "What is the dialect of the speaker?",
        "asr": "تعرف على الكلام وأعطني النص.",
        "translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية."
    }

    if task not in task_prompts:
        raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.")

    try:
        prompt = task_prompts[task]
        samples = prepare_one_sample(audio_path, wav_processor)
        prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"]
        generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0]
        return generated_text.replace('<s>', '').replace('</s>', '').strip()

    except Exception as e:
        return f"Error: {e}"