Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,9 +10,9 @@ from diffusers import EulerDiscreteScheduler
|
|
10 |
import gradio as gr
|
11 |
|
12 |
# Download the model files
|
13 |
-
|
14 |
|
15 |
-
# Function to load models
|
16 |
def load_models():
|
17 |
text_encoder = ChatGLMModel.from_pretrained(
|
18 |
os.path.join(ckpt_dir, 'text_encoder'),
|
@@ -29,21 +29,20 @@ def load_models():
|
|
29 |
unet=unet,
|
30 |
scheduler=scheduler,
|
31 |
force_zeros_for_empty_prompt=False
|
32 |
-
)
|
33 |
|
34 |
# Create a global variable to hold the pipeline
|
35 |
pipe = load_models()
|
36 |
|
37 |
-
@spaces.GPU(duration=200)
|
38 |
def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, progress=gr.Progress(track_tqdm=True)):
|
39 |
if use_random_seed:
|
40 |
seed = random.randint(0, 2**32 - 1)
|
41 |
else:
|
42 |
seed = int(seed) # Ensure seed is an integer
|
43 |
|
44 |
-
# Move the model to the
|
45 |
with torch.no_grad():
|
46 |
-
generator = torch.Generator(
|
47 |
result = pipe(
|
48 |
prompt=prompt,
|
49 |
negative_prompt=negative_prompt,
|
@@ -58,8 +57,6 @@ def generate_image(prompt, negative_prompt, height, width, num_inference_steps,
|
|
58 |
|
59 |
return image, seed
|
60 |
|
61 |
-
|
62 |
-
|
63 |
# Gradio interface
|
64 |
iface = gr.Interface(
|
65 |
fn=generate_image,
|
|
|
10 |
import gradio as gr
|
11 |
|
12 |
# Download the model files
|
13 |
+
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
|
14 |
|
15 |
+
# Function to load models
|
16 |
def load_models():
|
17 |
text_encoder = ChatGLMModel.from_pretrained(
|
18 |
os.path.join(ckpt_dir, 'text_encoder'),
|
|
|
29 |
unet=unet,
|
30 |
scheduler=scheduler,
|
31 |
force_zeros_for_empty_prompt=False
|
32 |
+
)
|
33 |
|
34 |
# Create a global variable to hold the pipeline
|
35 |
pipe = load_models()
|
36 |
|
|
|
37 |
def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, progress=gr.Progress(track_tqdm=True)):
|
38 |
if use_random_seed:
|
39 |
seed = random.randint(0, 2**32 - 1)
|
40 |
else:
|
41 |
seed = int(seed) # Ensure seed is an integer
|
42 |
|
43 |
+
# Move the model to the CPU for inference and clear unnecessary variables
|
44 |
with torch.no_grad():
|
45 |
+
generator = torch.Generator().manual_seed(seed)
|
46 |
result = pipe(
|
47 |
prompt=prompt,
|
48 |
negative_prompt=negative_prompt,
|
|
|
57 |
|
58 |
return image, seed
|
59 |
|
|
|
|
|
60 |
# Gradio interface
|
61 |
iface = gr.Interface(
|
62 |
fn=generate_image,
|