ali-kanbar's picture
Upload 2 files
2333322 verified
raw
history blame
8.49 kB
import gradio as gr
import asyncio
import os
import traceback
import numpy as np
import re
from functools import partial
import torch
import imageio
import cv2
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
import edge_tts
from transformers import AutoTokenizer, pipeline
from moviepy.editor import VideoFileClip, AudioFileClip
from func_timeout import func_timeout, FunctionTimedOut
# Initialize models with cache optimization
def initialize_components():
global tokenizer, text_pipe, sentiment_analyzer, pipe
# Text generation components
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct", cache_dir="model_cache")
text_pipe = pipeline(
"text-generation",
model="Qwen/Qwen2.5-1.5B-Instruct",
tokenizer=tokenizer,
device_map="auto",
cache_dir="model_cache"
)
# Sentiment analysis
sentiment_analyzer = pipeline("sentiment-analysis", cache_dir="model_cache")
# Video generation setup
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
step = 8
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
base = "emilianJR/epiCRealism"
# Load motion adapter with caching
adapter = MotionAdapter().to(device, dtype)
model_path = hf_hub_download(repo, ckpt, cache_dir="model_cache")
adapter.load_state_dict(load_file(model_path, device=device))
# Initialize pipeline
pipe = AnimateDiffPipeline.from_pretrained(
base,
motion_adapter=adapter,
torch_dtype=dtype,
cache_dir="model_cache"
).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
beta_schedule="linear"
)
initialize_components()
# Cleanup function for resource management
def cleanup():
torch.cuda.empty_cache()
for f in ["generated_video.mp4", "final_video_with_audio.mp4", "output.mp3"]:
if os.path.exists(f):
try:
os.remove(f)
except:
pass
# Story generation functions (keep your original functions but add timeout)
def generate_video(summary):
def crossfade_transition(frames1, frames2, transition_length=10):
blended_frames = []
frames1_np = [np.array(frame) for frame in frames1[-transition_length:]]
frames2_np = [np.array(frame) for frame in frames2[:transition_length]]
for i in range(transition_length):
alpha = i / transition_length
beta = 1.0 - alpha
blended = cv2.addWeighted(frames1_np[i], beta, frames2_np[i], alpha, 0)
blended_frames.append(Image.fromarray(blended))
return blended_frames
sentences = []
current_sentence = ""
for char in summary:
current_sentence += char
if char in {'.', '!', '?'}:
sentences.append(current_sentence.strip())
current_sentence = ""
sentences = [s.strip() for s in sentences if s.strip()]
output_dir = "generated_frames"
video_path = "generated_video.mp4"
os.makedirs(output_dir, exist_ok=True)
all_frames = []
previous_frames = None
transition_frames = 10
batch_size = 1
for i in range(0, len(sentences), batch_size):
batch_prompts = sentences[i : i + batch_size]
for idx, prompt in enumerate(batch_prompts):
try:
output = func_timeout(
300, # 5 minute timeout per scene
pipe,
args=(prompt,),
kwargs={
'guidance_scale': 1.0,
'num_inference_steps': step,
'width': 128, # Reduced resolution
'height': 128
}
)
frames = output.frames[0]
if previous_frames is not None:
transition = crossfade_transition(previous_frames, frames, transition_frames)
all_frames.extend(transition)
all_frames.extend(frames)
previous_frames = frames
except FunctionTimedOut:
print(f"Timeout generating scene {i+idx+1}")
return None
except Exception as e:
print(f"Error generating scene: {str(e)}")
continue
imageio.mimsave(video_path, all_frames, fps=6) # Reduced FPS
return video_path
# Modified main processing function with enhanced error handling
def create_story_video(prompt, progress=gr.Progress()):
cleanup() # Clear previous runs
if not prompt or len(prompt.strip()) < 5:
return "Prompt too short (min 5 characters)", None, None
if len(prompt) > 500:
return "Prompt too long (max 500 characters)", None, None
try:
progress(0, desc="Starting story generation...")
story = generate_story(prompt)
progress(25, desc="Story generated")
progress(30, desc="Starting video generation...")
video_path = generate_video(story)
if not video_path:
return story, None, "Video generation failed"
progress(60, desc="Video rendered")
progress(65, desc="Creating audio summary...")
audio_summary = summary_of_summary(story, video_path)
progress(75, desc="Generating voiceover...")
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
audio_file = loop.run_until_complete(
generate_audio_with_sentiment(audio_summary, sentiment_analyzer)
)
except Exception as e:
return story, None, f"Audio error: {str(e)}"
progress(90, desc="Finalizing video...")
output_path = 'final_video_with_audio.mp4'
combine_video_with_audio(video_path, audio_file, output_path)
return story, output_path, audio_summary
except Exception as e:
error_msg = f"Error: {str(e)}"
print(traceback.format_exc())
return error_msg, None, None
# Keep other functions (summarize, generate_story, etc.) unchanged from your original code
# ...
# Gradio interface setup with resource management
EXAMPLE_PROMPTS = [
"A nurse discovers an unusual pattern in patient symptoms.",
"A family finds a time capsule during home renovation.",
"A restaurant owner innovates to save their business.",
"Wildlife tracking reveals climate changes.",
"Community rebuilds after natural disaster."
]
with gr.Blocks(title="AI Story Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎬 AI Story Video Generator")
gr.Markdown("Enter a short story idea (5-500 characters)")
with gr.Row():
prompt_input = gr.Textbox(
label="Story Idea",
placeholder="Example: A detective finds a hidden room...",
max_lines=2
)
gr.Examples(
examples=EXAMPLE_PROMPTS,
inputs=prompt_input,
label="Example Prompts"
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Tabs():
with gr.Tab("Results"):
video_output = gr.Video(label="Generated Video", interactive=False)
story_output = gr.Textbox(label="Full Story", lines=10)
audio_summary = gr.Textbox(label="Audio Summary", lines=3)
generate_btn.click(
fn=create_story_video,
inputs=prompt_input,
outputs=[story_output, video_output, audio_summary]
)
clear_btn.click(
fn=lambda: [None, None, None],
outputs=[story_output, video_output, audio_summary]
)
demo.load(fn=cleanup)
demo.unload(fn=cleanup)
if __name__ == "__main__":
demo.launch(server_port=7860, show_error=True)