Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import torch | |
| import nemo.collections.asr as nemo_asr | |
| import wandb | |
| from pydub.utils import mediainfo | |
| MODEL_HISTORY_DAYS = 180 | |
| WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "tarteel") | |
| WANDB_PROJECT_NAME = os.environ.get("WANDB_PROJECT_NAME", "nemo-experiments") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "CfCtcLg-SpeUni1024-DI-EATLDN-CA:v0") | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| else: | |
| return "cpu" | |
| # run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT_NAME) | |
| wandb_api = wandb.Api(overrides={"entity": WANDB_ENTITY}) | |
| artifact = wandb_api.artifact(f"{WANDB_ENTITY}/{WANDB_PROJECT_NAME}/{MODEL_NAME}") | |
| artifact_dir = artifact.download() | |
| # find the model (ending with .nemo) in the artifact directory | |
| model_path = [ | |
| os.path.join(root, file) | |
| for root, dirs, files in os.walk(artifact_dir) | |
| for file in files | |
| if file.endswith(".nemo") | |
| ][0] | |
| model = nemo_asr.models.EncDecCTCModelBPE.restore_from( | |
| model_path, map_location=get_device() | |
| ) | |
| def transcribe(audio_file): | |
| transcription_file = model.transcribe([audio_file])[0] | |
| return transcription_file | |
| def get_duration_ms(audio_file): | |
| duration = mediainfo(audio_file)["duration"] | |
| duration_ms = int(duration) * 1000 | |
| return duration_ms | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # ﷽ | |
| """ | |
| ) | |
| with gr.Row(): | |
| audio_file = gr.Audio(source="upload", type="filepath", label="File") | |
| with gr.Row(): | |
| output_file = gr.TextArea(label="Audio Transcription") | |
| b1 = gr.Button("Transcribe") | |
| b1.click( | |
| transcribe, | |
| inputs=[audio_file], | |
| outputs=[output_file], | |
| api_name="transcribe", | |
| ) | |
| b2 = gr.Button("Get Duration") | |
| with gr.Row(): | |
| duration = gr.TextArea(label="Duration") | |
| b2.click( | |
| get_duration_ms, | |
| inputs=[audio_file], | |
| outputs=[duration], | |
| api_name="get_duration_ms", | |
| ) | |
| demo.launch() | |