My-AI-Projects commited on
Commit
27148b8
Β·
verified Β·
1 Parent(s): 9ab91a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -10,9 +10,9 @@ from diffusers import EulerDiscreteScheduler
10
  import gradio as gr
11
 
12
  # Download the model files
13
- #ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
14
 
15
- # Function to load models on demand
16
  def load_models():
17
  text_encoder = ChatGLMModel.from_pretrained(
18
  os.path.join(ckpt_dir, 'text_encoder'),
@@ -29,21 +29,20 @@ def load_models():
29
  unet=unet,
30
  scheduler=scheduler,
31
  force_zeros_for_empty_prompt=False
32
- ).to("cuda")
33
 
34
  # Create a global variable to hold the pipeline
35
  pipe = load_models()
36
 
37
- @spaces.GPU(duration=200)
38
  def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, progress=gr.Progress(track_tqdm=True)):
39
  if use_random_seed:
40
  seed = random.randint(0, 2**32 - 1)
41
  else:
42
  seed = int(seed) # Ensure seed is an integer
43
 
44
- # Move the model to the GPU for inference and clear unnecessary variables
45
  with torch.no_grad():
46
- generator = torch.Generator(pipe.device).manual_seed(seed)
47
  result = pipe(
48
  prompt=prompt,
49
  negative_prompt=negative_prompt,
@@ -58,8 +57,6 @@ def generate_image(prompt, negative_prompt, height, width, num_inference_steps,
58
 
59
  return image, seed
60
 
61
-
62
-
63
  # Gradio interface
64
  iface = gr.Interface(
65
  fn=generate_image,
 
10
  import gradio as gr
11
 
12
  # Download the model files
13
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
14
 
15
+ # Function to load models
16
  def load_models():
17
  text_encoder = ChatGLMModel.from_pretrained(
18
  os.path.join(ckpt_dir, 'text_encoder'),
 
29
  unet=unet,
30
  scheduler=scheduler,
31
  force_zeros_for_empty_prompt=False
32
+ )
33
 
34
  # Create a global variable to hold the pipeline
35
  pipe = load_models()
36
 
 
37
  def generate_image(prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, use_random_seed, seed, progress=gr.Progress(track_tqdm=True)):
38
  if use_random_seed:
39
  seed = random.randint(0, 2**32 - 1)
40
  else:
41
  seed = int(seed) # Ensure seed is an integer
42
 
43
+ # Move the model to the CPU for inference and clear unnecessary variables
44
  with torch.no_grad():
45
+ generator = torch.Generator().manual_seed(seed)
46
  result = pipe(
47
  prompt=prompt,
48
  negative_prompt=negative_prompt,
 
57
 
58
  return image, seed
59
 
 
 
60
  # Gradio interface
61
  iface = gr.Interface(
62
  fn=generate_image,