angelahzyuan commited on
Commit
375e6d9
·
verified ·
1 Parent(s): c8055b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -3,6 +3,8 @@ from diffusers import StableDiffusionPipeline, UNet2DConditionModel
3
  import torch
4
  import random
5
  import numpy as np
 
 
6
 
7
  MODEL="UCLA-AGI/SPIN-Diffusion-iter3"
8
 
@@ -15,19 +17,22 @@ def set_seed(seed=5775709):
15
 
16
  set_seed()
17
 
18
- def get_pipeline(device='cpu'):
 
 
19
  model_id = "runwayml/stable-diffusion-v1-5"
20
  #pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker = None, requires_safety_checker = False)
21
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
22
-
23
- # load finetuned model
24
- unet_id = MODEL
25
- unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float32)
26
- pipe.unet = unet
27
- pipe = pipe.to(device)
28
- return pipe
29
 
 
 
 
 
 
 
30
 
 
31
  def generate(prompt: str, num_images: int=5, guidance_scale=7.5):
32
  pipe = get_pipeline()
33
  generator = torch.Generator(pipe.device).manual_seed(5775709)
 
3
  import torch
4
  import random
5
  import numpy as np
6
+ import spaces
7
+
8
 
9
  MODEL="UCLA-AGI/SPIN-Diffusion-iter3"
10
 
 
17
 
18
  set_seed()
19
 
20
+
21
+
22
+ def get_pipeline(device='cuda'):
23
  model_id = "runwayml/stable-diffusion-v1-5"
24
  #pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker = None, requires_safety_checker = False)
25
+ if torch.cuda.is_available():
26
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
 
 
 
 
 
 
27
 
28
+ # load finetuned model
29
+ unet_id = MODEL
30
+ unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float32)
31
+ pipe.unet = unet
32
+ pipe = pipe.to(device)
33
+ return pipe
34
 
35
+ @spaces.GPU(enable_queue=True)
36
  def generate(prompt: str, num_images: int=5, guidance_scale=7.5):
37
  pipe = get_pipeline()
38
  generator = torch.Generator(pipe.device).manual_seed(5775709)