from pathlib import Path import json from typing import List, Union import soundfile as sf import torchaudio from transformers import AutoProcessor, MusicgenForConditionalGeneration from mm_story_agent.modality_agents.llm import QwenAgent from mm_story_agent.prompts_en import story_to_music_reviser_system, story_to_music_reviewer_system class MusicGenSynthesizer: def __init__(self, model_name: str = 'facebook/musicgen-medium', device: str = 'cuda', sample_rate: int = 16000, ) -> None: self.device = device self.processor = AutoProcessor.from_pretrained(model_name) self.model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device) self.sample_rate = sample_rate def call(self, prompt: Union[str, List[str]], save_path: Union[str, Path], ): inputs = self.processor( text=[prompt], padding=True, return_tensors="pt", ).to(self.device) wav = self.model.generate(**inputs, max_new_tokens=1536)[0, 0].cpu() wav = torchaudio.functional.resample(wav, self.model.config.audio_encoder.sampling_rate, self.sample_rate) sf.write(save_path, wav.numpy(), self.sample_rate) class MusicGenAgent: def __init__(self, config, llm_type="qwen2") -> None: self.config = config if llm_type == "qwen2": self.LLM = QwenAgent def generate_music_prompt_from_story( self, pages: List, num_turns: int = 3 ): music_prompt_reviser = self.LLM(story_to_music_reviser_system, track_history=False) music_prompt_reviewer = self.LLM(story_to_music_reviewer_system, track_history=False) music_prompt = "" review = "" for turn in range(num_turns): music_prompt, success = music_prompt_reviser.run(json.dumps({ "story": pages, "previous_result": music_prompt, "improvement_suggestions": review, }, ensure_ascii=False)) review, success = music_prompt_reviewer.run(json.dumps({ "story_content": pages, "music_description": music_prompt }, ensure_ascii=False)) if review == "Check passed.": break return music_prompt def call(self, pages: List, device: str, save_path: str): save_path = Path(save_path) music_prompt = self.generate_music_prompt_from_story(pages, **self.config["revise_cfg"]) generation_agent = MusicGenSynthesizer(device=device) generation_agent.call( prompt=music_prompt, save_path=save_path / "music.wav", **self.config["call_cfg"] ) return { "prompt": music_prompt, "modality": "music" }