File size: 3,239 Bytes
3b9ff86
2349faf
3b9ff86
 
 
2349faf
3b9ff86
 
 
 
 
 
 
 
 
 
 
 
 
 
2349faf
 
3b9ff86
 
 
 
2349faf
 
 
3b9ff86
 
 
 
 
 
 
6639d45
 
59b71d1
6639d45
3b9ff86
763786a
 
 
 
 
 
 
 
 
 
 
 
59b71d1
3b9ff86
 
6639d45
 
 
 
3b9ff86
763786a
 
6639d45
763786a
 
 
 
 
 
 
 
 
 
3b9ff86
763786a
 
6639d45
3b9ff86
 
 
 
763786a
3b9ff86
 
 
763786a
6639d45
59b71d1
 
3b9ff86
 
6639d45
3b9ff86
6639d45
 
 
3b9ff86
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from nemo.collections.asr.models import ASRModel
import torch
import os
import spaces
from huggingface_hub import hf_hub_download

# Check for GPU support and configure appropriately
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.Tensor([0]).to(device)
print(f"Device being used: {zero.device}")

# Model name from HuggingFace
MODEL_NAMES = [
    "NAMAA-Space/EgypTalk-ASR-v2"
]

# Cache loaded models
LOADED_MODELS = {}

# Filename of the NeMo checkpoint inside the HF repo
NEMO_FILENAME = "asr-egyptian-nemo-v2.0.nemo"

def get_model(model_name):
    if model_name not in LOADED_MODELS:
        print(f"Loading model {model_name} on {device}...")
        # Download the .nemo artifact from HF and restore
        nemo_path = hf_hub_download(repo_id=model_name, filename=NEMO_FILENAME)
        model = ASRModel.restore_from(nemo_path)
        model = model.to(device)
        model.eval()
        LOADED_MODELS[model_name] = model
        print(f"Model {model_name} loaded successfully!")
    return LOADED_MODELS[model_name]

@spaces.GPU(duration=120)
def transcribe_and_score(audio):
    if audio is None:
        return ""
    model = get_model(MODEL_NAMES[0])
    predictions = model.transcribe([audio])
    item = predictions[0] if isinstance(predictions, list) else predictions

    # Extract plain text regardless of return type
    if hasattr(item, "text"):
        text = item.text
    elif isinstance(item, dict) and "text" in item:
        text = item["text"]
    elif isinstance(item, str):
        text = item
    else:
        text = str(item)

    return text.strip()

@spaces.GPU(duration=120)
def batch_transcribe(audio_files):
    if not audio_files:
        return []
    model = get_model(MODEL_NAMES[0])
    predictions = model.transcribe(audio_files)

    texts = []
    if isinstance(predictions, list):
        for p in predictions:
            if hasattr(p, "text"):
                t = p.text
            elif isinstance(p, dict) and "text" in p:
                t = p["text"]
            elif isinstance(p, str):
                t = p
            else:
                t = str(p)
            texts.append(t)
    else:
        texts.append(str(predictions))

    return [[t.strip()] for t in texts]

with gr.Blocks(title="EgypTalk-ASR-v2") as demo:
    gr.Markdown("""
    # EgypTalk-ASR-v2
    Upload or record an audio file. This app transcribes audio using EgypTalk-ASR-v2.
    """)
    with gr.Tab("Single Test"):
        with gr.Row():
            audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
        transcribe_btn = gr.Button("Transcribe")
        pred_output = gr.Textbox(label="Transcription")
        transcribe_btn.click(transcribe_and_score, inputs=[audio_input], outputs=[pred_output])

    with gr.Tab("Batch Test"):
        gr.Markdown("Upload multiple audio files. Batch size is limited by GPU/CPU memory.")
        audio_files = gr.Files(label="Audio Files (wav)")
        batch_btn = gr.Button("Batch Transcribe")
        preds_output = gr.Dataframe(headers=["Transcription"], label="Results")
        batch_btn.click(batch_transcribe, inputs=[audio_files], outputs=[preds_output])


demo.launch(share=True)