Upload 9 files
Browse files- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/language_utils.cpython-38.pyc +0 -0
- utils/__pycache__/options.cpython-38.pyc +0 -0
- utils/__pycache__/util.cpython-38.pyc +0 -0
- utils/language_utils.py +315 -0
- utils/logger.py +112 -0
- utils/options.py +129 -0
- utils/util.py +123 -0
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (119 Bytes). View file
|
|
|
utils/__pycache__/language_utils.cpython-38.pyc
ADDED
|
Binary file (5.73 kB). View file
|
|
|
utils/__pycache__/options.cpython-38.pyc
ADDED
|
Binary file (3.94 kB). View file
|
|
|
utils/__pycache__/util.cpython-38.pyc
ADDED
|
Binary file (3.81 kB). View file
|
|
|
utils/language_utils.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from curses import A_ATTRIBUTES
|
| 2 |
+
|
| 3 |
+
import numpy
|
| 4 |
+
import torch
|
| 5 |
+
from pip import main
|
| 6 |
+
from sentence_transformers import SentenceTransformer, util
|
| 7 |
+
|
| 8 |
+
# predefined shape text
|
| 9 |
+
upper_length_text = [
|
| 10 |
+
'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top',
|
| 11 |
+
'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves',
|
| 12 |
+
'with short sleeves', 'medium-sleeve', 'medium sleeves',
|
| 13 |
+
'with medium sleeves', 'sleeves reach elbow', 'long-sleeve',
|
| 14 |
+
'long sleeves', 'with long sleeves'
|
| 15 |
+
]
|
| 16 |
+
upper_length_attr = {
|
| 17 |
+
'sleeveless': 0,
|
| 18 |
+
'without sleeves': 0,
|
| 19 |
+
'sleeves have been cut off': 0,
|
| 20 |
+
'tank top': 0,
|
| 21 |
+
'tank shirt': 0,
|
| 22 |
+
'muscle shirt': 0,
|
| 23 |
+
'short-sleeve': 1,
|
| 24 |
+
'with short sleeves': 1,
|
| 25 |
+
'short sleeves': 1,
|
| 26 |
+
'medium-sleeve': 2,
|
| 27 |
+
'with medium sleeves': 2,
|
| 28 |
+
'medium sleeves': 2,
|
| 29 |
+
'sleeves reach elbow': 2,
|
| 30 |
+
'long-sleeve': 3,
|
| 31 |
+
'long sleeves': 3,
|
| 32 |
+
'with long sleeves': 3
|
| 33 |
+
}
|
| 34 |
+
lower_length_text = [
|
| 35 |
+
'three-point', 'medium', 'short', 'covering knee', 'cropped',
|
| 36 |
+
'three-quarter', 'long', 'slack', 'of long length'
|
| 37 |
+
]
|
| 38 |
+
lower_length_attr = {
|
| 39 |
+
'three-point': 0,
|
| 40 |
+
'medium': 1,
|
| 41 |
+
'covering knee': 1,
|
| 42 |
+
'short': 1,
|
| 43 |
+
'cropped': 2,
|
| 44 |
+
'three-quarter': 2,
|
| 45 |
+
'long': 3,
|
| 46 |
+
'slack': 3,
|
| 47 |
+
'of long length': 3
|
| 48 |
+
}
|
| 49 |
+
socks_length_text = [
|
| 50 |
+
'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery'
|
| 51 |
+
]
|
| 52 |
+
socks_length_attr = {
|
| 53 |
+
'socks': 0,
|
| 54 |
+
'stocking': 1,
|
| 55 |
+
'pantyhose': 1,
|
| 56 |
+
'leggings': 1,
|
| 57 |
+
'sheer hosiery': 1
|
| 58 |
+
}
|
| 59 |
+
hat_text = ['hat', 'cap', 'chapeau']
|
| 60 |
+
eyeglasses_text = ['sunglasses']
|
| 61 |
+
belt_text = ['belt', 'with a dress tied around the waist']
|
| 62 |
+
outer_shape_text = [
|
| 63 |
+
'with outer clothing open', 'with outer clothing unzipped',
|
| 64 |
+
'covering inner clothes', 'with outer clothing zipped'
|
| 65 |
+
]
|
| 66 |
+
outer_shape_attr = {
|
| 67 |
+
'with outer clothing open': 0,
|
| 68 |
+
'with outer clothing unzipped': 0,
|
| 69 |
+
'covering inner clothes': 1,
|
| 70 |
+
'with outer clothing zipped': 1
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
upper_types = [
|
| 74 |
+
'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee'
|
| 75 |
+
]
|
| 76 |
+
outer_types = [
|
| 77 |
+
'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear',
|
| 78 |
+
'duffle', 'cardigan'
|
| 79 |
+
]
|
| 80 |
+
skirt_types = ['skirt']
|
| 81 |
+
dress_types = ['dress']
|
| 82 |
+
pant_types = ['jeans', 'pants', 'trousers']
|
| 83 |
+
rompers_types = ['rompers', 'bodysuit', 'jumpsuit']
|
| 84 |
+
|
| 85 |
+
attr_names_list = [
|
| 86 |
+
'gender', 'hair length', '0 upper clothing length',
|
| 87 |
+
'1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt',
|
| 88 |
+
'6 opening of outer clothing', '7 upper clothes', '8 outer clothing',
|
| 89 |
+
'9 skirt', '10 dress', '11 pants', '12 rompers'
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def generate_shape_attributes(user_shape_texts):
|
| 94 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 95 |
+
parsed_texts = user_shape_texts.split(',')
|
| 96 |
+
|
| 97 |
+
text_num = len(parsed_texts)
|
| 98 |
+
|
| 99 |
+
human_attr = [0, 0]
|
| 100 |
+
attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0]
|
| 101 |
+
|
| 102 |
+
changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
| 103 |
+
for text_id, text in enumerate(parsed_texts):
|
| 104 |
+
user_embeddings = model.encode(text)
|
| 105 |
+
if ('man' in text) and (text_id == 0):
|
| 106 |
+
human_attr[0] = 0
|
| 107 |
+
human_attr[1] = 0
|
| 108 |
+
|
| 109 |
+
if ('woman' in text or 'lady' in text) and (text_id == 0):
|
| 110 |
+
human_attr[0] = 1
|
| 111 |
+
human_attr[1] = 2
|
| 112 |
+
|
| 113 |
+
if (not changed[0]) and (text_id == 1):
|
| 114 |
+
# upper length
|
| 115 |
+
predefined_embeddings = model.encode(upper_length_text)
|
| 116 |
+
similarities = util.dot_score(user_embeddings,
|
| 117 |
+
predefined_embeddings)
|
| 118 |
+
arg_idx = torch.argmax(similarities).item()
|
| 119 |
+
attr[0] = upper_length_attr[upper_length_text[arg_idx]]
|
| 120 |
+
changed[0] = 1
|
| 121 |
+
|
| 122 |
+
if (not changed[1]) and ((text_num == 2 and text_id == 1) or
|
| 123 |
+
(text_num > 2 and text_id == 2)):
|
| 124 |
+
# lower length
|
| 125 |
+
predefined_embeddings = model.encode(lower_length_text)
|
| 126 |
+
similarities = util.dot_score(user_embeddings,
|
| 127 |
+
predefined_embeddings)
|
| 128 |
+
arg_idx = torch.argmax(similarities).item()
|
| 129 |
+
attr[1] = lower_length_attr[lower_length_text[arg_idx]]
|
| 130 |
+
changed[1] = 1
|
| 131 |
+
|
| 132 |
+
if (not changed[2]) and (text_id > 2):
|
| 133 |
+
# socks length
|
| 134 |
+
predefined_embeddings = model.encode(socks_length_text)
|
| 135 |
+
similarities = util.dot_score(user_embeddings,
|
| 136 |
+
predefined_embeddings)
|
| 137 |
+
arg_idx = torch.argmax(similarities).item()
|
| 138 |
+
if similarities[0][arg_idx] > 0.7:
|
| 139 |
+
attr[2] = arg_idx + 1
|
| 140 |
+
changed[2] = 1
|
| 141 |
+
|
| 142 |
+
if (not changed[3]) and (text_id > 2):
|
| 143 |
+
# hat
|
| 144 |
+
predefined_embeddings = model.encode(hat_text)
|
| 145 |
+
similarities = util.dot_score(user_embeddings,
|
| 146 |
+
predefined_embeddings)
|
| 147 |
+
if similarities[0][0] > 0.7:
|
| 148 |
+
attr[3] = 1
|
| 149 |
+
changed[3] = 1
|
| 150 |
+
|
| 151 |
+
if (not changed[4]) and (text_id > 2):
|
| 152 |
+
# glasses
|
| 153 |
+
predefined_embeddings = model.encode(eyeglasses_text)
|
| 154 |
+
similarities = util.dot_score(user_embeddings,
|
| 155 |
+
predefined_embeddings)
|
| 156 |
+
arg_idx = torch.argmax(similarities).item()
|
| 157 |
+
if similarities[0][arg_idx] > 0.7:
|
| 158 |
+
attr[4] = arg_idx + 1
|
| 159 |
+
changed[4] = 1
|
| 160 |
+
|
| 161 |
+
if (not changed[5]) and (text_id > 2):
|
| 162 |
+
# belt
|
| 163 |
+
predefined_embeddings = model.encode(belt_text)
|
| 164 |
+
similarities = util.dot_score(user_embeddings,
|
| 165 |
+
predefined_embeddings)
|
| 166 |
+
arg_idx = torch.argmax(similarities).item()
|
| 167 |
+
if similarities[0][arg_idx] > 0.7:
|
| 168 |
+
attr[5] = arg_idx + 1
|
| 169 |
+
changed[5] = 1
|
| 170 |
+
|
| 171 |
+
if (not changed[6]) and (text_id == 3):
|
| 172 |
+
# outer coverage
|
| 173 |
+
predefined_embeddings = model.encode(outer_shape_text)
|
| 174 |
+
similarities = util.dot_score(user_embeddings,
|
| 175 |
+
predefined_embeddings)
|
| 176 |
+
arg_idx = torch.argmax(similarities).item()
|
| 177 |
+
if similarities[0][arg_idx] > 0.7:
|
| 178 |
+
attr[6] = arg_idx
|
| 179 |
+
changed[6] = 1
|
| 180 |
+
|
| 181 |
+
if (not changed[10]) and (text_num == 2 and text_id == 1):
|
| 182 |
+
# dress_types
|
| 183 |
+
predefined_embeddings = model.encode(dress_types)
|
| 184 |
+
similarities = util.dot_score(user_embeddings,
|
| 185 |
+
predefined_embeddings)
|
| 186 |
+
similarity_skirt = util.dot_score(user_embeddings,
|
| 187 |
+
model.encode(skirt_types))
|
| 188 |
+
if similarities[0][0] > 0.5 and similarities[0][
|
| 189 |
+
0] > similarity_skirt[0][0]:
|
| 190 |
+
attr[10] = 1
|
| 191 |
+
attr[7] = 0
|
| 192 |
+
attr[8] = 0
|
| 193 |
+
attr[9] = 0
|
| 194 |
+
attr[11] = 0
|
| 195 |
+
attr[12] = 0
|
| 196 |
+
|
| 197 |
+
changed[0] = 1
|
| 198 |
+
changed[10] = 1
|
| 199 |
+
changed[7] = 1
|
| 200 |
+
changed[8] = 1
|
| 201 |
+
changed[9] = 1
|
| 202 |
+
changed[11] = 1
|
| 203 |
+
changed[12] = 1
|
| 204 |
+
|
| 205 |
+
if (not changed[12]) and (text_num == 2 and text_id == 1):
|
| 206 |
+
# rompers_types
|
| 207 |
+
predefined_embeddings = model.encode(rompers_types)
|
| 208 |
+
similarities = util.dot_score(user_embeddings,
|
| 209 |
+
predefined_embeddings)
|
| 210 |
+
max_similarity = torch.max(similarities).item()
|
| 211 |
+
if max_similarity > 0.6:
|
| 212 |
+
attr[12] = 1
|
| 213 |
+
attr[7] = 0
|
| 214 |
+
attr[8] = 0
|
| 215 |
+
attr[9] = 0
|
| 216 |
+
attr[10] = 0
|
| 217 |
+
attr[11] = 0
|
| 218 |
+
|
| 219 |
+
changed[12] = 1
|
| 220 |
+
changed[7] = 1
|
| 221 |
+
changed[8] = 1
|
| 222 |
+
changed[9] = 1
|
| 223 |
+
changed[10] = 1
|
| 224 |
+
changed[11] = 1
|
| 225 |
+
|
| 226 |
+
if (not changed[7]) and (text_num > 2 and text_id == 1):
|
| 227 |
+
# upper_types
|
| 228 |
+
predefined_embeddings = model.encode(upper_types)
|
| 229 |
+
similarities = util.dot_score(user_embeddings,
|
| 230 |
+
predefined_embeddings)
|
| 231 |
+
max_similarity = torch.max(similarities).item()
|
| 232 |
+
if max_similarity > 0.6:
|
| 233 |
+
attr[7] = 1
|
| 234 |
+
changed[7] = 1
|
| 235 |
+
|
| 236 |
+
if (not changed[8]) and (text_id == 3):
|
| 237 |
+
# outer_types
|
| 238 |
+
predefined_embeddings = model.encode(outer_types)
|
| 239 |
+
similarities = util.dot_score(user_embeddings,
|
| 240 |
+
predefined_embeddings)
|
| 241 |
+
arg_idx = torch.argmax(similarities).item()
|
| 242 |
+
if similarities[0][arg_idx] > 0.7:
|
| 243 |
+
attr[6] = outer_shape_attr[outer_shape_text[arg_idx]]
|
| 244 |
+
attr[8] = 1
|
| 245 |
+
changed[8] = 1
|
| 246 |
+
|
| 247 |
+
if (not changed[9]) and (text_num > 2 and text_id == 2):
|
| 248 |
+
# skirt_types
|
| 249 |
+
predefined_embeddings = model.encode(skirt_types)
|
| 250 |
+
similarity_skirt = util.dot_score(user_embeddings,
|
| 251 |
+
predefined_embeddings)
|
| 252 |
+
similarity_dress = util.dot_score(user_embeddings,
|
| 253 |
+
model.encode(dress_types))
|
| 254 |
+
if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][
|
| 255 |
+
0] > similarity_dress[0][0]:
|
| 256 |
+
attr[9] = 1
|
| 257 |
+
attr[10] = 0
|
| 258 |
+
changed[9] = 1
|
| 259 |
+
changed[10] = 1
|
| 260 |
+
|
| 261 |
+
if (not changed[11]) and (text_num > 2 and text_id == 2):
|
| 262 |
+
# pant_types
|
| 263 |
+
predefined_embeddings = model.encode(pant_types)
|
| 264 |
+
similarities = util.dot_score(user_embeddings,
|
| 265 |
+
predefined_embeddings)
|
| 266 |
+
max_similarity = torch.max(similarities).item()
|
| 267 |
+
if max_similarity > 0.6:
|
| 268 |
+
attr[11] = 1
|
| 269 |
+
attr[9] = 0
|
| 270 |
+
attr[10] = 0
|
| 271 |
+
attr[12] = 0
|
| 272 |
+
changed[11] = 1
|
| 273 |
+
changed[9] = 1
|
| 274 |
+
changed[10] = 1
|
| 275 |
+
changed[12] = 1
|
| 276 |
+
|
| 277 |
+
return human_attr + attr
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def generate_texture_attributes(user_text):
|
| 281 |
+
parsed_texts = user_text.split(',')
|
| 282 |
+
|
| 283 |
+
attr = []
|
| 284 |
+
for text in parsed_texts:
|
| 285 |
+
if ('pure color' in text) or ('solid color' in text):
|
| 286 |
+
attr.append(4)
|
| 287 |
+
elif ('spline' in text) or ('stripe' in text):
|
| 288 |
+
attr.append(3)
|
| 289 |
+
elif ('plaid' in text) or ('lattice' in text):
|
| 290 |
+
attr.append(5)
|
| 291 |
+
elif 'floral' in text:
|
| 292 |
+
attr.append(1)
|
| 293 |
+
elif 'denim' in text:
|
| 294 |
+
attr.append(0)
|
| 295 |
+
else:
|
| 296 |
+
attr.append(17)
|
| 297 |
+
|
| 298 |
+
if len(attr) == 1:
|
| 299 |
+
attr.append(attr[0])
|
| 300 |
+
attr.append(17)
|
| 301 |
+
|
| 302 |
+
if len(attr) == 2:
|
| 303 |
+
attr.append(17)
|
| 304 |
+
|
| 305 |
+
return attr
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
user_request = input('Enter your request: ')
|
| 310 |
+
while user_request != '\\q':
|
| 311 |
+
attr = generate_shape_attributes(user_request)
|
| 312 |
+
print(attr)
|
| 313 |
+
for attr_name, attr_value in zip(attr_names_list, attr):
|
| 314 |
+
print(attr_name, attr_value)
|
| 315 |
+
user_request = input('Enter your request: ')
|
utils/logger.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MessageLogger():
|
| 7 |
+
"""Message logger for printing.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
opt (dict): Config. It contains the following keys:
|
| 11 |
+
name (str): Exp name.
|
| 12 |
+
logger (dict): Contains 'print_freq' (str) for logger interval.
|
| 13 |
+
train (dict): Contains 'niter' (int) for total iters.
|
| 14 |
+
use_tb_logger (bool): Use tensorboard logger.
|
| 15 |
+
start_iter (int): Start iter. Default: 1.
|
| 16 |
+
tb_logger (obj:`tb_logger`): Tensorboard logger. DefaultοΌ None.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, opt, start_iter=1, tb_logger=None):
|
| 20 |
+
self.exp_name = opt['name']
|
| 21 |
+
self.interval = opt['print_freq']
|
| 22 |
+
self.start_iter = start_iter
|
| 23 |
+
self.max_iters = opt['max_iters']
|
| 24 |
+
self.use_tb_logger = opt['use_tb_logger']
|
| 25 |
+
self.tb_logger = tb_logger
|
| 26 |
+
self.start_time = time.time()
|
| 27 |
+
self.logger = get_root_logger()
|
| 28 |
+
|
| 29 |
+
def __call__(self, log_vars):
|
| 30 |
+
"""Format logging message.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
log_vars (dict): It contains the following keys:
|
| 34 |
+
epoch (int): Epoch number.
|
| 35 |
+
iter (int): Current iter.
|
| 36 |
+
lrs (list): List for learning rates.
|
| 37 |
+
|
| 38 |
+
time (float): Iter time.
|
| 39 |
+
data_time (float): Data time for each iter.
|
| 40 |
+
"""
|
| 41 |
+
# epoch, iter, learning rates
|
| 42 |
+
epoch = log_vars.pop('epoch')
|
| 43 |
+
current_iter = log_vars.pop('iter')
|
| 44 |
+
lrs = log_vars.pop('lrs')
|
| 45 |
+
|
| 46 |
+
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
|
| 47 |
+
f'iter:{current_iter:8,d}, lr:(')
|
| 48 |
+
for v in lrs:
|
| 49 |
+
message += f'{v:.3e},'
|
| 50 |
+
message += ')] '
|
| 51 |
+
|
| 52 |
+
# time and estimated time
|
| 53 |
+
if 'time' in log_vars.keys():
|
| 54 |
+
iter_time = log_vars.pop('time')
|
| 55 |
+
data_time = log_vars.pop('data_time')
|
| 56 |
+
|
| 57 |
+
total_time = time.time() - self.start_time
|
| 58 |
+
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
| 59 |
+
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
| 60 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
| 61 |
+
message += f'[eta: {eta_str}, '
|
| 62 |
+
message += f'time: {iter_time:.3f}, data_time: {data_time:.3f}] '
|
| 63 |
+
|
| 64 |
+
# other items, especially losses
|
| 65 |
+
for k, v in log_vars.items():
|
| 66 |
+
message += f'{k}: {v:.4e} '
|
| 67 |
+
# tensorboard logger
|
| 68 |
+
if self.use_tb_logger and 'debug' not in self.exp_name:
|
| 69 |
+
self.tb_logger.add_scalar(k, v, current_iter)
|
| 70 |
+
|
| 71 |
+
self.logger.info(message)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def init_tb_logger(log_dir):
|
| 75 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 76 |
+
tb_logger = SummaryWriter(log_dir=log_dir)
|
| 77 |
+
return tb_logger
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_root_logger(logger_name='base', log_level=logging.INFO, log_file=None):
|
| 81 |
+
"""Get the root logger.
|
| 82 |
+
|
| 83 |
+
The logger will be initialized if it has not been initialized. By default a
|
| 84 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
| 85 |
+
also be added.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
logger_name (str): root logger name. Default: base.
|
| 89 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
| 90 |
+
will be added to the root logger.
|
| 91 |
+
log_level (int): The root logger level. Note that only the process of
|
| 92 |
+
rank 0 is affected, while other processes will set the level to
|
| 93 |
+
"Error" and be silent most of the time.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
logging.Logger: The root logger.
|
| 97 |
+
"""
|
| 98 |
+
logger = logging.getLogger(logger_name)
|
| 99 |
+
# if the logger has been initialized, just return it
|
| 100 |
+
if logger.hasHandlers():
|
| 101 |
+
return logger
|
| 102 |
+
|
| 103 |
+
format_str = '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s'
|
| 104 |
+
logging.basicConfig(format=format_str, level=log_level)
|
| 105 |
+
|
| 106 |
+
if log_file is not None:
|
| 107 |
+
file_handler = logging.FileHandler(log_file, 'w')
|
| 108 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
| 109 |
+
file_handler.setLevel(log_level)
|
| 110 |
+
logger.addHandler(file_handler)
|
| 111 |
+
|
| 112 |
+
return logger
|
utils/options.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def ordered_yaml():
|
| 9 |
+
"""Support OrderedDict for yaml.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
yaml Loader and Dumper.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
from yaml import CDumper as Dumper
|
| 16 |
+
from yaml import CLoader as Loader
|
| 17 |
+
except ImportError:
|
| 18 |
+
from yaml import Dumper, Loader
|
| 19 |
+
|
| 20 |
+
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
| 21 |
+
|
| 22 |
+
def dict_representer(dumper, data):
|
| 23 |
+
return dumper.represent_dict(data.items())
|
| 24 |
+
|
| 25 |
+
def dict_constructor(loader, node):
|
| 26 |
+
return OrderedDict(loader.construct_pairs(node))
|
| 27 |
+
|
| 28 |
+
Dumper.add_representer(OrderedDict, dict_representer)
|
| 29 |
+
Loader.add_constructor(_mapping_tag, dict_constructor)
|
| 30 |
+
return Loader, Dumper
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def parse(opt_path, is_train=True):
|
| 34 |
+
"""Parse option file.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
opt_path (str): Option file path.
|
| 38 |
+
is_train (str): Indicate whether in training or not. Default: True.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
(dict): Options.
|
| 42 |
+
"""
|
| 43 |
+
with open(opt_path, mode='r') as f:
|
| 44 |
+
Loader, _ = ordered_yaml()
|
| 45 |
+
opt = yaml.load(f, Loader=Loader)
|
| 46 |
+
|
| 47 |
+
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
| 48 |
+
if opt.get('set_CUDA_VISIBLE_DEVICES', None):
|
| 49 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
| 50 |
+
print('export CUDA_VISIBLE_DEVICES=' + gpu_list, flush=True)
|
| 51 |
+
else:
|
| 52 |
+
print('gpu_list: ', gpu_list, flush=True)
|
| 53 |
+
|
| 54 |
+
opt['is_train'] = is_train
|
| 55 |
+
|
| 56 |
+
# paths
|
| 57 |
+
opt['path'] = {}
|
| 58 |
+
opt['path']['root'] = osp.abspath(
|
| 59 |
+
osp.join(__file__, osp.pardir, osp.pardir))
|
| 60 |
+
if is_train:
|
| 61 |
+
experiments_root = osp.join(opt['path']['root'], 'experiments',
|
| 62 |
+
opt['name'])
|
| 63 |
+
opt['path']['experiments_root'] = experiments_root
|
| 64 |
+
opt['path']['models'] = osp.join(experiments_root, 'models')
|
| 65 |
+
opt['path']['log'] = experiments_root
|
| 66 |
+
opt['path']['visualization'] = osp.join(experiments_root,
|
| 67 |
+
'visualization')
|
| 68 |
+
|
| 69 |
+
# change some options for debug mode
|
| 70 |
+
if 'debug' in opt['name']:
|
| 71 |
+
opt['debug'] = True
|
| 72 |
+
opt['val_freq'] = 1
|
| 73 |
+
opt['print_freq'] = 1
|
| 74 |
+
opt['save_checkpoint_freq'] = 1
|
| 75 |
+
else: # test
|
| 76 |
+
results_root = osp.join(opt['path']['root'], 'results', opt['name'])
|
| 77 |
+
opt['path']['results_root'] = results_root
|
| 78 |
+
opt['path']['log'] = results_root
|
| 79 |
+
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
| 80 |
+
|
| 81 |
+
return opt
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def dict2str(opt, indent_level=1):
|
| 85 |
+
"""dict to string for printing options.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
opt (dict): Option dict.
|
| 89 |
+
indent_level (int): Indent level. Default: 1.
|
| 90 |
+
|
| 91 |
+
Return:
|
| 92 |
+
(str): Option string for printing.
|
| 93 |
+
"""
|
| 94 |
+
msg = ''
|
| 95 |
+
for k, v in opt.items():
|
| 96 |
+
if isinstance(v, dict):
|
| 97 |
+
msg += ' ' * (indent_level * 2) + k + ':[\n'
|
| 98 |
+
msg += dict2str(v, indent_level + 1)
|
| 99 |
+
msg += ' ' * (indent_level * 2) + ']\n'
|
| 100 |
+
else:
|
| 101 |
+
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
| 102 |
+
return msg
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class NoneDict(dict):
|
| 106 |
+
"""None dict. It will return none if key is not in the dict."""
|
| 107 |
+
|
| 108 |
+
def __missing__(self, key):
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def dict_to_nonedict(opt):
|
| 113 |
+
"""Convert to NoneDict, which returns None for missing keys.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
opt (dict): Option dict.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
(dict): NoneDict for options.
|
| 120 |
+
"""
|
| 121 |
+
if isinstance(opt, dict):
|
| 122 |
+
new_opt = dict()
|
| 123 |
+
for key, sub_opt in opt.items():
|
| 124 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
| 125 |
+
return NoneDict(**new_opt)
|
| 126 |
+
elif isinstance(opt, list):
|
| 127 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
| 128 |
+
else:
|
| 129 |
+
return opt
|
utils/util.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from shutil import get_terminal_size
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger('base')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def make_exp_dirs(opt):
|
| 15 |
+
"""Make dirs for experiments."""
|
| 16 |
+
path_opt = opt['path'].copy()
|
| 17 |
+
if opt['is_train']:
|
| 18 |
+
overwrite = True if 'debug' in opt['name'] else False
|
| 19 |
+
os.makedirs(path_opt.pop('experiments_root'), exist_ok=overwrite)
|
| 20 |
+
os.makedirs(path_opt.pop('models'), exist_ok=overwrite)
|
| 21 |
+
else:
|
| 22 |
+
os.makedirs(path_opt.pop('results_root'))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def set_random_seed(seed):
|
| 26 |
+
"""Set random seeds."""
|
| 27 |
+
random.seed(seed)
|
| 28 |
+
np.random.seed(seed)
|
| 29 |
+
torch.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed(seed)
|
| 31 |
+
torch.cuda.manual_seed_all(seed)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ProgressBar(object):
|
| 35 |
+
"""A progress bar which can print the progress.
|
| 36 |
+
|
| 37 |
+
Modified from:
|
| 38 |
+
https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, task_num=0, bar_width=50, start=True):
|
| 42 |
+
self.task_num = task_num
|
| 43 |
+
max_bar_width = self._get_max_bar_width()
|
| 44 |
+
self.bar_width = (
|
| 45 |
+
bar_width if bar_width <= max_bar_width else max_bar_width)
|
| 46 |
+
self.completed = 0
|
| 47 |
+
if start:
|
| 48 |
+
self.start()
|
| 49 |
+
|
| 50 |
+
def _get_max_bar_width(self):
|
| 51 |
+
terminal_width, _ = get_terminal_size()
|
| 52 |
+
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
|
| 53 |
+
if max_bar_width < 10:
|
| 54 |
+
print(f'terminal width is too small ({terminal_width}), '
|
| 55 |
+
'please consider widen the terminal for better '
|
| 56 |
+
'progressbar visualization')
|
| 57 |
+
max_bar_width = 10
|
| 58 |
+
return max_bar_width
|
| 59 |
+
|
| 60 |
+
def start(self):
|
| 61 |
+
if self.task_num > 0:
|
| 62 |
+
sys.stdout.write(f"[{' ' * self.bar_width}] 0/{self.task_num}, "
|
| 63 |
+
f'elapsed: 0s, ETA:\nStart...\n')
|
| 64 |
+
else:
|
| 65 |
+
sys.stdout.write('completed: 0, elapsed: 0s')
|
| 66 |
+
sys.stdout.flush()
|
| 67 |
+
self.start_time = time.time()
|
| 68 |
+
|
| 69 |
+
def update(self, msg='In progress...'):
|
| 70 |
+
self.completed += 1
|
| 71 |
+
elapsed = time.time() - self.start_time
|
| 72 |
+
fps = self.completed / elapsed
|
| 73 |
+
if self.task_num > 0:
|
| 74 |
+
percentage = self.completed / float(self.task_num)
|
| 75 |
+
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
|
| 76 |
+
mark_width = int(self.bar_width * percentage)
|
| 77 |
+
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
|
| 78 |
+
sys.stdout.write('\033[2F') # cursor up 2 lines
|
| 79 |
+
sys.stdout.write(
|
| 80 |
+
'\033[J'
|
| 81 |
+
) # clean the output (remove extra chars since last display)
|
| 82 |
+
sys.stdout.write(
|
| 83 |
+
f'[{bar_chars}] {self.completed}/{self.task_num}, '
|
| 84 |
+
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, '
|
| 85 |
+
f'ETA: {eta:5}s\n{msg}\n')
|
| 86 |
+
else:
|
| 87 |
+
sys.stdout.write(
|
| 88 |
+
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s, '
|
| 89 |
+
f'{fps:.1f} tasks/s')
|
| 90 |
+
sys.stdout.flush()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class AverageMeter(object):
|
| 94 |
+
"""
|
| 95 |
+
Computes and stores the average and current value
|
| 96 |
+
Imported from
|
| 97 |
+
https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self):
|
| 101 |
+
self.reset()
|
| 102 |
+
|
| 103 |
+
def reset(self):
|
| 104 |
+
self.val = 0
|
| 105 |
+
self.avg = 0 # running average = running sum / running count
|
| 106 |
+
self.sum = 0 # running sum
|
| 107 |
+
self.count = 0 # running count
|
| 108 |
+
|
| 109 |
+
def update(self, val, n=1):
|
| 110 |
+
# n = batch_size
|
| 111 |
+
|
| 112 |
+
# val = batch accuracy for an attribute
|
| 113 |
+
# self.val = val
|
| 114 |
+
|
| 115 |
+
# sum = 100 * accumulative correct predictions for this attribute
|
| 116 |
+
self.sum += val * n
|
| 117 |
+
|
| 118 |
+
# count = total samples so far
|
| 119 |
+
self.count += n
|
| 120 |
+
|
| 121 |
+
# avg = 100 * avg accuracy for this attribute
|
| 122 |
+
# for all the batches so far
|
| 123 |
+
self.avg = self.sum / self.count
|