File size: 7,120 Bytes
2218ac8
 
b9cd341
2218ac8
 
 
487abe1
2218ac8
 
8d7f55f
2218ac8
 
 
43ab7a4
fc4579d
43ab7a4
 
045b9e5
43ab7a4
 
 
 
 
 
 
d520218
43ab7a4
 
 
d520218
045b9e5
7bc406e
43ab7a4
e3f2431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43ab7a4
2218ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from fastapi import FastAPI, UploadFile, File
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
import torch
import tempfile
import os
import time
from pydantic import BaseModel
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

# Define FastAPI app
app = FastAPI()

# Check if GPU is available
device = 0 if torch.cuda.is_available() else -1

# Load Whisper model and processor
model_name = "openai/whisper-large-v2"  # Use the model of your choice, e.g., whisper-small or whisper-large
model = WhisperForConditionalGeneration.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)

# Set forced_decoder_ids to enforce Portuguese language transcription
forced_decoder_ids = processor.get_decoder_prompt_ids(language="portuguese", task="transcribe")
model.config.forced_decoder_ids = forced_decoder_ids

# Initialize the ASR pipeline with the modified model and processor
asr_pipeline = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,  # Explicitly set the tokenizer from the processor
    feature_extractor=processor.feature_extractor,  # Set the feature extractor for audio input
    device=device
)
# Load question-answering model
model_name = 'pierreguillou/bert-base-cased-squad-v1.1-portuguese'
qa_pipeline = pipeline("question-answering", model=model_name)

# Define the context
context = r"""
A pandemia de COVID-19, também conhecida como pandemia de coronavírus, é uma pandemia em curso de COVID-19, 
uma doença respiratória aguda causada pelo coronavírus da síndrome respiratória aguda grave 2 (SARS-CoV-2). 
A doença foi identificada pela primeira vez em Wuhan, na província de Hubei, República Popular da China, 
em 1 de dezembro de 2019, mas o primeiro caso foi reportado em 31 de dezembro do mesmo ano. 
Acredita-se que o vírus tenha uma origem zoonótica, porque os primeiros casos confirmados 
tinham principalmente ligações ao Mercado Atacadista de Frutos do Mar de Huanan, que também vendia animais vivos. 
Em 11 de março de 2020, a Organização Mundial da Saúde declarou o surto uma pandemia. Até 8 de fevereiro de 2021, 
pelo menos 105 743 102 casos da doença foram confirmados em pelo menos 191 países e territórios, 
com cerca de 2 308 943 mortes e 58 851 440 pessoas curadas.
"""

# Define the request body for the POST method
class QuestionRequest(BaseModel):
    question: str

# POST endpoint to answer questions
@app.post("/answer/")
async def answer_question(request: QuestionRequest):
    try:
        # Use the QA model to answer the question based on the context
        result = qa_pipeline(question=request.question, context=context)
        return {
            "question": request.question,
            "answer": result['answer'],
            "score": round(result['score'], 4),
            "start": result['start'],
            "end": result['end']
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Basic GET endpoint
@app.get("/")
def read_root():
    return {"message": "Welcome to the FastAPI app on Hugging Face Spaces!"}

# POST endpoint to transcribe audio
@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
    start_time = time.time()

    # Save the uploaded file using a temporary file manager
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
        temp_audio_file.write(await file.read())
        temp_file_path = temp_audio_file.name

    # Transcribe the audio with long-form generation enabled
    transcription_start = time.time()
    transcription = asr_pipeline(temp_file_path, return_timestamps=True)  # Enable timestamp return for long audio files
    transcription_end = time.time()

    # Clean up temporary file after use
    os.remove(temp_file_path)

    # Log time durations
    end_time = time.time()
    print(f"Time to transcribe audio: {transcription_end - transcription_start:.4f} seconds")
    print(f"Total execution time: {end_time - start_time:.4f} seconds")

    return {"transcription": transcription['text']}

@app.get("/playground/", response_class=HTMLResponse)
def playground():
    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Voice Recorder</title>
    </head>
    <body>
        <h1>Record your voice</h1>
        <button id="startBtn">Start Recording</button>
        <button id="stopBtn" disabled>Stop Recording</button>
        <p id="status">Press start to record your voice...</p>

        <audio id="audioPlayback" controls style="display:none;"></audio>
        <script>
            let mediaRecorder;
            let audioChunks = [];

            const startBtn = document.getElementById('startBtn');
            const stopBtn = document.getElementById('stopBtn');
            const status = document.getElementById('status');
            const audioPlayback = document.getElementById('audioPlayback');

            // Start Recording
            startBtn.addEventListener('click', async () => {
                const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
                mediaRecorder = new MediaRecorder(stream);
                mediaRecorder.start();

                status.textContent = 'Recording...';
                startBtn.disabled = true;
                stopBtn.disabled = false;

                mediaRecorder.ondataavailable = event => {
                    audioChunks.push(event.data);
                };
            });

            // Stop Recording
            stopBtn.addEventListener('click', () => {
                mediaRecorder.stop();
                mediaRecorder.onstop = async () => {
                    status.textContent = 'Recording stopped. Preparing to send...';
                    const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
                    const audioUrl = URL.createObjectURL(audioBlob);
                    audioPlayback.src = audioUrl;
                    audioPlayback.style.display = 'block';
                    audioChunks = [];

                    // Send audio blob to FastAPI endpoint
                    const formData = new FormData();
                    formData.append('file', audioBlob, 'recording.wav');

                    const response = await fetch('/transcribe/', {
                        method: 'POST',
                        body: formData,
                    });

                    const result = await response.json();
                    status.textContent = 'Transcription: ' + result.transcription;
                };

                startBtn.disabled = false;
                stopBtn.disabled = true;
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)
# If running as the main module, start Uvicorn
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)