File size: 4,646 Bytes
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5152717
 
 
 
 
 
a121edc
 
5152717
 
a121edc
 
 
 
 
 
 
 
f7325de
3b33703
6331da0
f7325de
3b33703
f7325de
 
3b33703
6331da0
f7325de
3b33703
f7325de
 
3b33703
6331da0
f7325de
3b33703
f7325de
 
3b33703
6331da0
f7325de
3b33703
f7325de
a121edc
 
 
 
 
 
 
 
 
 
 
f7325de
 
a121edc
6fc0fd0
5152717
f7325de
 
 
 
 
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import time
import json
from pathlib import Path

import torch.multiprocessing as mp

from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter
from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent
from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent
from mm_story_agent.modality_agents.music_agent import MusicGenAgent
from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent
from mm_story_agent.video_compose_agent import VideoComposeAgent


class MMStoryAgent:

    def __init__(self) -> None:
        self.modalities = ["image", "sound", "speech", "music"]
        self.modality_agent_class = {
            "image": StoryDiffusionAgent,
            "sound": AudioLDM2Agent,
            "speech": CosyVoiceAgent,
            "music": MusicGenAgent
        }
        self.modality_devices = {
            "image": "cuda:0",
            "sound": "cuda:1",
            "music": "cuda:2",
            "speech": "cuda:3"
        }
        self.agents = {}

    def call_modality_agent(self, agent, device, pages, save_path, return_dict):
        result = agent.call(pages, device, save_path)
        modality = result["modality"]
        return_dict[modality] = result

    def write_story(self, config):
        story_writer = QAOutlineStoryWriter(config["story_gen_config"])
        pages = story_writer.call(config["story_setting"])
        return pages
    
    def generate_speech(self, config, pages):
        story_dir = Path(config["story_dir"])
        (story_dir / "speech").mkdir(exist_ok=True, parents=True)
        speech_agent = CosyVoiceAgent(config["speech_generation"])
        speech_agent.call(pages, story_dir / "speech")
    
    def generate_sound(self, config, pages):
        story_dir = Path(config["story_dir"])
        (story_dir / "sound").mkdir(exist_ok=True, parents=True)
        sound_agent = AudioLDM2Agent(config["sound_generation"])
        sound_agent.call(pages, story_dir / "sound")

    def generate_music(self, config, pages):
        story_dir = Path(config["story_dir"])
        (story_dir / "music").mkdir(exist_ok=True, parents=True)
        music_agent = MusicGenAgent(config["music_generation"])
        music_agent.call(pages, story_dir / "music")
    
    def generate_image(self, config, pages):
        story_dir = Path(config["story_dir"])
        (story_dir / "image").mkdir(exist_ok=True, parents=True)
        image_agent = StoryDiffusionAgent(config["image_generation"])
        image_agent.call(pages, story_dir / "image")
    
    def generate_modality_assets(self, config, pages):
        script_data = {"pages": [{"story": page} for page in pages]}
        story_dir = Path(config["story_dir"])

        for sub_dir in self.modalities:
            (story_dir / sub_dir).mkdir(exist_ok=True, parents=True)

        agents = {}
        for modality in self.modalities:
            agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"])

        processes = []
        return_dict = mp.Manager().dict()

        for modality in self.modalities:
            p = mp.Process(target=self.call_modality_agent, args=(agents[modality], self.modality_devices[modality], pages, story_dir / modality, return_dict), daemon=False)
            processes.append(p)
            p.start()
        
        for p in processes:
            p.join()

        for modality, result in return_dict.items():
            try:
                if result["modality"] == "image":
                    images = result["generation_results"]
                    for idx in range(len(pages)):
                        script_data["pages"][idx]["image_prompt"] = result["prompts"][idx]
                elif result["modality"] == "sound":
                    for idx in range(len(pages)):
                        script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx]
                elif result["modality"] == "music":
                    script_data["music_prompt"] = result["prompt"]
            except Exception as e:
                print(f"Error occurred during generation: {e}")
        
        with open(story_dir / "script_data.json", "w") as writer:
            json.dump(script_data, writer, ensure_ascii=False, indent=4)
        
        return images
    
    def compose_storytelling_video(self, config, pages):
        video_compose_agent = VideoComposeAgent()
        video_compose_agent.call(pages, config)

    def call(self, config):
        pages = self.write_story(config)
        images = self.generate_modality_assets(config, pages)
        self.compose_storytelling_video(config, pages)