Ziqi commited on
Commit
74df120
·
1 Parent(s): 3441cff
Files changed (2) hide show
  1. app.py +17 -11
  2. inference.py +6 -4
app.py CHANGED
@@ -131,7 +131,11 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
131
  maximum=10.,
132
  step=1,
133
  value=10)
134
-
 
 
 
 
135
  run_button = gr.Button('Generate')
136
 
137
  # gr.Markdown('''
@@ -146,23 +150,25 @@ def create_inference_demo(func: inference_fn) -> gr.Blocks:
146
  # inputs=None,
147
  # outputs=weight_name)
148
  prompt.submit(fn=func,
149
- inputs=[
150
- model_id,
151
- prompt,
152
- num_samples,
153
- guidance_scale,
154
- ],
155
- outputs=result,
156
- queue=False)
 
157
  run_button.click(fn=func,
158
  inputs=[
159
  model_id,
160
  prompt,
161
  num_samples,
162
  guidance_scale,
 
163
  ],
164
- outputs=result,
165
- queue=False)
166
  return demo
167
 
168
 
 
131
  maximum=10.,
132
  step=1,
133
  value=10)
134
+ ddim_steps = gr.Slider(label='Number of DDIM Sampling Steps',
135
+ minimum=10,
136
+ maximum=100,
137
+ step=1,
138
+ value=50)
139
  run_button = gr.Button('Generate')
140
 
141
  # gr.Markdown('''
 
150
  # inputs=None,
151
  # outputs=weight_name)
152
  prompt.submit(fn=func,
153
+ inputs=[
154
+ model_id,
155
+ prompt,
156
+ num_samples,
157
+ guidance_scale,
158
+ ddim_steps,
159
+ ],
160
+ outputs=result,
161
+ queue=False)
162
  run_button.click(fn=func,
163
  inputs=[
164
  model_id,
165
  prompt,
166
  num_samples,
167
  guidance_scale,
168
+ ddim_steps,
169
  ],
170
+ outputs=result,
171
+ queue=False)
172
  return demo
173
 
174
 
inference.py CHANGED
@@ -46,12 +46,14 @@ def inference_fn(
46
  prompt: str,
47
  num_samples: int,
48
  guidance_scale: float,
 
49
  ) -> PIL.Image.Image:
50
 
51
  # create inference pipeline
52
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
53
- pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to(device)
54
-
 
55
  # make directory to save images
56
  image_root_folder = os.path.join('experiments', model_id, 'inference')
57
  os.makedirs(image_root_folder, exist_ok = True)
@@ -80,7 +82,7 @@ def inference_fn(
80
  os.makedirs(image_folder, exist_ok = True)
81
 
82
  # batch generation
83
- images = pipe(prompt, num_inference_steps=50, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
84
 
85
  # save generated images
86
  for idx, image in enumerate(images):
 
46
  prompt: str,
47
  num_samples: int,
48
  guidance_scale: float,
49
+ ddim_steps: int,
50
  ) -> PIL.Image.Image:
51
 
52
  # create inference pipeline
53
+ if torch.cuda.is_available():
54
+ pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
55
+ else:
56
+ pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')
57
  # make directory to save images
58
  image_root_folder = os.path.join('experiments', model_id, 'inference')
59
  os.makedirs(image_root_folder, exist_ok = True)
 
82
  os.makedirs(image_folder, exist_ok = True)
83
 
84
  # batch generation
85
+ images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
86
 
87
  # save generated images
88
  for idx, image in enumerate(images):