File size: 2,988 Bytes
b55d767
 
 
 
 
 
2ade89d
b55d767
 
 
 
 
 
dbb3e47
b55d767
 
 
 
c2ba1a1
b55d767
 
 
 
 
 
 
 
 
2ade89d
bba8561
8537948
b55d767
 
8537948
b55d767
 
 
8537948
afa331d
b55d767
 
95a2b02
8537948
 
 
b55d767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8537948
 
 
 
 
 
 
 
 
 
 
b55d767
 
 
 
 
8537948
b55d767
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import importlib
from types import SimpleNamespace

import gradio as gr
import pandas as pd

import spaces
import torch

from utmosv2.utils import get_dataset, get_model

description = (
    "# πŸš€ UTMOSv2 demo\n\n"
    "[![GitHub](https://img.shields.io/badge/-GitHub-181717.svg?logo=github&style=flat)](https://github.com/sarulab-speech/UTMOSv2)\n\n"
    "This is a demonstration of MOS prediction using UTMOSv2. "
    "This demonstration only accepts `.wav` format. Best at 16 kHz sampling rate."
)

device = torch.device("cuda")
config = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{attr: getattr(config, attr) for attr in config.__dict__ if not attr.startswith("__")})
cfg.reproduce = False
cfg.config = "fusion_stage3"
cfg.print_config = False
cfg.data_config = None
cfg.phase = "inference"
cfg.num_workers = 1

@spaces.GPU
@torch.inference_mode()
def predict_mos(audio_path: str, domain: str, quick: bool) -> float:
    data = pd.DataFrame({"file_path": [audio_path]})
    data["dataset"] = domain
    data["mos"] = 0
    preds = 0.0
    for fold in range(5):
        cfg.now_fold = fold
        cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth"
        model = get_model(cfg, device).eval()
        for _ in range(5):
            test_dataset = get_dataset(cfg, data, "test")
            p = model(*[torch.tensor(t,dtype=torch.float32).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
            preds += p.cpu().numpy()[0][0]
            if quick:
                return preds
    preds /= 25.0
    return preds


with gr.Blocks() as demo:
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            audio = gr.Audio(type="filepath", label="Audio")
            domain = gr.Dropdown(
                [
                    "sarulab",
                    "bvcc",
                    "somos",
                    "blizzard2008",
                    "blizzard2009",
                    "blizzard2010-EH1",
                    "blizzard2010-EH2",
                    "blizzard2010-ES1",
                    "blizzard2010-ES3",
                    "blizzard2011",
                ],
                label="Data-domain ID for the MOS prediction",
                value="sarulab",
            )
            quick = gr.Checkbox(
                label="Quick prediction",
                value=True,
                info=(
                    "UTMOSv2 makes predictions repeatedly for five randomly selected frames "
                    "of the input speech waveform for all five folds. "
                    "To make quick predictions by reducing this to a single repetition, "
                    "check this checkbox:",
                ),
            )
            submit = gr.Button(value="Submit")

        with gr.Column():
            output = gr.Textbox(label="Predicted MOS", type="text")
    submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output])

demo.queue().launch()