Sample Inference Script

import random
import re
import sys
from argparse import ArgumentParser
from pathlib import Path
from warnings import simplefilter

sys.path.append("xcodec_mini_infer")
simplefilter("ignore")

import torch
import torchaudio
import yaml
from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Config,
    ExLlamaV2Tokenizer,
    Timer,
)
from exllamav2.generator import (
    ExLlamaV2DynamicGenerator,
    ExLlamaV2DynamicJob,
    ExLlamaV2Sampler,
)
from rich import print

from xcodec_mini_infer.models.soundstream_hubert_new import SoundStream

parser = ArgumentParser()
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-g", "--genre", required=True)
parser.add_argument("-l", "--lyrics", required=True)
parser.add_argument("-s", "--seed", type=int, default=None)
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("--repetition_penalty", type=float, default=1.2)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.93)
args = parser.parse_args()

with Timer() as timer:
    config = ExLlamaV2Config(args.model)
    model = ExLlamaV2(config, lazy_load=True)
    cache = ExLlamaV2Cache(model, lazy=True)
    model.load_autosplit(cache)

    tokenizer = ExLlamaV2Tokenizer(config, lazy_init=True)
    generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)
    generator.warmup()

print(f"Loaded model in {timer.interval:.2f} seconds.")

genre = Path(args.genre)
genre = genre.read_text(encoding="utf-8") if genre.is_file() else args.genre
genre = genre.strip()

lyrics = Path(args.lyrics)
lyrics = lyrics.read_text(encoding="utf-8") if lyrics.is_file() else args.lyrics
lyrics = lyrics.strip()

lyrics = re.findall(r"\[(\w+)\](.*?)\n(?=\[|\Z)", lyrics, re.DOTALL)
lyrics = [f"[{l[0]}]\n{l[1].strip()}\n\n" for l in lyrics]
lyrics_joined = "\n".join(lyrics)

gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.allow_tokens(tokenizer, [32002] + list(range(45334, 46358)))
gen_settings.temperature = args.temperature
gen_settings.token_repetition_penalty = args.repetition_penalty
gen_settings.top_p = args.top_p

seed = args.seed if args.seed else random.randint(0, 2**64 - 1)
stop_conditions = ["<EOA>"]

output_joined = ""
output = []

with Timer() as timer:
    for segment in lyrics:
        current = []

        input = (
            "Generate music from the given lyrics segment by segment.\n"
            f"[Genre] {genre}\n"
            f"{lyrics_joined}{output_joined}[start_of_segment]{segment}<SOA><xcodec>"
        )

        input_ids = tokenizer.encode(input, encode_special_tokens=True)
        input_len = input_ids.shape[-1]
        max_new_tokens = config.max_seq_len - input_len

        print(
            f"Using {input_len} tokens of {config.max_seq_len} tokens "
            f"with {max_new_tokens} tokens left."
        )

        job = ExLlamaV2DynamicJob(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            gen_settings=gen_settings,
            seed=seed,
            stop_conditions=stop_conditions,
            decode_special_tokens=True,
        )

        generator.enqueue(job)

        with Timer() as inner:
            while generator.num_remaining_jobs():
                for result in generator.iterate():
                    if result.get("stage") == "streaming":
                        text = result.get("text")

                        if text:
                            current.append(text)
                            output.append(text)

                            if args.debug:
                                print(text, end="", flush=True)

                    if result.get("eos") and current:
                        current_joined = "".join(current)
                        output_joined += (
                            f"[start_of_segment]{segment}<SOA><xcodec>"
                            f"{current_joined}<EOA>[end_of_segment]"
                        )

                        if args.debug:
                            print()

        print(f"Generated {len(current)} tokens in {inner.interval:.2f} seconds.")

print(f"Finished in {timer.interval:.2f} seconds with seed {seed}.")

with Timer() as timer:
    codec_config = Path("xcodec_mini_infer/final_ckpt/config.yaml")
    codec_config = yaml.safe_load(codec_config.read_bytes())
    codec = SoundStream(**codec_config["generator"]["config"])
    state_dict = torch.load("xcodec_mini_infer/final_ckpt/ckpt_00360000.pth")
    codec.load_state_dict(state_dict["codec_model"])
    codec = codec.eval().cuda()

print(f"Loaded codec in {timer.interval:.2f} seconds.")

with Timer() as timer, torch.inference_mode():
    pattern = re.compile(r"<xcodec/0/(\d+)>")
    output_ids = [int(o[10:-1]) for o in output if re.match(pattern, o)]

    vocal = output_ids[::2]
    vocal = torch.tensor([[vocal]]).cuda()
    vocal = vocal.permute(1, 0, 2)
    vocal = codec.decode(vocal)
    vocal = vocal.squeeze(0).cpu()
    torchaudio.save("vocal.wav", vocal, 16000)

    inst = output_ids[1::2]
    inst = torch.tensor([[inst]]).cuda()
    inst = inst.permute(1, 0, 2)
    inst = codec.decode(inst)
    inst = inst.squeeze(0).cpu()
    torchaudio.save("inst.wav", inst, 16000)

print(f"Decoded audio in {timer.interval:.2f} seconds.")
Downloads last month
14
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.

Model tree for Annuvin/YuE-s1-7B-anneal-en-cot-8.0bpw-h8-exl2

Quantized
(9)
this model