from fastapi import FastAPI, UploadFile, File from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor import torch import tempfile import os import time from pydantic import BaseModel from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles # Define FastAPI app app = FastAPI() # Check if GPU is available device = 0 if torch.cuda.is_available() else -1 # Load Whisper model and processor model_name = "openai/whisper-large-v2" # Use the model of your choice, e.g., whisper-small or whisper-large model = WhisperForConditionalGeneration.from_pretrained(model_name) processor = WhisperProcessor.from_pretrained(model_name) # Set forced_decoder_ids to enforce Portuguese language transcription forced_decoder_ids = processor.get_decoder_prompt_ids(language="portuguese", task="transcribe") model.config.forced_decoder_ids = forced_decoder_ids # Initialize the ASR pipeline with the modified model and processor asr_pipeline = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, # Explicitly set the tokenizer from the processor feature_extractor=processor.feature_extractor, # Set the feature extractor for audio input device=device ) # Load question-answering model model_name = 'pierreguillou/bert-base-cased-squad-v1.1-portuguese' qa_pipeline = pipeline("question-answering", model=model_name) # Define the context context = r""" A pandemia de COVID-19, também conhecida como pandemia de coronavírus, é uma pandemia em curso de COVID-19, uma doença respiratória aguda causada pelo coronavírus da síndrome respiratória aguda grave 2 (SARS-CoV-2). A doença foi identificada pela primeira vez em Wuhan, na província de Hubei, República Popular da China, em 1 de dezembro de 2019, mas o primeiro caso foi reportado em 31 de dezembro do mesmo ano. Acredita-se que o vírus tenha uma origem zoonótica, porque os primeiros casos confirmados tinham principalmente ligações ao Mercado Atacadista de Frutos do Mar de Huanan, que também vendia animais vivos. Em 11 de março de 2020, a Organização Mundial da Saúde declarou o surto uma pandemia. Até 8 de fevereiro de 2021, pelo menos 105 743 102 casos da doença foram confirmados em pelo menos 191 países e territórios, com cerca de 2 308 943 mortes e 58 851 440 pessoas curadas. """ # Define the request body for the POST method class QuestionRequest(BaseModel): question: str # POST endpoint to answer questions @app.post("/answer/") async def answer_question(request: QuestionRequest): try: # Use the QA model to answer the question based on the context result = qa_pipeline(question=request.question, context=context) return { "question": request.question, "answer": result['answer'], "score": round(result['score'], 4), "start": result['start'], "end": result['end'] } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Basic GET endpoint @app.get("/") def read_root(): return {"message": "Welcome to the FastAPI app on Hugging Face Spaces!"} # POST endpoint to transcribe audio @app.post("/transcribe/") async def transcribe_audio(file: UploadFile = File(...)): start_time = time.time() # Save the uploaded file using a temporary file manager with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: temp_audio_file.write(await file.read()) temp_file_path = temp_audio_file.name # Transcribe the audio with long-form generation enabled transcription_start = time.time() transcription = asr_pipeline(temp_file_path, return_timestamps=True) # Enable timestamp return for long audio files transcription_end = time.time() # Clean up temporary file after use os.remove(temp_file_path) # Log time durations end_time = time.time() print(f"Time to transcribe audio: {transcription_end - transcription_start:.4f} seconds") print(f"Total execution time: {end_time - start_time:.4f} seconds") return {"transcription": transcription['text']} @app.get("/playground/", response_class=HTMLResponse) def playground(): html_content = """
Press start to record your voice...
""" return HTMLResponse(content=html_content) # If running as the main module, start Uvicorn if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)