# app.py — Persian Whisper ASR (HF Spaces friendly)
import os, time, base64, datetime, logging
from html import escape
from difflib import SequenceMatcher
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# ===== Logging =====
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("persian-whisper-space")
# ===== Env =====
HF_TOKEN = os.getenv("HF_TOKEN", None)
if HF_TOKEN is None:
logger.warning("HF_TOKEN is not set. Add it in Space Settings → Secrets.")
# مدل پیشفرض (میتوانی هر موقع عوضش کنی)
MODEL_ID = os.getenv(
"MODEL_ID",
"MohammadReza-Halakoo/Whisper-Small-PersianASR-20-percent-17-0"
)
# اگر خواستی حذفِ سکوت فعال شود: در Settings→Variables مقدار 1 بگذار
ENABLE_SILENCE_REMOVAL = os.getenv("ENABLE_SILENCE_REMOVAL", "0") == "1"
# ===== Device & dtype =====
if torch.cuda.is_available():
device = "cuda"
torch_dtype = torch.float16
logger.info("GPU detected → using CUDA + float16")
else:
device = "cpu"
try:
# روی CPU، اگر bf16 پشتیبانی نشود، میرویم روی float32
torch_dtype = torch.bfloat16 # اکثر اوقات امن است؛ اگر خطا داد، except پایین میگیرد
_ = torch.tensor([0], dtype=torch_dtype) # sanity check
except Exception:
torch_dtype = torch.float32
logger.info(f"No GPU detected → falling back to CPU + {torch_dtype}")
# ===== Load model =====
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype=torch_dtype if device == "cuda" else None, # روی CPU بهتره float32/bf16 بماند
low_cpu_mem_usage=True,
device_map="auto" if device == "cuda" else None
).to(device)
# Pad token safety
if getattr(model.config, "pad_token_id", None) is None:
model.config.pad_token_id = processor.tokenizer.pad_token_id
if model.config.pad_token_id == model.config.eos_token_id:
if processor.tokenizer.pad_token_id != processor.tokenizer.eos_token_id:
model.config.pad_token_id = processor.tokenizer.pad_token_id
else:
processor.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model.resize_token_embeddings(len(processor.tokenizer))
model.config.pad_token_id = processor.tokenizer.pad_token_id
# ===== Optional prompt ids (fa/transcribe) =====
try:
prompt_ids = processor.get_decoder_prompt_ids(language="farsi", task="transcribe")
except Exception:
prompt_ids = None
# ===== Audio utils =====
def resolve_path(x):
if not x:
return None
if isinstance(x, str):
return x
if hasattr(x, "name"):
return x.name
if isinstance(x, dict) and "name" in x:
return x["name"]
return None
def load_audio_preserving_quality(audio_path, target_sr=16000):
try:
audio, sr = librosa.load(audio_path, sr=None, mono=False)
if audio.ndim > 1:
audio = np.mean(audio, axis=0)
if sr != target_sr:
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
sr = target_sr
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
audio = np.nan_to_num(audio)
return audio, sr
except Exception as e:
logger.exception(f"Audio load error: {e}")
return None, None
def remove_intermediate_silence(audio, sr, silence_thresh=-38, min_silence_len=700, padding=200):
if not ENABLE_SILENCE_REMOVAL:
return audio, sr
try:
# import اینجا تا اگر pydub/ffmpeg نصب نبود، کل برنامه crash نکند
from pydub import AudioSegment
from pydub.silence import detect_nonsilent
audio_segment = AudioSegment(
(audio * np.iinfo(np.int16).max).astype(np.int16).tobytes(),
frame_rate=sr, sample_width=2, channels=1
)
ranges = detect_nonsilent(audio_segment, min_silence_len=min_silence_len, silence_thresh=silence_thresh)
if not ranges:
return np.array([], dtype=np.float32), sr
merged = AudioSegment.empty()
for start, end in ranges:
start = max(0, start - padding)
end = min(len(audio_segment), end + padding)
merged += audio_segment[start:end]
data = np.array(merged.get_array_of_samples()).astype(np.float32)
data /= np.iinfo(np.int16).max
return data, sr
except Exception as e:
logger.warning(f"Silence removal disabled (error): {e}")
return audio, sr
def is_silent(audio, threshold=1e-4):
if audio is None or len(audio) == 0:
return True
rms = float(np.sqrt(np.mean(audio**2)))
return rms < threshold
def merge_transcriptions(transcriptions):
if not transcriptions:
return ""
out = transcriptions[0]
for i in range(1, len(transcriptions)):
prev, cur = out, transcriptions[i]
N = 50
match = SequenceMatcher(None, prev[-N:], cur[:N]).find_longest_match(0, min(N, len(prev)), 0, min(N, len(cur)))
if match.size > 10:
out += cur[match.b + match.size :]
else:
out += " " + cur
return out
# ===== Image helper (optional) =====
def image_to_base64(path):
try:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
except Exception:
return None
image_base64 = image_to_base64("assets/hero.jpg")
img_html = f'' if image_base64 else ""
# ===== Inference =====
def transcribe_audio(mic=None, upload_audio=None, file=None):
t0 = time.time()
audio_path = resolve_path(mic) or resolve_path(upload_audio) or resolve_path(file)
logger.info(f"audio_path: {audio_path!r}")
if not audio_path:
return "لطفاً یک فایل صوتی یا صدای ضبطشده ارسال کنید.", None, None, None
audio, sr = load_audio_preserving_quality(audio_path, 16000)
if audio is None:
return "خطا در بارگذاری و پردازش صوت.", None, None, None
audio, sr = remove_intermediate_silence(audio, sr)
if is_silent(audio):
return "صوت ورودی حاوی صدای قابل پردازش نیست.", None, None, None
# Chunking ~22s with 3s stride (سبکتر از 29s)
max_chunk_length, stride_length = 22, 3
max_chunk_samples, stride_samples = int(max_chunk_length*sr), int(stride_length*sr)
chunks, start = [], 0
L = len(audio)
while start < L:
end = min(start + max_chunk_samples, L)
chunks.append(audio[start:end])
if end >= L: break
start += max_chunk_samples - stride_samples
if not chunks:
return "صوت ورودی خالی است.", None, None, None
# سبکتر برای CPU
gen_kwargs = dict(max_new_tokens=225, do_sample=False, num_beams=1, length_penalty=1.0)
trans = []
for i, chunk in enumerate(chunks, 1):
try:
inputs = processor(chunk, sampling_rate=sr, return_tensors="pt", padding=True)
feats = inputs.input_features.to(device)
with torch.no_grad():
if prompt_ids is not None:
dec_ids = torch.tensor([x[1] for x in prompt_ids]).unsqueeze(0).to(device)
ids = model.generate(feats, decoder_input_ids=dec_ids, **gen_kwargs)
else:
ids = model.generate(feats, **gen_kwargs)
text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
trans.append(text)
except Exception as e:
logger.exception(f"Model error on chunk {i}: {e}")
return "خطا در تبدیل گفتار به متن رخ داد.", None, None, None
final_text = merge_transcriptions(trans).strip()
if not final_text:
return "هیچ متنی استخراج نشد.", None, None, None
# Save .txt
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"transcription_{ts}.txt"
with open(filename, "w", encoding="utf-8") as f:
f.write(final_text)
escaped = escape(final_text)
html_buttons = f"""
با استفاده از مدل خصوصی، صوت شما به متن تبدیل میشود.