File size: 5,922 Bytes
227bc73
ecea5f9
227bc73
eca6bdb
227bc73
ecea5f9
227bc73
 
eca6bdb
 
 
ecea5f9
 
 
227bc73
f9a089d
eca6bdb
827103d
 
 
 
 
 
 
 
ad2ae6c
4d4355a
eca6bdb
 
ecea5f9
eca6bdb
ecea5f9
 
eca6bdb
 
ecea5f9
eca6bdb
 
 
 
 
 
 
 
 
 
 
 
ecea5f9
eca6bdb
 
 
 
 
 
 
b36585e
eca6bdb
 
 
 
 
 
 
 
 
b113647
227bc73
ecea5f9
227bc73
eca6bdb
ecea5f9
 
eca6bdb
 
 
 
 
ecea5f9
 
 
 
 
 
 
eca6bdb
 
 
 
 
 
c0db3ab
eca6bdb
 
 
 
b113647
eca6bdb
 
 
 
 
 
 
ecea5f9
eca6bdb
 
4d4355a
eca6bdb
 
 
 
 
 
 
 
 
 
 
 
 
 
ecea5f9
eca6bdb
ecea5f9
 
 
eca6bdb
 
 
 
 
 
c0db3ab
eca6bdb
 
 
 
 
 
b113647
eca6bdb
 
0bedac9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import spaces

import gradio as gr
import argparse  # Keep argparse, but we'll modify its use
import sys
import time
import os
import random
# VERY IMPORTANT: Add the SkyReels-V1 root directory to the Python path
# Assuming your app.py is in the root of your cloned/forked repo.
sys.path.append(".")  # Correct path for Hugging Face Space
from skyreelsinfer import TaskType
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
from diffusers.utils import export_to_video
from diffusers.utils import load_image
import torch # Import Torch

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision("highest")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Model Loading (CRITICAL CHANGES) ---
predictor = None  # Global predictor, BUT loaded inside a function

def get_transformer_model_id(task_type: str) -> str:
    return "Skywork/SkyReels-V1-Hunyuan-I2V" if task_type == "i2v" else "Skywork/SkyReels-V1-Hunyuan-T2V"

@spaces.GPU(duration=90)
def init_predictor(task_type: str):
    global predictor
    try:
        predictor = SkyReelsVideoInfer(
            task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
            model_id=get_transformer_model_id(task_type),
            quant_model=True,  # Keep quantization for smaller model size
            world_size=1,  # VERY IMPORTANT: Set world_size to 1 for CPU
            is_offload=True,  # Keep offload for CPU
            offload_config=OffloadConfig(
                high_cpu_memory=True,
                parameters_level=True,
                compiler_transformer=False,  # Consider setting to True if compatible
            )
        )
        # Explicitly move the predictor to CPU (CRUCIAL)
        if hasattr(predictor, 'pipe') and hasattr(predictor.pipe, 'to'): #check to make sure the predictor has a pipe and to() method
            predictor.pipe.to("cpu")
        return "Model loaded successfully!"
    except Exception as e:
        return f"Error loading model: {e}"

@spaces.GPU(duration=90)
def generate_video(prompt, seed, image=None, task_type=None):
    global predictor

    # Input Type Validation
    if task_type == "i2v" and not isinstance(image, str):
        return "Error: For i2v, please provide a valid image file path.", "{}"
    if not isinstance(prompt, str) or not isinstance(seed, (int, float)):
        return "Error: Invalid input types for prompt or seed.", "{}"


    if seed == -1:
        random.seed(time.time())
        seed = int(random.randrange(4294967294))

    kwargs = {
        "prompt": prompt,
        "height": 512,  # Consider reducing for faster processing on CPU
        "width": 512,  # Consider reducing for faster processing on CPU
        "num_frames": 97,  # Consider reducing for faster processing on CPU
        "num_inference_steps": 30,  # Consider reducing for faster processing
        "seed": int(seed), #make sure seed is int
        "guidance_scale": 6.0,
        "embedded_guidance_scale": 1.0,
        "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
        "cfg_for": False,
    }

    if task_type == "i2v":
        if image is None or not os.path.exists(image):
            return "Error: Image not provided or not found.", "{}"
        try:
            kwargs["image"] = load_image(image=image)
        except Exception as e:
          return f"Error loading image: {e}", "{}"

    try:
        #Ensure Predictor is Loaded
        if predictor is None:
            return "Error: Model not initialized. Please reload the Space.", "{}"

        output = predictor.inference(kwargs)
        save_dir = f"./result/{task_type}"
        os.makedirs(save_dir, exist_ok=True)
        video_out_file = f"{save_dir}/{prompt[:100].replace('/','')}_{int(seed)}.mp4"  # Ensure seed is an integer
        print(f"Generating video, local path: {video_out_file}")
        export_to_video(output, video_out_file, fps=24)
        return video_out_file, str(kwargs)  # Return kwargs as a string

    except Exception as e:
        return f"Error during video generation: {e}", "{}"

# --- Gradio Interface ---
# We'll define a single interface that handles BOTH i2v and t2v
with gr.Blocks() as demo:
    with gr.Row():
        task_type_dropdown = gr.Dropdown(
            choices=["i2v", "t2v"], label="Task Type", value="t2v"
        )  # Default to t2v
        load_model_button = gr.Button("Load Model")
        model_status = gr.Textbox(label="Model Status")
    with gr.Row():
        with gr.Column():  # Use Columns for better layout
            prompt = gr.Textbox(label="Input Prompt")
            seed = gr.Number(label="Random Seed", value=-1)
            image = gr.Image(label="Upload Image (for i2v)", type="filepath")
            submit_button = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")
            output_params = gr.Textbox(label="Output Parameters")

    # Load Model Button Logic
    load_model_button.click(
        fn=init_predictor,
        inputs=[task_type_dropdown],
        outputs=[model_status]
    )

    # Submit Button Logic (Handles both i2v and t2v)
    submit_button.click(
        fn=generate_video,
        inputs=[prompt, seed, image, task_type_dropdown],  # Include task_type
        outputs=[output_video, output_params],
    )

# --- Launch the App ---
# No need for argparse in app.py for Hugging Face Spaces
demo.launch() # Don't use demo.launch() inside HuggingFace Spaces.