SaraAlthubaiti commited on
Commit
82b272f
·
verified ·
1 Parent(s): 106af52

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +43 -0
inference.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import WhisperFeatureExtractor
3
+ from models.tinyoctopus import TINYOCTOPUS
4
+ from utils import prepare_one_sample
5
+
6
+ # Load model
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model = TINYOCTOPUS.from_config(cfg.config.model)
9
+ model.to(device)
10
+ model.eval()
11
+
12
+ # Load processor
13
+ wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3")
14
+
15
+ def transcribe(audio_path, task="dialect"):
16
+ """
17
+ Perform inference on an audio file.
18
+
19
+ Args:
20
+ audio_path (str): Path to the audio file.
21
+ task (str): Task to perform. Options: "dialect", "asr", "translation".
22
+
23
+ Returns:
24
+ str: The generated text.
25
+ """
26
+ task_prompts = {
27
+ "dialect": "What is the dialect of the speaker?",
28
+ "asr": "تعرف على الكلام وأعطني النص.",
29
+ "translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية."
30
+ }
31
+
32
+ if task not in task_prompts:
33
+ raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.")
34
+
35
+ try:
36
+ prompt = task_prompts[task]
37
+ samples = prepare_one_sample(audio_path, wav_processor)
38
+ prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"]
39
+ generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0]
40
+ return generated_text.replace('<s>', '').replace('</s>', '').strip()
41
+
42
+ except Exception as e:
43
+ return f"Error: {e}"