Xu Xuenan
Initial commit
a121edc
raw
history blame
6.02 kB
import time
import json
from pathlib import Path
import torch.multiprocessing as mp
from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter
from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent
from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent
from mm_story_agent.modality_agents.music_agent import MusicGenAgent
from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent
from mm_story_agent.video_compose_agent import VideoComposeAgent
class MMStoryAgent:
def __init__(self) -> None:
self.modalities = ["image", "sound", "speech", "music"]
self.modality_agent_class = {
"image": StoryDiffusionAgent,
"sound": AudioLDM2Agent,
"speech": CosyVoiceAgent,
"music": MusicGenAgent
}
self.agents = {}
def call_modality_agent(self, agent, pages, save_path, return_dict):
result = agent.call(pages, save_path)
modality = result["modality"]
return_dict[modality] = result
def write_story(self, config):
story_writer = QAOutlineStoryWriter(config["story_gen_config"])
pages = story_writer.call(config["story_setting"])
# pages = [
# "In the heart of a dense forest, Flicker the Fox, nestled in his cozy den, stumbled upon an ancient computer hidden beneath a pile of soft moss and forgotten treasures. Surrounded by maps of unexplored territories and codes scribbled on parchment, Flicker's eyes widened with intrigue as he traced his paw over the mysterious machine.",
# "Flicker's den was a testament to his adventurous spirit, a haven filled with artifacts from his previous quests. The discovery of the computer, however, sparked a new kind of excitement within him, a curiosity that went beyond the physical boundaries of his forest home.",
# "With a determined gleam in his eye, Flicker trotted out of his den in search of his parents. He had questions about this relic that couldn't wait, eager to understand the secrets it held and how it functioned in a world so different from his own.",
# "Excited by his parents' encouragement, Flicker eagerly started his journey into the world of typing. His paws clumsily hit the wrong keys at first, resulting in a string of random letters and numbers on the screen. But with every mistake, Flicker's determination grew stronger.",
# "Days turned into weeks, and Flicker's persistence paid off. His paws now moved gracefully across the keyboard, his eyes focused on the screen as he typed out simple messages and commands. The once foreign device was becoming a familiar tool, and Flicker felt a sense of accomplishment wash over him.",
# "One evening, as the moon illuminated the forest, a wise old owl named Ollie perched on a branch outside Flicker's den. With a hoot and a smile, Ollie shared the magic of keyboard shortcuts, turning Flicker's typing sessions into thrilling adventures. Each shortcut was like a secret code, and Flicker couldn't wait to master them all.",
# "Eager to explore beyond the basics, Flicker's curiosity led him to the vast digital world of the internet. With guidance from his parents and Ollie, he learned how to navigate safely, discovering interactive games and educational videos that opened his eyes to the wonders beyond his forest.",
# "Each day, Flicker would sit before the screen, his paws dancing over the keys as he clicked through virtual tours of distant lands, watched videos of creatures he'd never seen, and played games that taught him about science and history. The computer became a window to a world far larger than he could have imagined.",
# ]
return pages
def generate_modality_assets(self, config, pages):
script_data = {"pages": [{"story": page} for page in pages]}
story_dir = Path(config["story_dir"])
for sub_dir in self.modalities:
(story_dir / sub_dir).mkdir(exist_ok=True, parents=True)
agents = {}
for modality in self.modalities:
agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"])
processes = []
return_dict = mp.Manager().dict()
for modality in self.modalities:
p = mp.Process(target=self.call_modality_agent, args=(agents[modality], pages, story_dir / modality, return_dict))
processes.append(p)
p.start()
for p in processes:
p.join()
for modality, result in return_dict.items():
try:
if result["modality"] == "image":
images = result["generation_results"]
for idx in range(len(pages)):
script_data["pages"][idx]["image_prompt"] = result["prompts"][idx]
elif result["modality"] == "sound":
for idx in range(len(pages)):
script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx]
elif result["modality"] == "music":
script_data["music_prompt"] = result["prompt"]
except Exception as e:
print(f"Error occurred during generation: {e}")
with open(story_dir / "script_data.json", "w") as writer:
json.dump(script_data, writer, ensure_ascii=False, indent=4)
return images
def compose_storytelling_video(self, config, pages):
video_compose_agent = VideoComposeAgent()
video_compose_agent.call(pages, config)
def call(self, config):
pages = self.write_story(config)
images = self.generate_modality_assets(config, pages)
self.compose_storytelling_video(config, pages)
if __name__ == "__main__":
from arg_parser import parse_yaml_and_cmd
config = parse_yaml_and_cmd()
mm_story_agent = MMStoryAgent()
mm_story_agent.call(config)