Xu Xuenan
Transformers MusicGen
676ec69
from pathlib import Path
import json
from typing import List, Union
import soundfile as sf
import torchaudio
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from mm_story_agent.modality_agents.llm import QwenAgent
from mm_story_agent.prompts_en import story_to_music_reviser_system, story_to_music_reviewer_system
class MusicGenSynthesizer:
def __init__(self,
model_name: str = 'facebook/musicgen-medium',
device: str = 'cuda',
sample_rate: int = 16000,
) -> None:
self.device = device
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device)
self.sample_rate = sample_rate
def call(self,
prompt: Union[str, List[str]],
save_path: Union[str, Path],
):
inputs = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to(self.device)
wav = self.model.generate(**inputs, max_new_tokens=1536)[0, 0].cpu()
wav = torchaudio.functional.resample(wav, self.model.config.audio_encoder.sampling_rate, self.sample_rate)
sf.write(save_path, wav.numpy(), self.sample_rate)
class MusicGenAgent:
def __init__(self, config, llm_type="qwen2") -> None:
self.config = config
if llm_type == "qwen2":
self.LLM = QwenAgent
def generate_music_prompt_from_story(
self,
pages: List,
num_turns: int = 3
):
music_prompt_reviser = self.LLM(story_to_music_reviser_system, track_history=False)
music_prompt_reviewer = self.LLM(story_to_music_reviewer_system, track_history=False)
music_prompt = ""
review = ""
for turn in range(num_turns):
music_prompt, success = music_prompt_reviser.run(json.dumps({
"story": pages,
"previous_result": music_prompt,
"improvement_suggestions": review,
}, ensure_ascii=False))
review, success = music_prompt_reviewer.run(json.dumps({
"story_content": pages,
"music_description": music_prompt
}, ensure_ascii=False))
if review == "Check passed.":
break
return music_prompt
def call(self, pages: List, device: str, save_path: str):
save_path = Path(save_path)
music_prompt = self.generate_music_prompt_from_story(pages, **self.config["revise_cfg"])
generation_agent = MusicGenSynthesizer(device=device)
generation_agent.call(
prompt=music_prompt,
save_path=save_path / "music.wav",
**self.config["call_cfg"]
)
return {
"prompt": music_prompt,
"modality": "music"
}