AvtnshM commited on
Commit
40d87de
Β·
verified Β·
1 Parent(s): c41cc32

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +208 -0
  2. requirements.txt +23 -0
app.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from transformers import pipeline
6
+ import gc
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+
10
+ class OptimizedShukaASR:
11
+ def __init__(self):
12
+ self.pipe = None
13
+ self.load_model()
14
+
15
+ def load_model(self):
16
+ """Load model with optimizations for CPU inference"""
17
+ try:
18
+ # Force CPU usage and optimize for inference
19
+ self.pipe = pipeline(
20
+ model='sarvamai/shuka_v1',
21
+ trust_remote_code=True,
22
+ device=-1, # Force CPU
23
+ torch_dtype=torch.float16, # Use half precision
24
+ model_kwargs={
25
+ "torch_dtype": torch.float16,
26
+ "low_cpu_mem_usage": True,
27
+ "use_cache": True,
28
+ }
29
+ )
30
+
31
+ # Set to eval mode and optimize
32
+ if hasattr(self.pipe.model, 'eval'):
33
+ self.pipe.model.eval()
34
+
35
+ # Compile for faster inference (PyTorch 2.0+)
36
+ try:
37
+ self.pipe.model = torch.compile(self.pipe.model, mode="reduce-overhead")
38
+ except:
39
+ pass # Skip if torch.compile not available
40
+
41
+ print("Model loaded successfully with optimizations")
42
+
43
+ except Exception as e:
44
+ print(f"Error loading model: {e}")
45
+ self.pipe = None
46
+
47
+ def preprocess_audio(self, audio_input, target_sr=16000, max_duration=30):
48
+ """Preprocess audio with length limiting and optimization"""
49
+ try:
50
+ if isinstance(audio_input, tuple):
51
+ sr, audio_data = audio_input
52
+ audio_data = audio_data.astype(np.float32)
53
+ if len(audio_data.shape) > 1:
54
+ audio_data = audio_data.mean(axis=1) # Convert to mono
55
+ audio_data = audio_data / np.max(np.abs(audio_data)) # Normalize
56
+ else:
57
+ audio_data, sr = librosa.load(audio_input, sr=target_sr)
58
+
59
+ # Limit audio duration to reduce processing time
60
+ max_samples = int(max_duration * target_sr)
61
+ if len(audio_data) > max_samples:
62
+ audio_data = audio_data[:max_samples]
63
+ print(f"Audio truncated to {max_duration} seconds")
64
+
65
+ # Resample if needed
66
+ if sr != target_sr:
67
+ audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=target_sr)
68
+
69
+ return audio_data, target_sr
70
+
71
+ except Exception as e:
72
+ raise Exception(f"Audio preprocessing failed: {e}")
73
+
74
+ def transcribe(self, audio_input, language="auto"):
75
+ """Transcribe audio to text"""
76
+ if self.pipe is None:
77
+ return "Model not loaded. Please check the setup."
78
+
79
+ try:
80
+ # Preprocess audio
81
+ audio, sr = self.preprocess_audio(audio_input)
82
+
83
+ # Prepare system prompt for ASR only
84
+ if language == "auto":
85
+ system_prompt = "Transcribe the following audio accurately. Only provide the transcription, nothing else."
86
+ else:
87
+ system_prompt = f"Transcribe the following audio in {language}. Only provide the transcription, nothing else."
88
+
89
+ turns = [
90
+ {'role': 'system', 'content': system_prompt},
91
+ {'role': 'user', 'content': '<|audio|>'}
92
+ ]
93
+
94
+ # Run inference with memory optimization
95
+ with torch.no_grad():
96
+ result = self.pipe(
97
+ {
98
+ 'audio': audio,
99
+ 'turns': turns,
100
+ 'sampling_rate': sr
101
+ },
102
+ max_new_tokens=256, # Reduced for ASR only
103
+ do_sample=False, # Deterministic output
104
+ temperature=0.1, # Low temperature for accuracy
105
+ pad_token_id=self.pipe.tokenizer.eos_token_id
106
+ )
107
+
108
+ # Clean up memory
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ gc.collect()
112
+
113
+ # Extract transcription
114
+ if isinstance(result, list) and len(result) > 0:
115
+ transcription = result[0].get('generated_text', '').strip()
116
+ elif isinstance(result, dict):
117
+ transcription = result.get('generated_text', '').strip()
118
+ else:
119
+ transcription = str(result).strip()
120
+
121
+ return transcription
122
+
123
+ except Exception as e:
124
+ return f"Transcription failed: {str(e)}"
125
+
126
+ # Initialize the ASR system
127
+ asr_system = OptimizedShukaASR()
128
+
129
+ def transcribe_audio(audio, language):
130
+ """Gradio interface function"""
131
+ if audio is None:
132
+ return "Please provide an audio file."
133
+
134
+ result = asr_system.transcribe(audio, language)
135
+ return result
136
+
137
+ # Language options
138
+ languages = [
139
+ ("Auto-detect", "auto"),
140
+ ("English", "english"),
141
+ ("Hindi", "hindi"),
142
+ ("Bengali", "bengali"),
143
+ ("Gujarati", "gujarati"),
144
+ ("Kannada", "kannada"),
145
+ ("Malayalam", "malayalam"),
146
+ ("Marathi", "marathi"),
147
+ ("Oriya", "oriya"),
148
+ ("Punjabi", "punjabi"),
149
+ ("Tamil", "tamil"),
150
+ ("Telugu", "telugu")
151
+ ]
152
+
153
+ # Create Gradio interface
154
+ with gr.Blocks(title="Shuka v1 ASR - Multilingual Speech Recognition") as demo:
155
+ gr.Markdown("# πŸŽ™οΈ Shuka v1 ASR - Fast Multilingual Transcription")
156
+ gr.Markdown("Upload an audio file or record directly to get transcription in multiple Indic languages.")
157
+
158
+ with gr.Row():
159
+ with gr.Column():
160
+ audio_input = gr.Audio(
161
+ label="Audio Input",
162
+ type="filepath",
163
+ format="wav"
164
+ )
165
+ language_dropdown = gr.Dropdown(
166
+ choices=languages,
167
+ value="auto",
168
+ label="Language (optional)"
169
+ )
170
+ transcribe_btn = gr.Button("πŸš€ Transcribe", variant="primary")
171
+
172
+ with gr.Column():
173
+ output_text = gr.Textbox(
174
+ label="Transcription",
175
+ placeholder="Transcription will appear here...",
176
+ lines=10
177
+ )
178
+
179
+ # Event handlers
180
+ transcribe_btn.click(
181
+ fn=transcribe_audio,
182
+ inputs=[audio_input, language_dropdown],
183
+ outputs=output_text
184
+ )
185
+
186
+ # Auto-transcribe on audio upload
187
+ audio_input.change(
188
+ fn=transcribe_audio,
189
+ inputs=[audio_input, language_dropdown],
190
+ outputs=output_text
191
+ )
192
+
193
+ # Examples section
194
+ gr.Markdown("## πŸ“ Tips for best results:")
195
+ gr.Markdown("""
196
+ - Audio files are automatically limited to 30 seconds for faster processing
197
+ - Supported formats: WAV, MP3, M4A, WEBM
198
+ - For best accuracy, use clear audio with minimal background noise
199
+ - The model supports 11 Indic languages + English
200
+ """)
201
+
202
+ if __name__ == "__main__":
203
+ demo.launch(
204
+ server_name="0.0.0.0",
205
+ server_port=7860,
206
+ share=False,
207
+ show_error=True
208
+ )
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML libraries
2
+ torch==2.1.0
3
+ transformers==4.41.2
4
+ peft==0.11.1
5
+
6
+ # Audio processing
7
+ librosa==0.10.2
8
+ soundfile==0.12.1
9
+
10
+ # Gradio for web interface
11
+ gradio==4.20.0
12
+
13
+ # Utilities
14
+ numpy==1.24.3
15
+ scipy==1.11.1
16
+ torchaudio==2.1.0
17
+
18
+ # Optional optimizations
19
+ accelerate==0.28.0
20
+ bitsandbytes==0.43.0
21
+
22
+ # System utilities
23
+ psutil==5.9.5