Sample Inference Script
from argparse import ArgumentParser
import torch
import torchaudio
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Config,
ExLlamaV2Tokenizer,
Timer,
)
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
ExLlamaV2Sampler,
)
from jinja2 import Template
from rich import print
from torchaudio import functional as F
from xcodec2.modeling_xcodec2 import XCodec2Model
parser = ArgumentParser()
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-v", "--vocoder", required=True)
parser.add_argument("-i", "--input", required=True)
parser.add_argument("-a", "--audio", default="")
parser.add_argument("-t", "--transcript", default="")
parser.add_argument("-o", "--output", default="output.wav")
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("--sample_rate", type=int, default=16000)
parser.add_argument("--max_seq_len", type=int, default=2048)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=1.0)
args = parser.parse_args()
with Timer() as timer:
config = ExLlamaV2Config(args.model)
config.max_seq_len = args.max_seq_len
model = ExLlamaV2(config, lazy_load=True)
cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache, progress=True)
tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True)
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
print(f"Loaded model in {timer.interval:.2f} seconds.")
with Timer() as timer:
vocoder = XCodec2Model.from_pretrained(args.vocoder)
vocoder = vocoder.cuda().eval()
print(f"Loaded vocoder in {timer.interval:.2f} seconds.")
if args.audio and args.transcript:
with Timer() as timer:
transcript = f"{args.transcript} "
audio, sample_rate = torchaudio.load(args.audio)
audio = audio.cuda()
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if sample_rate != args.sample_rate:
audio = F.resample(audio, sample_rate, args.sample_rate)
audio = vocoder.encode_code(audio)
audio = audio[0, 0, :]
audio = [f"<|s_{a}|>" for a in audio]
audio = "".join(audio)
print(f"Encoded audio in {timer.interval:.2f} seconds.")
else:
transcript = ""
audio = ""
with Timer() as timer:
messages = [
{
"role": "user",
"content": (
"Convert the text to speech:"
"<|TEXT_UNDERSTANDING_START|>"
f"{transcript}{args.input}"
"<|TEXT_UNDERSTANDING_END|>"
),
},
{"role": "assistant", "content": f"<|SPEECH_GENERATION_START|>{audio}"},
]
template = tokenizer.tokenizer_config_dict["chat_template"]
template = Template(template)
input = template.render(messages=messages)
input_ids = tokenizer.encode(input[:-10], add_bos=True)
print(f"Encoded input in {timer.interval:.2f} seconds.")
with Timer() as timer:
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.temperature = args.temperature
gen_settings.top_p = args.top_p
job = ExLlamaV2DynamicJob(
input_ids=input_ids,
max_new_tokens=config.max_seq_len - input_ids.shape[-1],
gen_settings=gen_settings,
stop_conditions=["<|SPEECH_GENERATION_END|>"],
)
generator.enqueue(job)
output = []
while generator.num_remaining_jobs():
for result in generator.iterate():
if result.get("stage") == "streaming":
text = result.get("text")
output.append(text)
if args.debug:
print(text, end="", flush=True)
if result.get("eos"):
generator.clear_queue()
if args.debug:
print()
print(f"Generated {len(output)} tokens in {timer.interval:.2f} seconds.")
with Timer() as timer:
output = [int(o[4:-2]) for o in output if o]
output = torch.tensor([[output]]).cuda()
output = vocoder.decode_code(output)
output = output[0, 0, :]
output = output.unsqueeze(0).cpu()
torchaudio.save(args.output, output, args.sample_rate)
print(f"Decoded audio in {timer.interval:.2f} seconds.")
- Downloads last month
- 4
Inference Providers
NEW
This model is not currently available via any of the supported third-party Inference Providers, and
HF Inference API was unable to determine this model's library.