import gradio as gr import tempfile import json from inference_shutter_speed import load_models, run_inference, OmegaConf import torch # Initialize models once at startup cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_shutter_speed.yaml") pipeline, device = load_models(cfg) def generate_video(base_scene, shutter_speed_list): try: # Validate input if len(json.loads(shutter_speed_list)) != 5: raise ValueError("Exactly 5 shutter_speed values required") # Run inference video_path = run_inference( pipeline=pipeline, tokenizer=pipeline.tokenizer, text_encoder=pipeline.text_encoder, base_scene=base_scene, shutter_speed_list=shutter_speed_list, device=device ) return video_path except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # Example inputs examples = [ [ "A brown and orange leather handbag with a paw print on it sits next to a book.", "[0.11, 0.22, 0.33, 0.44, 0.55]" ], [ "A variety of potted plants are displayed on a windowsill, with some of them placed in yellow and white bowls. ", "[0.29, 0.49, 0.69, 0.79, 0.89]" ] ] with gr.Blocks(title="Shutter Speed Effect Generator") as demo: gr.Markdown("#Dynamic Shutter Speed Effect Generation") with gr.Row(): with gr.Column(): scene_input = gr.Textbox( label="Scene Description", placeholder="Describe the scene you want to generate..." ) shutter_speed_input = gr.Textbox( label="Shutter Speed Values", placeholder="Enter 5 comma-separated values from 0.1-1.0 (e.g., [0.15, 0.32, 0.53, 0.62, 0.82])" ) submit_btn = gr.Button("Generate Video", variant="primary") with gr.Column(): video_output = gr.Video(label="Generated Video") error_output = gr.Textbox(label="Error Messages", visible=False) gr.Examples( examples=examples, inputs=[scene_input, shutter_speed_input], outputs=[video_output], fn=generate_video, cache_examples=True ) submit_btn.click( fn=generate_video, inputs=[scene_input, shutter_speed_input], outputs=[video_output], ) if __name__ == "__main__": demo.launch(share=True)