Xu Xuenan
Initial commit
a121edc
raw
history blame
3.7 kB
from pathlib import Path
from typing import List
import json
import torch
import soundfile as sf
from diffusers import AudioLDM2Pipeline
from mm_story_agent.prompts_en import story_to_sound_reviser_system, story_to_sound_review_system
from mm_story_agent.modality_agents.llm import QwenAgent
class AudioLDM2Synthesizer:
def __init__(self,
model_path: str = None,
) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.pipe = AudioLDM2Pipeline.from_pretrained(
model_path if model_path is not None else "cvssp/audioldm2",
torch_dtype=torch.float16
).to(self.device)
def call(self,
prompts: List[str],
n_candidate_per_text: int = 3,
seed: int = 0,
guidance_scale: float = 3.5,
ddim_steps: int = 100,
):
generator = torch.Generator(device=self.device).manual_seed(seed)
audios = self.pipe(
prompts,
num_inference_steps=ddim_steps,
audio_length_in_s=10.0,
guidance_scale=guidance_scale,
generator=generator,
num_waveforms_per_prompt=n_candidate_per_text).audios
audios = audios[::n_candidate_per_text]
return audios
class AudioLDM2Agent:
def __init__(self, config, llm_type="qwen2") -> None:
self.config = config
if llm_type == "qwen2":
self.LLM = QwenAgent
def call(self, pages: List, save_path: str):
sound_prompts = self.generate_sound_prompt_from_story(pages, **self.config["revise_cfg"])
save_paths = []
forward_prompts = []
save_path = Path(save_path)
for idx in range(len(pages)):
if sound_prompts[idx] != "No sounds.":
save_paths.append(save_path / f"p{idx + 1}.wav")
forward_prompts.append(sound_prompts[idx])
generation_agent = AudioLDM2Synthesizer()
if len(forward_prompts) > 0:
sounds = generation_agent.call(
forward_prompts,
**self.config["call_cfg"]
)
for sound, path in zip(sounds, save_paths):
sf.write(path.__str__(), sound, self.config["sample_rate"])
return {
"prompts": sound_prompts,
"modality": "sound"
}
def generate_sound_prompt_from_story(
self,
pages: List,
num_turns: int = 3
):
sound_prompt_reviser = self.LLM(story_to_sound_reviser_system, track_history=False)
sound_prompt_reviewer = self.LLM(story_to_sound_review_system, track_history=False)
sound_prompts = []
for page in pages:
review = ""
sound_prompt = ""
for turn in range(num_turns):
sound_prompt, success = sound_prompt_reviser.run(json.dumps({
"story": page,
"previous_result": sound_prompt,
"improvement_suggestions": review,
}, ensure_ascii=False))
if sound_prompt.startswith("Sound description:"):
sound_prompt = sound_prompt[len("Sound description:"):]
review, success = sound_prompt_reviewer.run(json.dumps({
"story": page,
"sound_description": sound_prompt
}, ensure_ascii=False))
if review == "Check passed.":
break
# else:
# print(review)
sound_prompts.append(sound_prompt)
return sound_prompts