mrchan1207 commited on
Commit
e747cdc
·
verified ·
1 Parent(s): 88423ae
Files changed (1) hide show
  1. app.py +34 -2
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import librosa
3
  import soundfile as sf
4
  import io
5
- from fastapi import FastAPI, File, UploadFile
6
  from fastapi.responses import JSONResponse
7
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
8
 
@@ -41,7 +41,7 @@ except Exception as e:
41
 
42
  # --- 3. Define the Transcription Endpoint ---
43
  @app.post("/transcribe/")
44
- async def transcribe_audio(audio_file: UploadFile = File(...)):
45
  if not model or not processor:
46
  return JSONResponse(status_code=503, content={"error": "Model is not loaded."})
47
 
@@ -71,6 +71,38 @@ async def transcribe_audio(audio_file: UploadFile = File(...)):
71
  except Exception as e:
72
  print(f"Error during transcription: {str(e)}")
73
  return JSONResponse(status_code=500, content={"error": f"An error occurred: {str(e)}"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # --- 4. Root Endpoint for Health Check ---
76
  @app.get("/")
 
2
  import librosa
3
  import soundfile as sf
4
  import io
5
+ from fastapi import FastAPI, File, UploadFile, Request
6
  from fastapi.responses import JSONResponse
7
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
8
 
 
41
 
42
  # --- 3. Define the Transcription Endpoint ---
43
  @app.post("/transcribe/")
44
+ async def transcribe(audio_file: UploadFile = File(...)):
45
  if not model or not processor:
46
  return JSONResponse(status_code=503, content={"error": "Model is not loaded."})
47
 
 
71
  except Exception as e:
72
  print(f"Error during transcription: {str(e)}")
73
  return JSONResponse(status_code=500, content={"error": f"An error occurred: {str(e)}"})
74
+
75
+ @app.post("/transcribe_audio/")
76
+ async def transcribe_audio(request: Request):
77
+ if not model or not processor:
78
+ return JSONResponse(status_code=503, content={"error": "Model is not loaded."})
79
+
80
+ try:
81
+ contents = await request.body()
82
+
83
+ audio_data, original_sr = sf.read(io.BytesIO(contents))
84
+ if audio_data.ndim > 1:
85
+ audio_data = audio_data.mean(axis=1)
86
+
87
+ resampled_audio = librosa.resample(y=audio_data, orig_sr=original_sr, target_sr=16000)
88
+
89
+ inputs = processor(resampled_audio, sampling_rate=16000, return_tensors="pt", padding=True)
90
+
91
+ # <-- CHANGED: Move the input tensors to the same device as the model
92
+ inputs = inputs.to(device)
93
+
94
+ with torch.no_grad():
95
+ logits = model(**inputs).logits
96
+
97
+ predicted_ids = torch.argmax(logits, dim=-1)
98
+ transcription = processor.batch_decode(predicted_ids)[0]
99
+
100
+ print(f"Transcription complete: {transcription}")
101
+ return {"transcription": transcription}
102
+
103
+ except Exception as e:
104
+ print(f"Error during transcription: {str(e)}")
105
+ return JSONResponse(status_code=500, content={"error": f"An error occurred: {str(e)}"})
106
 
107
  # --- 4. Root Endpoint for Health Check ---
108
  @app.get("/")