Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import asyncio | |
| import os | |
| import traceback | |
| import numpy as np | |
| import re | |
| from functools import partial | |
| import torch | |
| import imageio | |
| import cv2 | |
| from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| from PIL import Image | |
| import edge_tts | |
| from transformers import AutoTokenizer, pipeline | |
| from moviepy.editor import VideoFileClip, AudioFileClip | |
| from func_timeout import func_timeout, FunctionTimedOut | |
| # Initialize models with cache optimization | |
| def initialize_components(): | |
| global tokenizer, text_pipe, sentiment_analyzer, pipe | |
| # Text generation components | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", cache_dir="model_cache") | |
| text_pipe = pipeline( | |
| "text-generation", | |
| model="Qwen/Qwen2.5-1.5B-Instruct", | |
| tokenizer=tokenizer, | |
| device_map="auto", | |
| cache_dir="model_cache" | |
| ) | |
| # Sentiment analysis | |
| sentiment_analyzer = pipeline("sentiment-analysis", cache_dir="model_cache") | |
| # Video generation setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| step = 8 | |
| repo = "ByteDance/AnimateDiff-Lightning" | |
| ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" | |
| base = "emilianJR/epiCRealism" | |
| # Load motion adapter with caching | |
| adapter = MotionAdapter().to(device, dtype) | |
| model_path = hf_hub_download(repo, ckpt, cache_dir="model_cache") | |
| adapter.load_state_dict(load_file(model_path, device=device)) | |
| # Initialize pipeline | |
| pipe = AnimateDiffPipeline.from_pretrained( | |
| base, | |
| motion_adapter=adapter, | |
| torch_dtype=dtype, | |
| cache_dir="model_cache" | |
| ).to(device) | |
| pipe.scheduler = EulerDiscreteScheduler.from_config( | |
| pipe.scheduler.config, | |
| timestep_spacing="trailing", | |
| beta_schedule="linear" | |
| ) | |
| initialize_components() | |
| # Cleanup function for resource management | |
| def cleanup(): | |
| torch.cuda.empty_cache() | |
| for f in ["generated_video.mp4", "final_video_with_audio.mp4", "output.mp3"]: | |
| if os.path.exists(f): | |
| try: | |
| os.remove(f) | |
| except: | |
| pass | |
| # Story generation functions (keep your original functions but add timeout) | |
| def generate_video(summary): | |
| def crossfade_transition(frames1, frames2, transition_length=10): | |
| blended_frames = [] | |
| frames1_np = [np.array(frame) for frame in frames1[-transition_length:]] | |
| frames2_np = [np.array(frame) for frame in frames2[:transition_length]] | |
| for i in range(transition_length): | |
| alpha = i / transition_length | |
| beta = 1.0 - alpha | |
| blended = cv2.addWeighted(frames1_np[i], beta, frames2_np[i], alpha, 0) | |
| blended_frames.append(Image.fromarray(blended)) | |
| return blended_frames | |
| sentences = [] | |
| current_sentence = "" | |
| for char in summary: | |
| current_sentence += char | |
| if char in {'.', '!', '?'}: | |
| sentences.append(current_sentence.strip()) | |
| current_sentence = "" | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| output_dir = "generated_frames" | |
| video_path = "generated_video.mp4" | |
| os.makedirs(output_dir, exist_ok=True) | |
| all_frames = [] | |
| previous_frames = None | |
| transition_frames = 10 | |
| batch_size = 1 | |
| for i in range(0, len(sentences), batch_size): | |
| batch_prompts = sentences[i : i + batch_size] | |
| for idx, prompt in enumerate(batch_prompts): | |
| try: | |
| output = func_timeout( | |
| 300, # 5 minute timeout per scene | |
| pipe, | |
| args=(prompt,), | |
| kwargs={ | |
| 'guidance_scale': 1.0, | |
| 'num_inference_steps': step, | |
| 'width': 128, # Reduced resolution | |
| 'height': 128 | |
| } | |
| ) | |
| frames = output.frames[0] | |
| if previous_frames is not None: | |
| transition = crossfade_transition(previous_frames, frames, transition_frames) | |
| all_frames.extend(transition) | |
| all_frames.extend(frames) | |
| previous_frames = frames | |
| except FunctionTimedOut: | |
| print(f"Timeout generating scene {i+idx+1}") | |
| return None | |
| except Exception as e: | |
| print(f"Error generating scene: {str(e)}") | |
| continue | |
| imageio.mimsave(video_path, all_frames, fps=6) # Reduced FPS | |
| return video_path | |
| # Modified main processing function with enhanced error handling | |
| def create_story_video(prompt, progress=gr.Progress()): | |
| cleanup() # Clear previous runs | |
| if not prompt or len(prompt.strip()) < 5: | |
| return "Prompt too short (min 5 characters)", None, None | |
| if len(prompt) > 500: | |
| return "Prompt too long (max 500 characters)", None, None | |
| try: | |
| progress(0, desc="Starting story generation...") | |
| story = generate_story(prompt) | |
| progress(25, desc="Story generated") | |
| progress(30, desc="Starting video generation...") | |
| video_path = generate_video(story) | |
| if not video_path: | |
| return story, None, "Video generation failed" | |
| progress(60, desc="Video rendered") | |
| progress(65, desc="Creating audio summary...") | |
| audio_summary = summary_of_summary(story, video_path) | |
| progress(75, desc="Generating voiceover...") | |
| try: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| audio_file = loop.run_until_complete( | |
| generate_audio_with_sentiment(audio_summary, sentiment_analyzer) | |
| ) | |
| except Exception as e: | |
| return story, None, f"Audio error: {str(e)}" | |
| progress(90, desc="Finalizing video...") | |
| output_path = 'final_video_with_audio.mp4' | |
| combine_video_with_audio(video_path, audio_file, output_path) | |
| return story, output_path, audio_summary | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| print(traceback.format_exc()) | |
| return error_msg, None, None | |
| # Keep other functions (summarize, generate_story, etc.) unchanged from your original code | |
| # ... | |
| # Gradio interface setup with resource management | |
| EXAMPLE_PROMPTS = [ | |
| "A nurse discovers an unusual pattern in patient symptoms.", | |
| "A family finds a time capsule during home renovation.", | |
| "A restaurant owner innovates to save their business.", | |
| "Wildlife tracking reveals climate changes.", | |
| "Community rebuilds after natural disaster." | |
| ] | |
| with gr.Blocks(title="AI Story Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎬 AI Story Video Generator") | |
| gr.Markdown("Enter a short story idea (5-500 characters)") | |
| with gr.Row(): | |
| prompt_input = gr.Textbox( | |
| label="Story Idea", | |
| placeholder="Example: A detective finds a hidden room...", | |
| max_lines=2 | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLE_PROMPTS, | |
| inputs=prompt_input, | |
| label="Example Prompts" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Tabs(): | |
| with gr.Tab("Results"): | |
| video_output = gr.Video(label="Generated Video", interactive=False) | |
| story_output = gr.Textbox(label="Full Story", lines=10) | |
| audio_summary = gr.Textbox(label="Audio Summary", lines=3) | |
| generate_btn.click( | |
| fn=create_story_video, | |
| inputs=prompt_input, | |
| outputs=[story_output, video_output, audio_summary] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None, None, None], | |
| outputs=[story_output, video_output, audio_summary] | |
| ) | |
| demo.load(fn=cleanup) | |
| demo.unload(fn=cleanup) | |
| if __name__ == "__main__": | |
| demo.launch(server_port=7860, show_error=True) |