from pathlib import Path
import argparse
import shutil
import time
import uuid
import subprocess
import gradio as gr
import yaml
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)
from mm_story_agent import MMStoryAgent
try:
result = subprocess.run(["convert", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
imagemagick_installed = True
except FileNotFoundError:
imagemagick_installed = False
if not imagemagick_installed:
import os
os.system("apt update -y")
os.system("apt install -y imagemagick")
os.system("cp policy.xml /etc/ImageMagick-6/")
with open("configs/mm_story_agent.yaml", "r") as reader:
config = yaml.load(reader, Loader=yaml.FullLoader)
default_story_setting = config["story_setting"]
default_story_gen_config = config["story_gen_config"]
default_slideshow_effect = config["slideshow_effect"]
default_image_config = config["image_generation"]
default_sound_config = config["sound_generation"]
default_music_config = config["music_generation"]
def set_generating_progress_text(text):
return gr.update(visible=True, value=f"
{text} ...
")
def set_text_invisible():
return gr.update(visible=False)
def deep_update(original, updates):
for key, value in updates.items():
if isinstance(value, dict):
original[key] = deep_update(original.get(key, {}), value)
else:
original[key] = value
return original
def update_page(direction, page, story_data):
orig_page = page
if direction == 'next' and page < len(story_data) - 1:
page = orig_page + 1
elif direction == 'prev' and page > 0:
page = orig_page - 1
return page, story_data[page], story_data
def write_story_fn(story_topic, main_role, scene,
num_outline, temperature,
current_page,
progress=gr.Progress(track_tqdm=True)):
config["story_dir"] = f"generated_stories/{time.strftime('%Y%m%d-%H%M%S') + '-' + str(uuid.uuid1().hex)}"
deep_update(config, {
"story_setting": {
"story_topic": story_topic,
"main_role": main_role,
"scene": scene,
},
"story_gen_config": {
"num_outline": num_outline,
"temperature": temperature
},
})
story_gen_agent = MMStoryAgent()
pages = story_gen_agent.write_story(config)
# story_data, story_accordion, story_content
return pages, gr.update(visible=True), pages[current_page], gr.update()
def modality_assets_generation_fn(
height, width, image_seed, sound_guidance_scale, sound_seed,
n_candidate_per_text, music_duration,
story_data):
deep_update(config, {
"image_generation": {
"obj_cfg": {
"height": height,
"width": width,
},
"call_cfg": {
"seed": image_seed
}
},
"sound_generation": {
"call_cfg": {
"guidance_scale": sound_guidance_scale,
"seed": sound_seed,
"n_candidate_per_text": n_candidate_per_text
}
},
"music_generation": {
"call_cfg": {
"duration": music_duration
}
}
})
story_gen_agent = MMStoryAgent()
images = story_gen_agent.generate_modality_assets(config, story_data)
# image gallery
return gr.update(visible=True, value=images, columns=[len(images)], rows=[1], height="auto")
def compose_storytelling_video_fn(
fade_duration, slide_duration, zoom_speed, move_ratio,
sound_volume, music_volume, bg_speech_ratio, fps,
story_data,
progress=gr.Progress(track_tqdm=True)):
deep_update(config, {
"slideshow_effect": {
"fade_duration": fade_duration,
"slide_duration": slide_duration,
"zoom_speed": zoom_speed,
"move_ratio": move_ratio,
"sound_volume": sound_volume,
"music_volume": music_volume,
"bg_speech_ratio": bg_speech_ratio,
"fps": fps
},
})
story_gen_agent = MMStoryAgent()
story_gen_agent.compose_storytelling_video(config, story_data)
# video_output
return Path(config["story_dir"]) / "output.mp4"
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML("""
MM-StoryAgent
This is a demo for generating attractive storytelling videos based on the given story setting.
""")
with gr.Row():
with gr.Column():
story_topic = gr.Textbox(label="Story Topic", value=default_story_setting["story_topic"])
main_role = gr.Textbox(label="Main Role", value=default_story_setting["main_role"])
scene = gr.Textbox(label="Scene", value=default_story_setting["scene"])
chapter_num = gr.Number(label="Chapter Number", value=default_story_gen_config["num_outline"])
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=default_story_gen_config["temperature"])
with gr.Accordion("Detailed Image Configuration (Optional)", open=False):
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['height'])
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['width'])
image_seed = gr.Number(label="Image Seed", value=default_image_config["call_cfg"]['seed'])
with gr.Accordion("Detailed Sound Configuration (Optional)", open=False):
sound_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=7.0, step=0.5, value=default_sound_config["call_cfg"]['guidance_scale'])
sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed'])
n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text'])
with gr.Accordion("Detailed Music Configuration (Optional)", open=False):
music_duration = gr.Number(label="Music Duration", min_width=30.0, maximum=120.0, value=default_music_config["call_cfg"]["duration"])
with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False):
fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration'])
slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration'])
zoom_speed = gr.Slider(label="Zoom Speed", minimum=0.1, maximum=2.0, step=0.1, value=default_slideshow_effect['zoom_speed'])
move_ratio = gr.Slider(label="Move Ratio", minimum=0.8, maximum=1.0, step=0.05, value=default_slideshow_effect['move_ratio'])
sound_volume = gr.Slider(label="Sound Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['sound_volume'])
music_volume = gr.Slider(label="Music Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['music_volume'])
bg_speech_ratio = gr.Slider(label="Background / Speech Ratio", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['bg_speech_ratio'])
fps = gr.Slider(label="FPS", minimum=1, maximum=30, step=1, value=default_slideshow_effect['fps'])
with gr.Column():
story_data = gr.State([])
story_generation_information = gr.Markdown(
label="Story Generation Status",
value="Generating Story Script ......
",
visible=False)
with gr.Accordion(label="Story Content", open=False, visible=False) as story_accordion:
with gr.Row():
prev_button = gr.Button("Previous Page",)
next_button = gr.Button("Next Page",)
story_content = gr.Textbox(label="Page Content")
video_generation_information = gr.Markdown(label="Generation Status", value="Generating Video ......
", visible=False)
image_gallery = gr.Gallery(label="Images", show_label=False, visible=False)
video_generation_btn = gr.Button("Generate Video")
video_output = gr.Video(label="Generated Story", interactive=False)
current_page = gr.State(0)
prev_button.click(
fn=update_page,
inputs=[gr.State("prev"), current_page, story_data],
outputs=[current_page, story_content]
)
next_button.click(
fn=update_page,
inputs=[gr.State("next"), current_page, story_data],
outputs=[current_page, story_content,])
# (possibly) update role description and scripts
video_generation_btn.click(
fn=set_generating_progress_text,
inputs=[gr.State("Generating Story")],
outputs=video_generation_information
).then(
fn=write_story_fn,
inputs=[story_topic, main_role, scene,
chapter_num, temperature,
current_page],
outputs=[story_data, story_accordion, story_content, video_output]
).then(
fn=set_generating_progress_text,
inputs=[gr.State("Generating Modality Assets")],
outputs=video_generation_information
).then(
fn=modality_assets_generation_fn,
inputs=[height, width, image_seed, sound_guidance_scale, sound_seed,
n_candidate_per_text, music_duration,
story_data],
outputs=[image_gallery]
).then(
fn=set_generating_progress_text,
inputs=[gr.State("Composing Video")],
outputs=video_generation_information
).then(
fn=compose_storytelling_video_fn,
inputs=[fade_duration, slide_duration, zoom_speed, move_ratio,
sound_volume, music_volume, bg_speech_ratio, fps,
story_data],
outputs=[video_output]
).then(
fn=lambda : gr.update(visible=False),
inputs=[],
outputs=[image_gallery]
).then(
fn=set_generating_progress_text,
inputs=[gr.State("Generation Finished")],
outputs=video_generation_information
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", default=False, action="store_true")
args = parser.parse_args()
demo.launch(share=args.share)