Sample Inference Script
import random
import re
import sys
from argparse import ArgumentParser
from pathlib import Path
from warnings import simplefilter
sys.path.append("xcodec_mini_infer")
simplefilter("ignore")
import torch
import torchaudio
import yaml
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Config,
ExLlamaV2Tokenizer,
Timer,
)
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
ExLlamaV2Sampler,
)
from rich import print
from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream
parser = ArgumentParser()
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-g", "--genre", required=True)
parser.add_argument("-l", "--lyrics", required=True)
parser.add_argument("-s", "--seed", type=int, default=None)
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("--repetition_penalty", type=float, default=1.2)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.93)
args = parser.parse_args()
with Timer() as timer:
config = ExLlamaV2Config(args.model)
model = ExLlamaV2(config, lazy_load=True)
cache = ExLlamaV2Cache(model, lazy=True)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True)
generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
generator.warmup()
print(f"Loaded model in {timer.interval:.2f} seconds.")
genre = Path(args.genre)
genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre
genre = genre.strip()
lyrics = Path(args.lyrics)
lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics
lyrics = lyrics.strip()
lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL)
lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics]
lyrics_joined = "\n".join(lyrics)
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358)))
gen_settings.temperature = args.temperature
gen_settings.token_repetition_penalty = args.repetition_penalty
gen_settings.top_p = args.top_p
seed = args.seed if args.seed else random.randint(0, 2**64 - 1)
stop_conditions = ["<EOA>"]
output_joined = ""
output = []
with Timer() as timer:
for segment in lyrics:
current = []
input = (
"Generate music from the given lyrics segment by segment.\n"
f"[Genre] {genre}\n"
f"{lyrics_joined}{output_joined}[start_of_segment]{segment}<SOA><xcodec>"
)
input_ids = tokenizer.encode(input, encode_special_tokens=True)
input_len = input_ids.shape[-1]
max_new_tokens = config.max_seq_len - input_len
print(
f"Using {input_len} tokens of {config.max_seq_len} tokens "
f"with {max_new_tokens} tokens left."
)
job = ExLlamaV2DynamicJob(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
gen_settings=gen_settings,
seed=seed,
stop_conditions=stop_conditions,
decode_special_tokens=True,
)
generator.enqueue(job)
with Timer() as inner:
while generator.num_remaining_jobs():
for result in generator.iterate():
if result.get("stage") == "streaming":
text = result.get("text")
if text:
current.append(text)
output.append(text)
if args.debug:
print(text, end="", flush=True)
if result.get("eos") and current:
current_joined = "".join(current)
output_joined += (
f"[start_of_segment]{segment}<SOA><xcodec>"
f"{current_joined}<EOA>[end_of_segment]"
)
if args.debug:
print()
print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.")
print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.")
with Timer() as timer:
codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml")
codec_config = yaml.safe_load(codec_config.read_bytes())
codec = SoundStream(**codec_config["generator"]["config"])
state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth")
codec.load_state_dict(state_dict["codec_model"])
codec = codec.eval().cuda()
print(f"Loaded codec in {timer.interval:.2f} seconds.")
with Timer() as timer, torch.inference_mode():
pattern = re.compile(r"<xcodec/0/(\d+)>")
output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)]
vocal = output_ids[::2]
vocal = torch.tensor([[vocal]]).cuda()
vocal = vocal.permute(1, 0, 2)
vocal = codec.decode(vocal)
vocal = vocal.squeeze(0).cpu()
torchaudio.save("vocal.wav", vocal, 16000)
inst = output_ids[1::2]
inst = torch.tensor([[inst]]).cuda()
inst = inst.permute(1, 0, 2)
inst = codec.decode(inst)
inst = inst.squeeze(0).cpu()
torchaudio.save("inst.wav", inst, 16000)
print(f"Decoded audio in {timer.interval:.2f} seconds.")
- Downloads last month
- 14
Inference Providers
NEW
This model is not currently available via any of the supported third-party Inference Providers, and
HF Inference API was unable to determine this model's library.
Model tree for Annuvin/YuE-s1-7B-anneal-en-cot-8.0bpw-h8-exl2
Base model
m-a-p/YuE-s1-7B-anneal-en-cot