from grammar import analyse_grammar_gf from phoneme import transcribe_tensor from audio import decode_audio_bytes, preprocess_audio from utils import ( arpabet_to_ipa_seq, levenshtein_similarity_score as similarity_score ) from model import ( AlignmentRequest, CorrectionRequest ) from gramformer import Gramformer from minineedle import needle, core from g2p_en import G2p from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torch from fastapi.responses import JSONResponse from fastapi import FastAPI, UploadFile, File import os # Configure environment os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:64") DEVICE = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") TARGET_SR = 16000 MODEL_ID = "vitouphy/wav2vec2-xls-r-300m-phoneme" # Load model and processor g2p = G2p() gf = Gramformer(models=1, use_gpu=torch.cuda.is_available()) processor = Wav2Vec2Processor.from_pretrained(MODEL_ID) model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID).to(DEVICE).eval() # App instance app = FastAPI(title="Audio Phoneme Transcription API") # on app shutdown @app.on_event("shutdown") def shutdown_event(): if language_tool: language_tool.close() # region FastAPI app @app.post("/transcribe_file/") async def transcribe_file(audio_file: UploadFile = File(...)): try: audio_bytes = await audio_file.read() if not audio_file: return JSONResponse(status_code=400, content={"error": "Empty audio data"}) content_type = audio_file.content_type audio, sr = decode_audio_bytes(audio_bytes, content_type) waveform = preprocess_audio(audio, sr, TARGET_SR) phonemes = transcribe_tensor( model=model, processor=processor, waveform=waveform, sr=TARGET_SR, device=DEVICE ) print( f"[OK] {waveform.shape[-1]/TARGET_SR:.2f}s audio") return {"transcription": phonemes} except Exception as e: print(f"[ERR] {e}") return JSONResponse(status_code=500, content={"error": str(e)}) @app.post("/align_phoneme/") async def align_phoneme(request: AlignmentRequest): # Extract phonemes and words phonemes_pred = request.phonemes words = request.words if not phonemes_pred or not words: return JSONResponse(status_code=400, content={"error": "Empty phonemes or words"}) # Convert predicted phonemes from ARPAbet to IPA phonemes_pred = arpabet_to_ipa_seq(phonemes_pred) # Convert words to ground truh phonemes phonemes_gt = [] word_boundaries = [] for word in words: phs = [p for p in g2p(word) if p != ' '] phs = arpabet_to_ipa_seq(phs) word_boundaries.append( ( word, len(phs) ) ) phonemes_gt.extend(phs) # Perform alignment alignment = needle.NeedlemanWunsch(phonemes_gt, phonemes_pred) alignment.align() al_gt, al_pred = alignment.get_aligned_sequences() al_gt = ["-" if isinstance(a, core.Gap) else a for a in al_gt] al_pred = ["-" if isinstance(a, core.Gap) else a for a in al_pred] # Map back phonemes to words word_to_pred = [] current_idx = 0 for word, word_len in word_boundaries: _phonemes_gt_len = 0 _phonemes_gt = [] _phonemes_pred = [] while _phonemes_gt_len != word_len and current_idx < len(al_gt): if al_gt[current_idx] != "-": _phonemes_gt_len += 1 _phonemes_gt.append(al_gt[current_idx]) _phonemes_pred.append(al_pred[current_idx]) current_idx += 1 score = similarity_score(_phonemes_gt, _phonemes_pred) word_to_pred.append({ "word": word, "correct_ipa": _phonemes_gt, "user_ipa": _phonemes_pred, "score": score }) return {"alignment": word_to_pred} @app.post("/correct_grammar_gf/") async def correct_grammar_gf(request: CorrectionRequest): # Extract phonemes and words transcript = request.transcript if not transcript: return JSONResponse(status_code=400, content={"error": "Empty transcript"}) corrections = analyse_grammar_gf(transcript, gf) return {"corrections": corrections} @app.get("/") def health(): return {"status": "ok", "model": MODEL_ID, "device": str(DEVICE)}