dragynir commited on
Commit
d50676a
·
1 Parent(s): 4f8bfe3

add adaptive model

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. src/pipeline.py +16 -5
README.md CHANGED
@@ -24,4 +24,5 @@ a handsome man relaxing in a chair, shirt widely unbuttoned, eyes closed, close
24
  - [ ] добавить caption.csv в data/
25
  - [ ] добавить adaptive resize
26
  - [ ] прокинуть параметры в демке (seed и т д), + adaptive resize размеры
 
27
  - [ ] настроить запуск в hugging space
 
24
  - [ ] добавить caption.csv в data/
25
  - [ ] добавить adaptive resize
26
  - [ ] прокинуть параметры в демке (seed и т д), + adaptive resize размеры
27
+ - [ ] разобраться с выставлением device в пайплайн
28
  - [ ] настроить запуск в hugging space
src/pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  from dataclasses import dataclass
 
2
 
3
  from PIL import Image
4
  import numpy as np
@@ -49,8 +50,11 @@ class FashionPipeline:
49
 
50
  # extract segmentation mask
51
  segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
 
 
52
  control_mask = self.adaptive_resize(
53
- self.create_control_image(segm_mask),
 
54
  target_image_size=resolution,
55
  )
56
 
@@ -84,13 +88,20 @@ class FashionPipeline:
84
  ch3 = (segm_mask == 3) * 255 # Full body(blue).
85
  return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB')
86
 
87
- def adaptive_resize(self, image, target_image_size=512, max_image_size=768, divisible=64):
 
 
 
 
 
 
 
88
 
89
  assert target_image_size % divisible == 0
90
  assert max_image_size % divisible == 0
91
  assert max_image_size >= target_image_size
92
 
93
- width, height = image.size
94
  aspect_ratio = width / height
95
 
96
  if height > width:
@@ -116,14 +127,14 @@ class FashionPipeline:
116
  self.controlnet = ControlNetModel.from_pretrained(
117
  self.config.controlnet_path,
118
  torch_dtype=torch.float16,
119
- device_map="auto",
120
  )
121
 
122
  self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
123
  self.config.base_model_path,
124
  controlnet=self.controlnet,
125
  torch_dtype=torch.float16,
126
- device_map="auto",
127
  )
128
 
129
  self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
 
1
  from dataclasses import dataclass
2
+ from typing import Tuple
3
 
4
  from PIL import Image
5
  import numpy as np
 
50
 
51
  # extract segmentation mask
52
  segm_mask = generate_mask(control_image, self.segmentation_model, device=self.device)
53
+ control_mask = self.create_control_image(segm_mask)
54
+
55
  control_mask = self.adaptive_resize(
56
+ image=control_mask,
57
+ initial_shape=(control_image.shape[1], control_image.shape[0]),
58
  target_image_size=resolution,
59
  )
60
 
 
88
  ch3 = (segm_mask == 3) * 255 # Full body(blue).
89
  return Image.fromarray(np.stack([ch1, ch2, ch3], axis=-1).astype('uint8'), 'RGB')
90
 
91
+ def adaptive_resize(
92
+ self,
93
+ image: Image,
94
+ initial_shape: Tuple[int, int],
95
+ target_image_size: int = 512,
96
+ max_image_size: int = 768,
97
+ divisible: int = 64,
98
+ ):
99
 
100
  assert target_image_size % divisible == 0
101
  assert max_image_size % divisible == 0
102
  assert max_image_size >= target_image_size
103
 
104
+ width, height = initial_shape
105
  aspect_ratio = width / height
106
 
107
  if height > width:
 
127
  self.controlnet = ControlNetModel.from_pretrained(
128
  self.config.controlnet_path,
129
  torch_dtype=torch.float16,
130
+ # device_map="auto",
131
  )
132
 
133
  self.pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
134
  self.config.base_model_path,
135
  controlnet=self.controlnet,
136
  torch_dtype=torch.float16,
137
+ # device_map="auto",
138
  )
139
 
140
  self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)