Spaces:
Running
on
L4
Running
on
L4
File size: 4,646 Bytes
a121edc 5152717 a121edc 5152717 a121edc f7325de 3b33703 6331da0 f7325de 3b33703 f7325de 3b33703 6331da0 f7325de 3b33703 f7325de 3b33703 6331da0 f7325de 3b33703 f7325de 3b33703 6331da0 f7325de 3b33703 f7325de a121edc f7325de a121edc 6fc0fd0 5152717 f7325de a121edc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import time
import json
from pathlib import Path
import torch.multiprocessing as mp
from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter
from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent
from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent
from mm_story_agent.modality_agents.music_agent import MusicGenAgent
from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent
from mm_story_agent.video_compose_agent import VideoComposeAgent
class MMStoryAgent:
def __init__(self) -> None:
self.modalities = ["image", "sound", "speech", "music"]
self.modality_agent_class = {
"image": StoryDiffusionAgent,
"sound": AudioLDM2Agent,
"speech": CosyVoiceAgent,
"music": MusicGenAgent
}
self.modality_devices = {
"image": "cuda:0",
"sound": "cuda:1",
"music": "cuda:2",
"speech": "cuda:3"
}
self.agents = {}
def call_modality_agent(self, agent, device, pages, save_path, return_dict):
result = agent.call(pages, device, save_path)
modality = result["modality"]
return_dict[modality] = result
def write_story(self, config):
story_writer = QAOutlineStoryWriter(config["story_gen_config"])
pages = story_writer.call(config["story_setting"])
return pages
def generate_speech(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "speech").mkdir(exist_ok=True, parents=True)
speech_agent = CosyVoiceAgent(config["speech_generation"])
speech_agent.call(pages, story_dir / "speech")
def generate_sound(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "sound").mkdir(exist_ok=True, parents=True)
sound_agent = AudioLDM2Agent(config["sound_generation"])
sound_agent.call(pages, story_dir / "sound")
def generate_music(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "music").mkdir(exist_ok=True, parents=True)
music_agent = MusicGenAgent(config["music_generation"])
music_agent.call(pages, story_dir / "music")
def generate_image(self, config, pages):
story_dir = Path(config["story_dir"])
(story_dir / "image").mkdir(exist_ok=True, parents=True)
image_agent = StoryDiffusionAgent(config["image_generation"])
image_agent.call(pages, story_dir / "image")
def generate_modality_assets(self, config, pages):
script_data = {"pages": [{"story": page} for page in pages]}
story_dir = Path(config["story_dir"])
for sub_dir in self.modalities:
(story_dir / sub_dir).mkdir(exist_ok=True, parents=True)
agents = {}
for modality in self.modalities:
agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"])
processes = []
return_dict = mp.Manager().dict()
for modality in self.modalities:
p = mp.Process(target=self.call_modality_agent, args=(agents[modality], self.modality_devices[modality], pages, story_dir / modality, return_dict), daemon=False)
processes.append(p)
p.start()
for p in processes:
p.join()
for modality, result in return_dict.items():
try:
if result["modality"] == "image":
images = result["generation_results"]
for idx in range(len(pages)):
script_data["pages"][idx]["image_prompt"] = result["prompts"][idx]
elif result["modality"] == "sound":
for idx in range(len(pages)):
script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx]
elif result["modality"] == "music":
script_data["music_prompt"] = result["prompt"]
except Exception as e:
print(f"Error occurred during generation: {e}")
with open(story_dir / "script_data.json", "w") as writer:
json.dump(script_data, writer, ensure_ascii=False, indent=4)
return images
def compose_storytelling_video(self, config, pages):
video_compose_agent = VideoComposeAgent()
video_compose_agent.call(pages, config)
def call(self, config):
pages = self.write_story(config)
images = self.generate_modality_assets(config, pages)
self.compose_storytelling_video(config, pages)
|