irah23 commited on
Commit
6b6138e
·
1 Parent(s): 2208b08

update space

Browse files
Files changed (1) hide show
  1. app.py +40 -26
app.py CHANGED
@@ -9,49 +9,63 @@ import imageio
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  MAX_SEED = np.iinfo(np.int32).max
11
 
12
- # SDXL for image generation
13
  sdxl_model_id = "stabilityai/sdxl-turbo"
14
- image_pipe = DiffusionPipeline.from_pretrained(sdxl_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
15
- image_pipe = image_pipe.to(device)
 
 
16
 
17
- # Stable Video Diffusion for video generation
18
  svd_model_id = "stabilityai/stable-video-diffusion-img2vid"
19
- video_pipe = StableVideoDiffusionPipeline.from_pretrained(svd_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, variant="fp16" if device == "cuda" else None)
20
- video_pipe.enable_model_cpu_offload() if device == "cuda" else None
 
 
 
 
 
21
 
22
  def generate_video_from_text(prompt, seed=0, randomize_seed=True):
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
  generator = torch.Generator(device=device).manual_seed(seed)
26
 
27
- # Generate image from text
28
- image = image_pipe(prompt=prompt, generator=generator, guidance_scale=0.0, num_inference_steps=2, width=1024, height=1024).images[0]
 
 
 
 
 
 
 
29
 
30
- # Resize to 512x512
31
  image = image.resize((512, 512))
32
 
33
- # Generate video frames from image
34
- video_frames = video_pipe(image).frames[0] # list of PIL images
35
-
36
- # Convert to video (MP4)
37
  video_path = f"/tmp/generated_{seed}.mp4"
38
  imageio.mimsave(video_path, video_frames, fps=7)
39
 
40
  return video_path, image, seed
41
 
42
- with gr.Blocks() as demo:
43
- gr.Markdown("## Text to Video using SDXL + Stable Video Diffusion")
44
-
45
- with gr.Row():
46
- prompt = gr.Textbox(label="Prompt", placeholder="Describe your scene...")
47
- run_button = gr.Button("Generate")
48
-
49
- video_output = gr.Video(label="Generated Video")
50
- image_output = gr.Image(label="Generated Image")
51
- seed_output = gr.Number(label="Seed")
52
-
53
- run_button.click(fn=generate_video_from_text, inputs=[prompt], outputs=[video_output, image_output, seed_output])
 
 
54
 
 
55
  demo.api_name = "predict"
56
-
57
  demo.launch()
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  MAX_SEED = np.iinfo(np.int32).max
11
 
12
+ # Load SDXL for image generation
13
  sdxl_model_id = "stabilityai/sdxl-turbo"
14
+ image_pipe = DiffusionPipeline.from_pretrained(
15
+ sdxl_model_id,
16
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
17
+ ).to(device)
18
 
19
+ # Load Stable Video Diffusion for video generation
20
  svd_model_id = "stabilityai/stable-video-diffusion-img2vid"
21
+ video_pipe = StableVideoDiffusionPipeline.from_pretrained(
22
+ svd_model_id,
23
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
24
+ variant="fp16" if device == "cuda" else None
25
+ )
26
+ if device == "cuda":
27
+ video_pipe.enable_model_cpu_offload()
28
 
29
  def generate_video_from_text(prompt, seed=0, randomize_seed=True):
30
  if randomize_seed:
31
  seed = random.randint(0, MAX_SEED)
32
  generator = torch.Generator(device=device).manual_seed(seed)
33
 
34
+ # Generate image
35
+ image = image_pipe(
36
+ prompt=prompt,
37
+ generator=generator,
38
+ guidance_scale=0.0,
39
+ num_inference_steps=2,
40
+ width=1024,
41
+ height=1024
42
+ ).images[0]
43
 
44
+ # Resize for SVD
45
  image = image.resize((512, 512))
46
 
47
+ # Generate video
48
+ video_frames = video_pipe(image).frames[0]
 
 
49
  video_path = f"/tmp/generated_{seed}.mp4"
50
  imageio.mimsave(video_path, video_frames, fps=7)
51
 
52
  return video_path, image, seed
53
 
54
+ # Use Interface instead of Blocks
55
+ demo = gr.Interface(
56
+ fn=generate_video_from_text,
57
+ inputs=[
58
+ gr.Textbox(label="Prompt", placeholder="Describe your scene..."),
59
+ gr.Number(label="Seed", value=0),
60
+ gr.Checkbox(label="Randomize Seed", value=True)
61
+ ],
62
+ outputs=[
63
+ gr.Video(label="Generated Video"),
64
+ gr.Image(label="Generated Image"),
65
+ gr.Number(label="Seed Used")
66
+ ]
67
+ )
68
 
69
+ # Expose endpoint
70
  demo.api_name = "predict"
 
71
  demo.launch()