dragynir commited on
Commit
4f8bfe3
·
1 Parent(s): cb96de2

add adaptive model

Browse files
Files changed (5) hide show
  1. README.md +2 -1
  2. app.py +1 -1
  3. config.py +1 -1
  4. src/pipeline.py +26 -1
  5. weights/controlnet_adaptive/config.json +57 -0
README.md CHANGED
@@ -22,5 +22,6 @@ a handsome man relaxing in a chair, shirt widely unbuttoned, eyes closed, close
22
 
23
  - [x] научиться записывать демку (научился Screencastify - поставил плагин в гугл)
24
  - [ ] добавить caption.csv в data/
25
- - [ ] прокинуть параметры в демке (seed и т д)
 
26
  - [ ] настроить запуск в hugging space
 
22
 
23
  - [x] научиться записывать демку (научился Screencastify - поставил плагин в гугл)
24
  - [ ] добавить caption.csv в data/
25
+ - [ ] добавить adaptive resize
26
+ - [ ] прокинуть параметры в демке (seed и т д), + adaptive resize размеры
27
  - [ ] настроить запуск в hugging space
app.py CHANGED
@@ -58,7 +58,7 @@ with block:
58
 
59
  <p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a>
60
  It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then
61
- run StableDiffusionXLControlNetPipeline with trained controlnet to generate image conditioned on this masks.
62
  </p>
63
  """)
64
 
 
58
 
59
  <p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a>
60
  It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then
61
+ run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks.
62
  </p>
63
  """)
64
 
config.py CHANGED
@@ -15,6 +15,6 @@ class PipelineConfig:
15
 
16
  vae_path: str = 'madebyollin/sdxl-vae-fp16-fix'
17
 
18
- controlnet_path: str = os.path.join(weights_path, 'controlnet')
19
 
20
  segmentation_model_path: str = os.path.join(weights_path, 'cloth_segm.pth')
 
15
 
16
  vae_path: str = 'madebyollin/sdxl-vae-fp16-fix'
17
 
18
+ controlnet_path: str = os.path.join(weights_path, 'controlnet_adaptive')
19
 
20
  segmentation_model_path: str = os.path.join(weights_path, 'cloth_segm.pth')
src/pipeline.py CHANGED
@@ -49,7 +49,10 @@ class FashionPipeline:
49
 
50
  # extract segmentation mask
51
  segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
52
- control_mask = self.create_control_image(segm_mask).resize((resolution, resolution))
 
 
 
53
 
54
  segm_mask = self.color_segmentation_mask(segm_mask)
55
 
@@ -81,6 +84,28 @@ class FashionPipeline:
81
  ch3 = (segm_mask == 3) * 255 # Full body(blue).
82
  return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB')
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def __init_pipeline(self):
85
  """Init models and SDXL pipeline."""
86
  self.segmentation_model = load_seg_model(
 
49
 
50
  # extract segmentation mask
51
  segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
52
+ control_mask = self.adaptive_resize(
53
+ self.create_control_image(segm_mask),
54
+ target_image_size=resolution,
55
+ )
56
 
57
  segm_mask = self.color_segmentation_mask(segm_mask)
58
 
 
84
  ch3 = (segm_mask == 3) * 255 # Full body(blue).
85
  return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB')
86
 
87
+ def adaptive_resize(self, image, target_image_size=512, max_image_size=768, divisible=64):
88
+
89
+ assert target_image_size % divisible == 0
90
+ assert max_image_size % divisible == 0
91
+ assert max_image_size >= target_image_size
92
+
93
+ width, height = image.size
94
+ aspect_ratio = width / height
95
+
96
+ if height > width:
97
+ new_width = target_image_size
98
+ new_height = new_width / aspect_ratio
99
+ new_height = (new_height // divisible) * divisible
100
+ new_height = int(min(new_height, max_image_size))
101
+ else:
102
+ new_height = target_image_size
103
+ new_width = new_height / aspect_ratio
104
+ new_width = (new_width // divisible) * divisible
105
+ new_width = int(min(new_width, max_image_size))
106
+
107
+ return image.resize((new_width, new_height))
108
+
109
  def __init_pipeline(self):
110
  """Init models and SDXL pipeline."""
111
  self.segmentation_model = load_seg_model(
weights/controlnet_adaptive/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.25.0.dev0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": "text_time",
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": 256,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20
12
+ ],
13
+ "block_out_channels": [
14
+ 320,
15
+ 640,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_channels": 3,
20
+ "conditioning_embedding_out_channels": [
21
+ 16,
22
+ 32,
23
+ 96,
24
+ 256
25
+ ],
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 2048,
28
+ "down_block_types": [
29
+ "DownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D"
32
+ ],
33
+ "downsample_padding": 1,
34
+ "encoder_hid_dim": null,
35
+ "encoder_hid_dim_type": null,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }