import os import argparse import random import jsonlines import soundfile as sf import json import copy import torch from pathlib import Path from threading import Thread import torchaudio from transformers import AutoTokenizer from model import VoilaAlphaModel, VoilaModel from spkr import SpeakerEmbedding from voila_tokenizer import VoilaTokenizer from tokenize_func import ( voila_input_format, AUDIO_TOKEN_FORMAT, DEFAULT_AUDIO_TOKEN, ) def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def load_model(model_name): disable_torch_init() if "alpha" in model_name: model_type = "audio" cls = VoilaAlphaModel else: model_type = "token" cls = VoilaModel model = cls.from_pretrained( model_name, torch_dtype=torch.bfloat16, use_flash_attention_2=True, use_cache=True, ) model = model.cuda() tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer_voila = VoilaTokenizer(device="cuda") return model, tokenizer, tokenizer_voila, model_type def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history, ref_embs, ref_embs_mask, max_new_tokens=4096): # step1: initializing num_codebooks = model.config.num_codebooks codebook_size = model.config.codebook_size AUDIO_MIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(0)) assert isinstance(AUDIO_MIN_TOKEN_ID, int) AUDIO_MAX_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(codebook_size*num_codebooks-1)) assert isinstance(AUDIO_MAX_TOKEN_ID, int) AUDIO_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_AUDIO_TOKEN) assert isinstance(AUDIO_TOKEN_ID, int) # step2: set infer config data_cfg = { "input_type": model_type, "task_type": task_type, "num_codebooks": num_codebooks, "codebook_size": codebook_size, } # step3: infer input_ids, audio_datas, audio_data_masks = voila_input_format(history, tokenizer, tokenizer_voila, data_cfg) input_ids = torch.as_tensor([input_ids]).transpose(1,2).cuda() # transpose to [bs, seq, num_codebooks] gen_params = { "input_ids": input_ids, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "max_new_tokens": max_new_tokens, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, "llm_audio_token_id": AUDIO_TOKEN_ID, "min_audio_token_id": AUDIO_MIN_TOKEN_ID, "temperature": 0.8, "top_k": 50, "audio_temperature": 0.2, "audio_top_k": 50, } if model_type == "audio": audio_datas = torch.tensor([audio_datas], dtype=torch.bfloat16).cuda() audio_data_masks = torch.tensor([audio_data_masks]).cuda() gen_params["audio_datas"] = audio_datas gen_params["audio_data_masks"] = audio_data_masks print(tokenizer.decode(input_ids[0, :, 0])) with torch.inference_mode(): outputs = model.run_generate(**gen_params) outputs = outputs[0].cpu().tolist() predict_outputs = outputs[input_ids.shape[1]:] text_outputs = [] audio_outputs = [] for _ in range(num_codebooks): audio_outputs.append([]) for item in predict_outputs: if item[0] >= AUDIO_MIN_TOKEN_ID and item[0] <= AUDIO_MAX_TOKEN_ID: for n, at in enumerate(item): audio_outputs[n].append((at - AUDIO_MIN_TOKEN_ID)%codebook_size) else: if item[0] != tokenizer.eos_token_id: text_outputs.append(item[0]) if task_type in ["chat_tito"]: return tokenizer.decode(text_outputs) elif task_type in ["chat_aiao"]: audio_values = tokenizer_voila.decode(torch.tensor(audio_outputs).cuda()) return audio_values.detach().cpu().numpy(), 16000 else: raise NotImplementedError (f"task type {task_type} is not support to infer yet.") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--instruction", type=str, default="") parser.add_argument("--input-text", type=str, default=None) parser.add_argument("--input-audio", type=str, default=None) parser.add_argument("--ref-audio", type=str, default="examples/test1.mp3") parser.add_argument("--model-name", type=str, default="maitrix-org/Voila-chat") parser.add_argument("--result-path", type=str, default="output") parser.add_argument("--task-type", type=str, default="chat_aiao") args = parser.parse_args() assert args.model_name in ["maitrix-org/Voila-alpha", "maitrix-org/Voila-base", "maitrix-org/Voila-chat"] # step0: Model loading model, tokenizer, tokenizer_voila, model_type = load_model(args.model_name) # step1: prepare inputs Path(args.result_path).mkdir(exist_ok=True, parents=True) history = { "instruction": args.instruction, "conversations": [], } if args.input_text is not None: history["conversations"].append({"from": "user", "text": args.input_text}) elif args.input_audio is not None: history["conversations"].append({"from": "user", "audio": {"file": args.input_audio}}) else: raise Exception("Please provide atleast one of --input-text and --input-audio") history["conversations"].append({"from": "assistant"}) # step2: encode ref ref_embs, ref_embs_mask = None, None if args.task_type in ["chat_aiao"]: spkr_model = SpeakerEmbedding(device="cuda") wav, sr = torchaudio.load(args.ref_audio) ref_embs = spkr_model(wav, sr) ref_embs_mask = torch.tensor([1]).cuda() out = eval_model(model, tokenizer, tokenizer_voila, model_type, args.task_type, history, ref_embs, ref_embs_mask) if args.task_type in ["chat_tito"]: print(out) else: wav, sr = out save_name = f"{args.result_path}/out.wav" sf.write(save_name, wav, sr)