wsntxxn
Add AudioCaps checkpoint
6065472
raw
history blame
3.07 kB
from pathlib import Path
import argparse
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample
import utils.train_util as train_util
def load_model(cfg,
ckpt_path,
device):
model = train_util.init_model_from_config(cfg["model"])
ckpt = torch.load(ckpt_path, "cpu")
train_util.load_pretrained_model(model, ckpt)
model.eval()
model = model.to(device)
tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
if not tokenizer.loaded:
tokenizer.load_state_dict(ckpt["tokenizer"])
model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
return model, tokenizer
def infer(file, device, model, tokenizer, target_sr):
sr, wav = file
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0).to(device)
input_dict = {
"mode": "inference",
"wav": wav,
"wav_len": [wav_len],
"specaug": False,
"sample_method": "beam",
"beam_size": 3,
}
with torch.no_grad():
output_dict = model(input_dict)
seq = output_dict["seq"].cpu().numpy()
cap = tokenizer.decode(seq)[0]
return cap
# def input_toggle(input_type):
# if input_type == "file":
# return gr.update(visible=True), gr.update(visible=False)
# elif input_type == "mic":
# return gr.update(visible=False), gr.update(visible=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", default=False)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exp_dir = Path("./checkpoints/audiocaps")
cfg = train_util.load_config(exp_dir / "config.yaml")
target_sr = cfg["target_sr"]
model, tokenizer = load_model(cfg, exp_dir / "ckpt.pth", device)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
# radio = gr.Radio(
# ["file", "mic"],
# value="file",
# label="Select input type"
# )
file = gr.Audio(label="Input", visible=True)
# mic = gr.Microphone(label="Input", visible=False)
# radio.change(fn=input_toggle, inputs=radio, outputs=[file, mic])
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer,
device=device,
model=model,
tokenizer=tokenizer,
target_sr=target_sr),
inputs=[file,],
outputs=output
)
demo.launch(share=args.share)