Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src'))) | |
| import shutil | |
| import mimetypes | |
| import subprocess | |
| import gradio as gr | |
| import torchaudio | |
| import spaces | |
| from model_helper import load_model_checkpoint, transcribe | |
| from prepare_media import prepare_media | |
| from typing import Tuple, Dict, Literal | |
| MODEL_NAME = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"] | |
| PRECISION = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"] | |
| PROJECT = '2024' | |
| MODELS = { | |
| "YMT3+": { | |
| "checkpoint": "[email protected]", | |
| "args": ["[email protected]", '-p', PROJECT, '-pr', PRECISION] | |
| }, | |
| "YPTF+Single (noPS)": { | |
| "checkpoint": "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", | |
| "args": ["ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt", '-p', PROJECT, '-enc', 'perceiver-tf', '-ac', 'spec', | |
| '-hop', '300', '-atc', '1', '-pr', PRECISION] | |
| }, | |
| "YPTF+Multi (PS)": { | |
| "checkpoint": "mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", | |
| "args": ["mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', | |
| '-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf','-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PROJECT] | |
| }, | |
| "YPTF.MoE+Multi (noPS)": { | |
| "checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt", | |
| "args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
| '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
| '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
| '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION] | |
| }, | |
| "YPTF.MoE+Multi (PS)": { | |
| "checkpoint": "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", | |
| "args": ["mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt", '-p', PROJECT, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5', | |
| '-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe', | |
| '-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope', | |
| '-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', PRECISION] | |
| } | |
| } | |
| log_file = 'amt/log.txt' | |
| model = load_model_checkpoint(args=MODELS[MODEL_NAME]["args"], device="cpu") | |
| model.to("cuda") | |
| def prepare_media(source_path_or_url: os.PathLike, | |
| source_type: Literal['audio_filepath', 'youtube_url'], | |
| delete_video: bool = True, | |
| simulate = False) -> Dict: | |
| """prepare media from source path or youtube, and return audio info""" | |
| # Get audio_file | |
| if source_type == 'audio_filepath': | |
| audio_file = source_path_or_url | |
| elif source_type == 'youtube_url': | |
| if os.path.exists('/download/yt_audio.mp3'): | |
| os.remove('/download/yt_audio.mp3') | |
| # Download from youtube | |
| with open(log_file, 'w') as lf: | |
| audio_file = './downloaded/yt_audio' | |
| command = ['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio', | |
| '-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames', | |
| '--extractor-retries', '10', | |
| '--force-overwrites', '--username', 'oauth2', '--password', '', '-v'] | |
| if simulate: | |
| command = command + ['-s'] | |
| process = subprocess.Popen(command, | |
| stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | |
| for line in iter(process.stdout.readline, ''): | |
| # Filter out unnecessary messages | |
| print(line) | |
| if "www.google.com/device" in line: | |
| hl_text = line.replace("https://www.google.com/device", "\033[93mhttps://www.google.com/device\x1b[0m").split() | |
| hl_text[-1] = "\x1b[31;1m" + hl_text[-1] + "\x1b[0m" | |
| lf.write(' '.join(hl_text)); lf.flush() | |
| elif "Authorization successful" in line or "Video unavailable" in line: | |
| lf.write(line); lf.flush() | |
| process.stdout.close() | |
| process.wait() | |
| audio_file += '.mp3' | |
| else: | |
| raise ValueError(source_type) | |
| # Create info | |
| info = torchaudio.info(audio_file) | |
| return { | |
| "filepath": audio_file, | |
| "track_name": os.path.basename(audio_file).split('.')[0], | |
| "sample_rate": int(info.sample_rate), | |
| "bits_per_sample": int(info.bits_per_sample), | |
| "num_channels": int(info.num_channels), | |
| "num_frames": int(info.num_frames), | |
| "duration": int(info.num_frames / info.sample_rate), | |
| "encoding": str.lower(info.encoding), | |
| } | |
| def handle_audio(file_path): | |
| # Guess extension from MIME | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| ext = mimetypes.guess_extension(mime_type) or os.path.splitext(file_path)[1] or ".bin" | |
| output_path = f"received_audio{ext}" | |
| shutil.copy(file_path, output_path) | |
| audio_info = prepare_media(output_path, source_type='audio_filepath') | |
| midifile_path = transcribe(model, audio_info) | |
| return midifile_path | |
| demo = gr.Interface( | |
| fn=handle_audio, | |
| inputs=gr.Audio(type="filepath"), | |
| outputs=gr.File(), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_port=7860 | |
| ) | |