Spaces:
Running
on
L4
Running
on
L4
from pathlib import Path | |
import json | |
from typing import List, Union | |
import soundfile as sf | |
import torchaudio | |
from transformers import AutoProcessor, MusicgenForConditionalGeneration | |
from mm_story_agent.modality_agents.llm import QwenAgent | |
from mm_story_agent.prompts_en import story_to_music_reviser_system, story_to_music_reviewer_system | |
class MusicGenSynthesizer: | |
def __init__(self, | |
model_name: str = 'facebook/musicgen-medium', | |
device: str = 'cuda', | |
sample_rate: int = 16000, | |
) -> None: | |
self.device = device | |
self.processor = AutoProcessor.from_pretrained(model_name) | |
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device) | |
self.sample_rate = sample_rate | |
def call(self, | |
prompt: Union[str, List[str]], | |
save_path: Union[str, Path], | |
): | |
inputs = self.processor( | |
text=[prompt], | |
padding=True, | |
return_tensors="pt", | |
).to(self.device) | |
wav = self.model.generate(**inputs, max_new_tokens=1536)[0, 0].cpu() | |
wav = torchaudio.functional.resample(wav, self.model.config.audio_encoder.sampling_rate, self.sample_rate) | |
sf.write(save_path, wav.numpy(), self.sample_rate) | |
class MusicGenAgent: | |
def __init__(self, config, llm_type="qwen2") -> None: | |
self.config = config | |
if llm_type == "qwen2": | |
self.LLM = QwenAgent | |
def generate_music_prompt_from_story( | |
self, | |
pages: List, | |
num_turns: int = 3 | |
): | |
music_prompt_reviser = self.LLM(story_to_music_reviser_system, track_history=False) | |
music_prompt_reviewer = self.LLM(story_to_music_reviewer_system, track_history=False) | |
music_prompt = "" | |
review = "" | |
for turn in range(num_turns): | |
music_prompt, success = music_prompt_reviser.run(json.dumps({ | |
"story": pages, | |
"previous_result": music_prompt, | |
"improvement_suggestions": review, | |
}, ensure_ascii=False)) | |
review, success = music_prompt_reviewer.run(json.dumps({ | |
"story_content": pages, | |
"music_description": music_prompt | |
}, ensure_ascii=False)) | |
if review == "Check passed.": | |
break | |
return music_prompt | |
def call(self, pages: List, device: str, save_path: str): | |
save_path = Path(save_path) | |
music_prompt = self.generate_music_prompt_from_story(pages, **self.config["revise_cfg"]) | |
generation_agent = MusicGenSynthesizer(device=device) | |
generation_agent.call( | |
prompt=music_prompt, | |
save_path=save_path / "music.wav", | |
**self.config["call_cfg"] | |
) | |
return { | |
"prompt": music_prompt, | |
"modality": "music" | |
} |