File size: 4,168 Bytes
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7325de
3b33703
f7325de
3b33703
f7325de
 
3b33703
f7325de
3b33703
f7325de
 
3b33703
f7325de
3b33703
f7325de
 
3b33703
f7325de
3b33703
f7325de
a121edc
 
 
 
 
 
 
 
 
 
 
f7325de
 
a121edc
6fc0fd0
f7325de
 
 
 
 
 
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import time
import json
from pathlib import Path

import torch.multiprocessing as mp

from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter
from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent
from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent
from mm_story_agent.modality_agents.music_agent import MusicGenAgent
from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent
from mm_story_agent.video_compose_agent import VideoComposeAgent


class MMStoryAgent:

    def __init__(self) -> None:
        self.modalities = ["image", "sound", "speech", "music"]
        self.modality_agent_class = {
            "image": StoryDiffusionAgent,
            "sound": AudioLDM2Agent,
            "speech": CosyVoiceAgent,
            "music": MusicGenAgent
        }
        self.agents = {}

    def call_modality_agent(self, agent, pages, save_path, return_dict):
        result = agent.call(pages, save_path)
        modality = result["modality"]
        return_dict[modality] = result

    def write_story(self, config):
        story_writer = QAOutlineStoryWriter(config["story_gen_config"])
        pages = story_writer.call(config["story_setting"])
        return pages
    
    def generate_speech(self, config, pages):
        story_dir = Path(config["story_dir"])
        speech_agent = CosyVoiceAgent(config["speech_generation"])
        speech_agent.call(pages, story_dir / "speech")
    
    def generate_sound(self, config, pages):
        story_dir = Path(config["story_dir"])
        sound_agent = AudioLDM2Agent(config["sound_generation"])
        sound_agent.call(pages, story_dir / "sound")

    def generate_music(self, config, pages):
        story_dir = Path(config["story_dir"])
        music_agent = MusicGenAgent(config["music_generation"])
        music_agent.call(pages, story_dir / "music")
    
    def generate_image(self, config, pages):
        story_dir = Path(config["story_dir"])
        image_agent = StoryDiffusionAgent(config["image_generation"])
        image_agent.call(pages, story_dir / "image")
    
    def generate_modality_assets(self, config, pages):
        script_data = {"pages": [{"story": page} for page in pages]}
        story_dir = Path(config["story_dir"])

        for sub_dir in self.modalities:
            (story_dir / sub_dir).mkdir(exist_ok=True, parents=True)

        agents = {}
        for modality in self.modalities:
            agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"])

        processes = []
        return_dict = mp.Manager().dict()

        for modality in self.modalities:
            p = mp.Process(target=self.call_modality_agent, args=(agents[modality], pages, story_dir / modality, return_dict), daemon=False)
            processes.append(p)
            p.start()
        
        for p in processes:
            p.join()

        for modality, result in return_dict.items():
            try:
                if result["modality"] == "image":
                    images = result["generation_results"]
                    for idx in range(len(pages)):
                        script_data["pages"][idx]["image_prompt"] = result["prompts"][idx]
                elif result["modality"] == "sound":
                    for idx in range(len(pages)):
                        script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx]
                elif result["modality"] == "music":
                    script_data["music_prompt"] = result["prompt"]
            except Exception as e:
                print(f"Error occurred during generation: {e}")
        
        with open(story_dir / "script_data.json", "w") as writer:
            json.dump(script_data, writer, ensure_ascii=False, indent=4)
        
        return images
    
    def compose_storytelling_video(self, config, pages):
        video_compose_agent = VideoComposeAgent()
        video_compose_agent.call(pages, config)

    def call(self, config):
        pages = self.write_story(config)
        images = self.generate_modality_assets(config, pages)
        self.compose_storytelling_video(config, pages)