WiNE-iNEFF commited on
Commit
edeab3f
·
1 Parent(s): d0b80da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -19,6 +19,8 @@ device = (
19
 
20
  pipeline_name = 'WiNE-iNEFF/Minecraft-Skin-Diffusion-V2'
21
  image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
 
 
22
 
23
  def show_images_save(x):
24
  grid = torchvision.utils.make_grid(x, nrow=4, padding=0)
@@ -26,16 +28,14 @@ def show_images_save(x):
26
  grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
27
  return grid_im
28
 
29
- def color_loss(images, target_color=(0.1, 0.9, 0.5, 1)):
30
  target = (torch.tensor(target_color).to(images.device) * 2 - 1)
31
  target = target[None, :, None, None]
32
  error = torch.abs(images - target).mean()
33
  return error
34
 
35
  def generate():
36
- scheduler = DDIMScheduler.from_pretrained(pipeline_name)
37
- scheduler.set_timesteps(num_inference_steps=20)
38
- x = torch.randn(12, 4, 64, 64).to(device)
39
  for i, t in enumerate(scheduler.timesteps):
40
  model_input = scheduler.scale_model_input(x, t)
41
  with torch.no_grad():
@@ -104,7 +104,7 @@ with demo:
104
  with gr.Row().style(equal_height=True):
105
  picker = gr.ColorPicker(label="color", value="#55FFAA")
106
  slide = gr.Slider(label="guidance_scale", minimum=0, maximum=100, value=50)
107
- gall = gr.Gallery(elem_id='gallery').style(grid=[4])
108
  greet_btn = gr.Button("Generate")
109
  greet_btn.click(fn=ex_g, inputs=[picker, slide], outputs=gall)
110
  gr.HTML(
 
19
 
20
  pipeline_name = 'WiNE-iNEFF/Minecraft-Skin-Diffusion-V2'
21
  image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
22
+ scheduler = DDIMScheduler.from_pretrained(pipeline_name)
23
+ scheduler.set_timesteps(num_inference_steps=20)
24
 
25
  def show_images_save(x):
26
  grid = torchvision.utils.make_grid(x, nrow=4, padding=0)
 
28
  grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
29
  return grid_im
30
 
31
+ def color_loss(images, target_color=(0, 0, 0, 1)):
32
  target = (torch.tensor(target_color).to(images.device) * 2 - 1)
33
  target = target[None, :, None, None]
34
  error = torch.abs(images - target).mean()
35
  return error
36
 
37
  def generate():
38
+ x = torch.randn(8, 4, 64, 64).to(device)
 
 
39
  for i, t in enumerate(scheduler.timesteps):
40
  model_input = scheduler.scale_model_input(x, t)
41
  with torch.no_grad():
 
104
  with gr.Row().style(equal_height=True):
105
  picker = gr.ColorPicker(label="color", value="#55FFAA")
106
  slide = gr.Slider(label="guidance_scale", minimum=0, maximum=100, value=50)
107
+ gall = gr.Gallery(elem_id='gallery').style(grid=[4])
108
  greet_btn = gr.Button("Generate")
109
  greet_btn.click(fn=ex_g, inputs=[picker, slide], outputs=gall)
110
  gr.HTML(