ShukaAI_ASR / app.py
AvtnshM's picture
Lite-V3
771dc21 verified
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import pipeline, AutoConfig
import gc
import warnings
import os
warnings.filterwarnings("ignore")
# Set environment variables for optimization
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class UltraLightShukaASR:
def __init__(self):
self.pipe = None
self.model_loaded = False
def load_model_lazy(self):
"""Lazy load model only when needed"""
if self.model_loaded:
return True
try:
print("Loading Shuka v1 model...")
# Try with minimal configuration first
self.pipe = pipeline(
model='sarvamai/shuka_v1',
trust_remote_code=True,
device=-1, # CPU only
model_kwargs={
"low_cpu_mem_usage": True,
"use_cache": False, # Disable cache to save memory
"torch_dtype": torch.float32,
}
)
print("βœ… Model loaded successfully!")
self.model_loaded = True
return True
except Exception as e:
print(f"❌ Model loading failed: {e}")
return False
def preprocess_audio_minimal(self, audio_input, target_sr=16000, max_duration=15):
"""Minimal audio preprocessing for speed"""
try:
if isinstance(audio_input, tuple):
sr, audio_data = audio_input
audio_data = audio_data.astype(np.float32)
if len(audio_data.shape) > 1:
audio_data = np.mean(audio_data, axis=1)
else:
audio_data, sr = librosa.load(audio_input, sr=target_sr, duration=max_duration)
# Quick normalization
if np.max(np.abs(audio_data)) > 0:
audio_data = audio_data / np.max(np.abs(audio_data))
# Trim silence from start and end
audio_data, _ = librosa.effects.trim(audio_data, top_db=20)
return audio_data, target_sr
except Exception as e:
raise Exception(f"Audio preprocessing failed: {e}")
def transcribe_fast(self, audio_input, language_hint=""):
"""Fast transcription with minimal overhead"""
# Lazy load model
if not self.load_model_lazy():
return "❌ Model failed to load. Please check your setup."
try:
# Quick audio processing
audio, sr = self.preprocess_audio_minimal(audio_input)
# Minimal system prompt for speed
system_content = "Transcribe audio to text."
if language_hint and language_hint != "auto":
system_content += f" Language: {language_hint}."
turns = [
{'role': 'system', 'content': system_content},
{'role': 'user', 'content': '<|audio|>'}
]
# Fast inference settings
with torch.inference_mode(): # More efficient than no_grad
result = self.pipe(
{
'audio': audio,
'turns': turns,
'sampling_rate': sr
},
max_new_tokens=128, # Reduced further
do_sample=False, # Deterministic
num_beams=1, # No beam search
early_stopping=True, # Stop as soon as possible
pad_token_id=self.pipe.tokenizer.eos_token_id if hasattr(self.pipe, 'tokenizer') else None
)
# Immediate cleanup
del audio
gc.collect()
# Extract result
if isinstance(result, list) and len(result) > 0:
text = result[0].get('generated_text', '').strip()
elif isinstance(result, dict):
text = result.get('generated_text', '').strip()
else:
text = str(result).strip()
# Clean up the output (remove system prompts if they appear)
if "Transcribe audio to text" in text:
text = text.replace("Transcribe audio to text", "").strip()
if text.startswith("Language:"):
text = text.split(".", 1)[-1].strip() if "." in text else text
return text if text else "No speech detected"
except Exception as e:
return f"❌ Transcription error: {str(e)}"
# Initialize ASR system
print("Initializing Ultra-Light Shuka ASR...")
asr_system = UltraLightShukaASR()
def process_audio(audio, language):
"""Main processing function"""
if audio is None:
return "Please upload or record an audio file."
return asr_system.transcribe_fast(audio, language)
# Simple language options
LANGUAGES = [
("Auto", "auto"),
("English", "english"),
("Hindi", "hindi"),
("Bengali", "bengali"),
("Tamil", "tamil"),
("Telugu", "telugu"),
("Gujarati", "gujarati"),
("Kannada", "kannada"),
("Malayalam", "malayalam"),
("Marathi", "marathi"),
("Punjabi", "punjabi"),
("Oriya", "oriya")
]
# Ultra-minimal Gradio interface
css = """
.gradio-container {
max-width: 800px !important;
}
.output-text textarea {
font-size: 16px !important;
}
"""
with gr.Blocks(css=css, title="Fast Shuka ASR") as demo:
gr.HTML("""
<div style='text-align: center; margin-bottom: 20px;'>
<h1>πŸš€ Ultra-Fast Shuka v1 ASR</h1>
<p>Optimized for speed β€’ Multilingual β€’ 15-second max clips</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(
label="πŸŽ™οΈ Audio Input",
type="filepath",
format="wav",
elem_id="audio-input"
)
language_select = gr.Dropdown(
choices=LANGUAGES,
value="auto",
label="🌍 Language Hint",
info="Optional - helps with accuracy"
)
with gr.Column(scale=2):
output_box = gr.Textbox(
label="πŸ“ Transcription",
placeholder="Upload audio to see transcription here...",
lines=8,
elem_classes=["output-text"]
)
gr.Button("πŸ”„ Clear", size="sm").click(
lambda: ("", None),
outputs=[output_box, audio_input]
)
# Auto-transcribe on upload
audio_input.change(
fn=process_audio,
inputs=[audio_input, language_select],
outputs=output_box,
show_progress=True
)
# Also trigger on language change
language_select.change(
fn=process_audio,
inputs=[audio_input, language_select],
outputs=output_box,
show_progress=True
)
gr.HTML("""
<div style='margin-top: 20px; padding: 15px; background: #f0f0f0; border-radius: 10px;'>
<h4>⚑ Speed Optimizations Active:</h4>
<ul style='margin: 10px 0;'>
<li>βœ… Auto audio trimming (15s max)</li>
<li>βœ… CPU-optimized inference</li>
<li>βœ… Minimal token generation</li>
<li>βœ… Memory cleanup after each request</li>
</ul>
<p><strong>Tip:</strong> For fastest results, use short, clear audio clips in WAV format.</p>
</div>
""")
if __name__ == "__main__":
demo.queue(max_size=3) # Limit concurrent requests
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
quiet=False
)