pgonzzz commited on
Commit
2290869
·
verified ·
1 Parent(s): 66182cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -1,31 +1,26 @@
1
  import gradio as gr
2
  import torch
 
3
  import os
4
  import time
5
- import imageio
6
  from diffusers.utils import load_image
7
  from skyreels_v2_infer.modules import download_model
8
  from skyreels_v2_infer.pipelines import Text2VideoPipeline
9
 
 
 
 
 
 
10
  def generate_video(prompt):
11
- model_id = "Skywork/SkyReels-V2-T2V-14B-540P"
12
- model_path = download_model(model_id)
13
-
14
- pipe = Text2VideoPipeline(
15
- model_path=model_path,
16
- dit_path=model_path,
17
- use_usp=False,
18
- offload=True,
19
- )
20
-
21
- height = 544
22
- width = 960
23
- seed = int(time.time())
24
  generator = torch.Generator(device="cuda").manual_seed(seed)
25
 
26
  kwargs = {
27
  "prompt": prompt,
28
- "negative_prompt": "static, blurred, low quality",
29
  "num_frames": 97,
30
  "num_inference_steps": 30,
31
  "guidance_scale": 6.0,
@@ -38,22 +33,24 @@ def generate_video(prompt):
38
  with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
39
  video_frames = pipe(**kwargs)[0]
40
 
41
- timestamp = time.strftime("%Y%m%d_%H%M%S")
42
- filename = f"output_{timestamp}.mp4"
43
- os.makedirs("video_out", exist_ok=True)
44
- path = os.path.join("video_out", filename)
45
- imageio.mimwrite(path, video_frames, fps=24, quality=8)
46
- return path
 
 
47
 
48
  with gr.Blocks() as demo:
49
  gr.Markdown("# SkyReels V2")
50
  with gr.Row():
51
- prompt_input = gr.Textbox(label="Prompt para generar video", placeholder="Describe la escena que quieres ver")
52
  with gr.Row():
53
- run_button = gr.Button("Generar video")
54
  with gr.Row():
55
  output_video = gr.Video(label="Video generado")
56
 
57
- run_button.click(fn=generate_video, inputs=prompt_input, outputs=output_video)
58
 
59
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import gc
4
  import os
5
  import time
 
6
  from diffusers.utils import load_image
7
  from skyreels_v2_infer.modules import download_model
8
  from skyreels_v2_infer.pipelines import Text2VideoPipeline
9
 
10
+ # Descargar y preparar modelo
11
+ MODEL_ID = "Skywork/SkyReels-V2-T2V-14B-540P"
12
+ model_path = download_model(MODEL_ID)
13
+ pipe = Text2VideoPipeline(model_path=model_path, dit_path=model_path)
14
+
15
  def generate_video(prompt):
16
+ print(f"Generando video para: {prompt}")
17
+ height, width = 544, 960
18
+ seed = int(time.time()) % 4294967294
 
 
 
 
 
 
 
 
 
 
19
  generator = torch.Generator(device="cuda").manual_seed(seed)
20
 
21
  kwargs = {
22
  "prompt": prompt,
23
+ "negative_prompt": "ugly, blurry, bad quality, watermark",
24
  "num_frames": 97,
25
  "num_inference_steps": 30,
26
  "guidance_scale": 6.0,
 
33
  with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
34
  video_frames = pipe(**kwargs)[0]
35
 
36
+ output_path = f"/tmp/video_{seed}.mp4"
37
+ import imageio
38
+ imageio.mimwrite(output_path, video_frames, fps=24, quality=8)
39
+
40
+ gc.collect()
41
+ torch.cuda.empty_cache()
42
+
43
+ return output_path
44
 
45
  with gr.Blocks() as demo:
46
  gr.Markdown("# SkyReels V2")
47
  with gr.Row():
48
+ prompt_input = gr.Textbox(label="Prompt para generar video")
49
  with gr.Row():
50
+ generate_btn = gr.Button("Generar video")
51
  with gr.Row():
52
  output_video = gr.Video(label="Video generado")
53
 
54
+ generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=output_video)
55
 
56
  demo.launch()