wsntxxn
Change to Hugging Face calling
dd3d338
raw
history blame
3.38 kB
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample
from transformers import AutoModel, PreTrainedTokenizerFast
def load_model(model_name,
device):
if model_name == "AudioCaps":
model = AutoModel.from_pretrained(
"wsntxxn/effb2-trm-audiocaps-captioning",
trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/audiocaps-simple-tokenizer"
)
elif model_name == "Clotho":
model = AutoModel.from_pretrained(
"wsntxxn/effb2-trm-clotho-captioning",
trust_remote_code=True
).to(device)
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"wsntxxn/clotho-simple-tokenizer"
)
return model, tokenizer
def infer(file, runner):
sr, wav = file
wav = torch.as_tensor(wav)
if wav.dtype == torch.short:
wav = wav / 2 ** 15
elif wav.dtype == torch.int:
wav = wav / 2 ** 31
if wav.ndim > 1:
wav = wav.mean(1)
wav = resample(wav, sr, runner.target_sr)
wav_len = len(wav)
wav = wav.float().unsqueeze(0)
with torch.no_grad():
word_idx = runner.model(
audio=wav,
audio_length=[wav_len]
)[0]
cap = runner.tokenizer.decode(word_idx, skip_special_tokens=True)
return cap
# def input_toggle(input_type):
# if input_type == "file":
# return gr.update(visible=True), gr.update(visible=False)
# elif input_type == "mic":
# return gr.update(visible=False), gr.update(visible=True)
class InferRunner:
def __init__(self, model_name):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.tokenizer = load_model(model_name, self.device)
self.target_sr = self.model.config.sample_rate
def change_model(self, model_name):
self.model, self.tokenizer = load_model(model_name, self.device)
self.target_sr = self.model.config.sample_rate
def change_model(radio):
global infer_runner
infer_runner.change_model(radio)
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning")
with gr.Row():
gr.Markdown("""
[![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329)
[![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model)
""")
with gr.Row():
with gr.Column():
radio = gr.Radio(
["AudioCaps", "Clotho"],
value="AudioCaps",
label="Select model"
)
infer_runner = InferRunner(radio.value)
file = gr.Audio(label="Input", visible=True)
radio.change(fn=change_model, inputs=[radio,],)
btn = gr.Button("Run")
with gr.Column():
output = gr.Textbox(label="Output")
btn.click(
fn=partial(infer,
runner=infer_runner),
inputs=[file,],
outputs=output
)
demo.launch()