Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,31 +1,26 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
import os
|
4 |
import time
|
5 |
-
import imageio
|
6 |
from diffusers.utils import load_image
|
7 |
from skyreels_v2_infer.modules import download_model
|
8 |
from skyreels_v2_infer.pipelines import Text2VideoPipeline
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
def generate_video(prompt):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
pipe = Text2VideoPipeline(
|
15 |
-
model_path=model_path,
|
16 |
-
dit_path=model_path,
|
17 |
-
use_usp=False,
|
18 |
-
offload=True,
|
19 |
-
)
|
20 |
-
|
21 |
-
height = 544
|
22 |
-
width = 960
|
23 |
-
seed = int(time.time())
|
24 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
25 |
|
26 |
kwargs = {
|
27 |
"prompt": prompt,
|
28 |
-
"negative_prompt": "
|
29 |
"num_frames": 97,
|
30 |
"num_inference_steps": 30,
|
31 |
"guidance_scale": 6.0,
|
@@ -38,22 +33,24 @@ def generate_video(prompt):
|
|
38 |
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
|
39 |
video_frames = pipe(**kwargs)[0]
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
47 |
|
48 |
with gr.Blocks() as demo:
|
49 |
gr.Markdown("# SkyReels V2")
|
50 |
with gr.Row():
|
51 |
-
prompt_input = gr.Textbox(label="Prompt para generar video"
|
52 |
with gr.Row():
|
53 |
-
|
54 |
with gr.Row():
|
55 |
output_video = gr.Video(label="Video generado")
|
56 |
|
57 |
-
|
58 |
|
59 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import gc
|
4 |
import os
|
5 |
import time
|
|
|
6 |
from diffusers.utils import load_image
|
7 |
from skyreels_v2_infer.modules import download_model
|
8 |
from skyreels_v2_infer.pipelines import Text2VideoPipeline
|
9 |
|
10 |
+
# Descargar y preparar modelo
|
11 |
+
MODEL_ID = "Skywork/SkyReels-V2-T2V-14B-540P"
|
12 |
+
model_path = download_model(MODEL_ID)
|
13 |
+
pipe = Text2VideoPipeline(model_path=model_path, dit_path=model_path)
|
14 |
+
|
15 |
def generate_video(prompt):
|
16 |
+
print(f"Generando video para: {prompt}")
|
17 |
+
height, width = 544, 960
|
18 |
+
seed = int(time.time()) % 4294967294
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
20 |
|
21 |
kwargs = {
|
22 |
"prompt": prompt,
|
23 |
+
"negative_prompt": "ugly, blurry, bad quality, watermark",
|
24 |
"num_frames": 97,
|
25 |
"num_inference_steps": 30,
|
26 |
"guidance_scale": 6.0,
|
|
|
33 |
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
|
34 |
video_frames = pipe(**kwargs)[0]
|
35 |
|
36 |
+
output_path = f"/tmp/video_{seed}.mp4"
|
37 |
+
import imageio
|
38 |
+
imageio.mimwrite(output_path, video_frames, fps=24, quality=8)
|
39 |
+
|
40 |
+
gc.collect()
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
|
43 |
+
return output_path
|
44 |
|
45 |
with gr.Blocks() as demo:
|
46 |
gr.Markdown("# SkyReels V2")
|
47 |
with gr.Row():
|
48 |
+
prompt_input = gr.Textbox(label="Prompt para generar video")
|
49 |
with gr.Row():
|
50 |
+
generate_btn = gr.Button("Generar video")
|
51 |
with gr.Row():
|
52 |
output_video = gr.Video(label="Video generado")
|
53 |
|
54 |
+
generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=output_video)
|
55 |
|
56 |
demo.launch()
|