import time import json from pathlib import Path import torch.multiprocessing as mp from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent from mm_story_agent.modality_agents.music_agent import MusicGenAgent from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent from mm_story_agent.video_compose_agent import VideoComposeAgent class MMStoryAgent: def __init__(self) -> None: self.modalities = ["image", "sound", "speech", "music"] self.modality_agent_class = { "image": StoryDiffusionAgent, "sound": AudioLDM2Agent, "speech": CosyVoiceAgent, "music": MusicGenAgent } self.modality_devices = { "image": "cuda:0", "sound": "cuda:1", "music": "cuda:2", "speech": "cuda:3" } self.agents = {} def call_modality_agent(self, agent, device, pages, save_path, return_dict): result = agent.call(pages, device, save_path) modality = result["modality"] return_dict[modality] = result def write_story(self, config): story_writer = QAOutlineStoryWriter(config["story_gen_config"]) pages = story_writer.call(config["story_setting"]) return pages def generate_speech(self, config, pages): story_dir = Path(config["story_dir"]) (story_dir / "speech").mkdir(exist_ok=True, parents=True) speech_agent = CosyVoiceAgent(config["speech_generation"]) speech_agent.call(pages, story_dir / "speech") def generate_sound(self, config, pages): story_dir = Path(config["story_dir"]) (story_dir / "sound").mkdir(exist_ok=True, parents=True) sound_agent = AudioLDM2Agent(config["sound_generation"]) sound_agent.call(pages, story_dir / "sound") def generate_music(self, config, pages): story_dir = Path(config["story_dir"]) (story_dir / "music").mkdir(exist_ok=True, parents=True) music_agent = MusicGenAgent(config["music_generation"]) music_agent.call(pages, story_dir / "music") def generate_image(self, config, pages): story_dir = Path(config["story_dir"]) (story_dir / "image").mkdir(exist_ok=True, parents=True) image_agent = StoryDiffusionAgent(config["image_generation"]) image_agent.call(pages, story_dir / "image") def generate_modality_assets(self, config, pages): script_data = {"pages": [{"story": page} for page in pages]} story_dir = Path(config["story_dir"]) for sub_dir in self.modalities: (story_dir / sub_dir).mkdir(exist_ok=True, parents=True) agents = {} for modality in self.modalities: agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"]) processes = [] return_dict = mp.Manager().dict() for modality in self.modalities: p = mp.Process(target=self.call_modality_agent, args=(agents[modality], self.modality_devices[modality], pages, story_dir / modality, return_dict), daemon=False) processes.append(p) p.start() for p in processes: p.join() for modality, result in return_dict.items(): try: if result["modality"] == "image": images = result["generation_results"] for idx in range(len(pages)): script_data["pages"][idx]["image_prompt"] = result["prompts"][idx] elif result["modality"] == "sound": for idx in range(len(pages)): script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx] elif result["modality"] == "music": script_data["music_prompt"] = result["prompt"] except Exception as e: print(f"Error occurred during generation: {e}") with open(story_dir / "script_data.json", "w") as writer: json.dump(script_data, writer, ensure_ascii=False, indent=4) return images def compose_storytelling_video(self, config, pages): video_compose_agent = VideoComposeAgent() video_compose_agent.call(pages, config) def call(self, config): pages = self.write_story(config) images = self.generate_modality_assets(config, pages) self.compose_storytelling_video(config, pages)