legusxyz commited on
Commit
2218ac8
·
verified ·
1 Parent(s): 554e451

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -6
app.py CHANGED
@@ -1,8 +1,132 @@
 
 
1
  import torch
 
 
 
 
 
2
 
3
- if torch.cuda.is_available():
4
- device = torch.device("cuda")
5
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
6
- else:
7
- device = torch.device("cpu")
8
- print("CUDA is not available, using CPU.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
3
  import torch
4
+ import tempfile
5
+ import os
6
+ import time
7
+ from fastapi.responses import HTMLResponse
8
+ from fastapi.staticfiles import StaticFiles
9
 
10
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
11
+ # True
12
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
13
+
14
+ model = load_pytorch_model()
15
+ model = model.to("cuda")
16
+
17
+ # Define FastAPI app
18
+ app = FastAPI()
19
+
20
+ # Load the Whisper model once during startup
21
+ device = torch.device("cuda")
22
+ asr_pipeline = pipeline(model="openai/whisper-large", device=device) # Initialize Whisper model
23
+ # asr_pipeline = pipeline( model="openai/whisper-small", device=device, language="pt")
24
+
25
+
26
+ # Basic GET endpoint
27
+ @app.get("/")
28
+ def read_root():
29
+ return {"message": "Welcome to the FastAPI app on Hugging Face Spaces!"}
30
+
31
+ # POST endpoint to transcribe audio
32
+ @app.post("/transcribe/")
33
+ async def transcribe_audio(file: UploadFile = File(...)):
34
+ start_time = time.time()
35
+
36
+ # Save the uploaded file using a temporary file manager
37
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
38
+ temp_audio_file.write(await file.read())
39
+ temp_file_path = temp_audio_file.name
40
+
41
+ # Transcribe the audio with long-form generation enabled
42
+ transcription_start = time.time()
43
+ transcription = asr_pipeline(temp_file_path, return_timestamps=True) # Enable timestamp return for long audio files
44
+ transcription_end = time.time()
45
+
46
+ # Clean up temporary file after use
47
+ os.remove(temp_file_path)
48
+
49
+ # Log time durations
50
+ end_time = time.time()
51
+ print(f"Time to transcribe audio: {transcription_end - transcription_start:.4f} seconds")
52
+ print(f"Total execution time: {end_time - start_time:.4f} seconds")
53
+
54
+ return {"transcription": transcription['text']}
55
+
56
+ @app.get("/playground/", response_class=HTMLResponse)
57
+ def playground():
58
+ html_content = """
59
+ <!DOCTYPE html>
60
+ <html lang="en">
61
+ <head>
62
+ <meta charset="UTF-8">
63
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
64
+ <title>Voice Recorder</title>
65
+ </head>
66
+ <body>
67
+ <h1>Record your voice</h1>
68
+ <button id="startBtn">Start Recording</button>
69
+ <button id="stopBtn" disabled>Stop Recording</button>
70
+ <p id="status">Press start to record your voice...</p>
71
+
72
+ <audio id="audioPlayback" controls style="display:none;"></audio>
73
+ <script>
74
+ let mediaRecorder;
75
+ let audioChunks = [];
76
+
77
+ const startBtn = document.getElementById('startBtn');
78
+ const stopBtn = document.getElementById('stopBtn');
79
+ const status = document.getElementById('status');
80
+ const audioPlayback = document.getElementById('audioPlayback');
81
+
82
+ // Start Recording
83
+ startBtn.addEventListener('click', async () => {
84
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
85
+ mediaRecorder = new MediaRecorder(stream);
86
+ mediaRecorder.start();
87
+
88
+ status.textContent = 'Recording...';
89
+ startBtn.disabled = true;
90
+ stopBtn.disabled = false;
91
+
92
+ mediaRecorder.ondataavailable = event => {
93
+ audioChunks.push(event.data);
94
+ };
95
+ });
96
+
97
+ // Stop Recording
98
+ stopBtn.addEventListener('click', () => {
99
+ mediaRecorder.stop();
100
+ mediaRecorder.onstop = async () => {
101
+ status.textContent = 'Recording stopped. Preparing to send...';
102
+ const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
103
+ const audioUrl = URL.createObjectURL(audioBlob);
104
+ audioPlayback.src = audioUrl;
105
+ audioPlayback.style.display = 'block';
106
+ audioChunks = [];
107
+
108
+ // Send audio blob to FastAPI endpoint
109
+ const formData = new FormData();
110
+ formData.append('file', audioBlob, 'recording.wav');
111
+
112
+ const response = await fetch('/transcribe/', {
113
+ method: 'POST',
114
+ body: formData,
115
+ });
116
+
117
+ const result = await response.json();
118
+ status.textContent = 'Transcription: ' + result.transcription;
119
+ };
120
+
121
+ startBtn.disabled = false;
122
+ stopBtn.disabled = true;
123
+ });
124
+ </script>
125
+ </body>
126
+ </html>
127
+ """
128
+ return HTMLResponse(content=html_content)
129
+ # If running as the main module, start Uvicorn
130
+ if __name__ == "__main__":
131
+ import uvicorn
132
+ uvicorn.run(app, host="0.0.0.0", port=7860)