import atexit from dataclasses import fields from time import perf_counter import torch.multiprocessing as mp from tqdm.auto import tqdm from transformers import AutoTokenizer from flashcosyvoice.config import Config, SamplingParams from flashcosyvoice.engine.model_runner import ModelRunner from flashcosyvoice.engine.scheduler import Scheduler from flashcosyvoice.engine.sequence import Sequence class LLMEngine: def __init__(self, model, **kwargs): config_fields = {field.name for field in fields(Config)} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} config = Config(model, **config_kwargs) self.ps = [] self.events = [] ctx = mp.get_context("spawn") assert config.tensor_parallel_size == 1, "NOTE(xcsong): Currently only support tp=1" for i in range(1, config.tensor_parallel_size): event = ctx.Event() process = ctx.Process(target=ModelRunner, args=(config, i, event)) process.start() self.ps.append(process) self.events.append(event) if hasattr(config.hf_config, "speech_vocab_size"): # NOTE: non-chat model, all these special tokens keep randomly initialized. special_tokens = { 'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': [ '<|im_start|>', '<|im_end|>', '<|endofprompt|>', '[breath]', '', '', '[noise]', '[laughter]', '[cough]', '[clucking]', '[accent]', '[quick_breath]', "", "", "[hissing]", "[sigh]", "[vocalized-noise]", "[lipsmack]", "[mn]" ] } self.tokenizer = AutoTokenizer.from_pretrained(f"{config.model}/CosyVoice-BlankEN") self.tokenizer.add_special_tokens(special_tokens) self.skip_special_tokens = True else: self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) if hasattr(config.hf_config, "eos_token_id"): config.eos = config.hf_config.eos_token_id else: config.eos = self.tokenizer.eos_token_id self.model_runner = ModelRunner(config, config.rank, self.events) self.scheduler = Scheduler(config) self.config = config atexit.register(self.exit) def exit(self): self.model_runner.call("exit") del self.model_runner for p in self.ps: p.join() def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): if isinstance(prompt, str): prompt = self.tokenizer.encode(prompt) seq = Sequence(prompt, sampling_params) self.scheduler.add(seq) def step(self): seqs, is_prefill = self.scheduler.schedule() token_ids = self.model_runner.call("run", seqs, is_prefill) self.scheduler.postprocess(seqs, token_ids) outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) return outputs, num_tokens def is_finished(self): return self.scheduler.is_finished() def generate( self, prompts: list[str] | list[list[int]], sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True, ) -> list[str]: if use_tqdm: pbar = tqdm(total=len(prompts), desc="Generating tokens (LLM)", leave=False, dynamic_ncols=True, position=self.config.rank + 1) if not isinstance(sampling_params, list): sampling_params = [sampling_params] * len(prompts) for prompt, sp in zip(prompts, sampling_params): self.add_request(prompt, sp) outputs = {} prefill_throughput = decode_throughput = instant_decode_throughput = 0. total_decode_tokens = 0 total_decode_time = 0. while not self.is_finished(): t = perf_counter() output, num_tokens = self.step() step_time = perf_counter() - t if use_tqdm: if num_tokens > 0: prefill_throughput = num_tokens / step_time else: instant_decode_throughput = -num_tokens / step_time total_decode_tokens += -num_tokens total_decode_time += step_time decode_throughput = total_decode_tokens / total_decode_time if total_decode_time > 0 else 0 pbar.set_postfix({ "Prefill": f"{int(prefill_throughput)}tok/s", "AvgDecode": f"{int(decode_throughput)}tok/s", "InstDecode": f"{int(instant_decode_throughput)}tok/s", }) for seq_id, token_ids in output: outputs[seq_id] = token_ids if use_tqdm: pbar.update(1) outputs = [outputs[seq_id] for seq_id in sorted(outputs)] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] if use_tqdm: pbar.close() return outputs