import torch from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import os from utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels hf_token = os.getenv("HF_TOKEN", None) os.environ["HF_TOKEN"] = hf_token class StepAudio2Base: def __init__(self, model_path: str): print(model_path) self.llm_tokenizer = AutoTokenizer.from_pretrained("stepfun-ai/Step-Audio-2-mini", trust_remote_code=True, padding_side="right") self.llm = AutoModelForCausalLM.from_pretrained("stepfun-ai/Step-Audio-2-mini", trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() self.eos_token_id = self.llm_tokenizer.eos_token_id def __call__(self, messages: list, **kwargs): messages, mels = self.apply_chat_template(messages) # Tokenize prompts prompt_ids = [] for msg in messages: if isinstance(msg, str): prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"]) elif isinstance(msg, list): prompt_ids.append(torch.tensor([msg], dtype=torch.int32)) else: raise ValueError(f"Unsupported content type: {type(msg)}") prompt_ids = torch.cat(prompt_ids, dim=-1).cuda() attention_mask = torch.ones_like(prompt_ids) #mels = None if len(mels) == 0 else torch.stack(mels).cuda() #mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda') if len(mels)==0: mels = None mel_lengths = None else: mels, mel_lengths = padding_mels(mels) mels = mels.cuda() mel_lengths = mel_lengths.cuda() generate_inputs = { "input_ids": prompt_ids, "wavs": mels, "wav_lens": mel_lengths, "attention_mask":attention_mask } generation_config = dict(max_new_tokens=2048, pad_token_id=self.llm_tokenizer.pad_token_id, eos_token_id=self.eos_token_id, ) generation_config.update(kwargs) generation_config = GenerationConfig(**generation_config) outputs = self.llm.generate(**generate_inputs, generation_config=generation_config) output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist() output_text_tokens = [i for i in output_token_ids if i < 151688] output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695] output_text = self.llm_tokenizer.decode(output_text_tokens) return output_token_ids, output_text, output_audio_tokens def apply_chat_template(self, messages: list): results = [] mels = [] for msg in messages: content = msg if isinstance(content, str): text_with_audio = content results.append(text_with_audio) elif isinstance(content, dict): if content["type"] == "text": results.append(f"{content['text']}") elif content["type"] == "audio": audio = load_audio(content['audio']) for i in range(0, audio.shape[0], 16000 * 25): mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479) mels.append(mel) audio_tokens = "" * compute_token_num(mel.shape[1]) results.append(f"{audio_tokens}") elif content["type"] == "token": results.append(content["token"]) else: raise ValueError(f"Unsupported content type: {type(content)}") # print(results) return results, mels class StepAudio2(StepAudio2Base): def __init__(self, model_path: str): super().__init__(model_path) self.llm_tokenizer.eos_token = "<|EOT|>" self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>") self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>") def apply_chat_template(self, messages: list): results = [] mels = [] for msg in messages: role = msg["role"] content = msg["content"] if role == "user": role = "human" if isinstance(content, str): text_with_audio = f"<|BOT|>{role}\n{content}" text_with_audio += '<|EOT|>' if msg.get('eot', True) else '' results.append(text_with_audio) elif isinstance(content, list): results.append(f"<|BOT|>{role}\n") for item in content: if item["type"] == "text": results.append(f"{item['text']}") elif item["type"] == "audio": audio = load_audio(item['audio']) for i in range(0, audio.shape[0], 16000 * 25): mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479) mels.append(mel) audio_tokens = "" * compute_token_num(mel.shape[1]) results.append(f"{audio_tokens}") elif item["type"] == "token": results.append(item["token"]) if msg.get('eot', True): results.append('<|EOT|>') elif content is None: results.append(f"<|BOT|>{role}\n") else: raise ValueError(f"Unsupported content type: {type(content)}") # print(results) return results, mels if __name__ == '__main__': from token2wav import Token2wav model = StepAudio2('stepfun-ai/Step-Audio-2-mini') token2wav = Token2wav('token2wav') # Text-to-text conversation print() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "human", "content": "Give me a brief introduction to the Great Wall."}, {"role": "assistant", "content": None} ] tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) print(text) # Text-to-speech conversation print() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "human", "content": "Give me a brief introduction to the Great Wall."}, {"role": "assistant", "content": "", "eot": False}, # Insert for speech response ] tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) print(text) print(tokens) audio = token2wav(audio, prompt_wav='assets/default_male.wav') with open('output-male.wav', 'wb') as f: f.write(audio) # Speech-to-text conversation print() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, {"role": "assistant", "content": None} ] tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) print(text) # Speech-to-speech conversation print() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, {"role": "assistant", "content": "", "eot": False}, # Insert for speech response ] tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) print(text) print(tokens) audio = token2wav(audio, prompt_wav='assets/default_female.wav') with open('output-female.wav', 'wb') as f: f.write(audio) # Multi-turn conversation print() messages.pop(-1) messages += [ {"role": "assistant", "content": [{"type": "text", "text": ""}, {"type": "token", "token": tokens}]}, {"role": "human", "content": "Now write a 4-line poem about it."}, {"role": "assistant", "content": None} ] tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) print(text) # Multi-modal inputs print() messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."}, {"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]}, {"role": "assistant", "content": None} ] tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True) print(text)