thanhnt-cf commited on
Commit
a9d8d74
·
1 Parent(s): fcd223a
app.py CHANGED
@@ -86,7 +86,7 @@ async def forward_request(
86
  service = AIServiceFactory.get_service(ai_vendor)
87
 
88
  try:
89
- json_attributes = await service.extract_attributes_with_validation(
90
  Product, # type: ignore
91
  ai_model,
92
  None,
@@ -101,7 +101,7 @@ async def forward_request(
101
  shutil.rmtree(request_temp_folder)
102
 
103
  gr.Info("Process completed!")
104
- return json_attributes
105
 
106
 
107
  def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values):
@@ -380,6 +380,9 @@ with gr.Blocks(title="Internal Demo for Attribution") as demo:
380
  output_json = gr.Json(
381
  label="Extracted Attributes", value={}, show_indices=False
382
  )
 
 
 
383
 
384
  # add_btn.click(
385
  # add_attribute_schema,
@@ -390,7 +393,7 @@ with gr.Blocks(title="Internal Demo for Attribution") as demo:
390
  submit_btn.click(
391
  forward_request,
392
  inputs=[attributes, product_taxnomy, product_data, ai_model, gallery],
393
- outputs=output_json,
394
  )
395
 
396
 
 
86
  service = AIServiceFactory.get_service(ai_vendor)
87
 
88
  try:
89
+ json_attributes, reevaluated = await service.extract_attributes_with_validation(
90
  Product, # type: ignore
91
  ai_model,
92
  None,
 
101
  shutil.rmtree(request_temp_folder)
102
 
103
  gr.Info("Process completed!")
104
+ return json_attributes, reevaluated
105
 
106
 
107
  def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values):
 
380
  output_json = gr.Json(
381
  label="Extracted Attributes", value={}, show_indices=False
382
  )
383
+ reevaluated_output_json = gr.Json(
384
+ label="Extracted Attributes", value={}, show_indices=False
385
+ )
386
 
387
  # add_btn.click(
388
  # add_attribute_schema,
 
393
  submit_btn.click(
394
  forward_request,
395
  inputs=[attributes, product_taxnomy, product_data, ai_model, gallery],
396
+ outputs=[output_json, reevaluated_output_json],
397
  )
398
 
399
 
app/core/prompts.py CHANGED
@@ -28,6 +28,13 @@ You should use the following product data to assist you, if available:
28
  If an attribute appears in both the image and the product data, use the value from the product data.
29
  """
30
 
 
 
 
 
 
 
 
31
 
32
  class Prompts(BaseSettings):
33
  EXTRACT_INFO_SYSTEM_MESSAGE: str = EXTRACT_INFO_SYSTEM
@@ -42,6 +49,10 @@ class Prompts(BaseSettings):
42
 
43
  GET_PERCENTAGE_HUMAN_MESSAGE: str = GET_PERCENTAGE_HUMAN
44
 
 
 
 
 
45
 
46
  # Create a cached instance of settings
47
  @lru_cache
 
28
  If an attribute appears in both the image and the product data, use the value from the product data.
29
  """
30
 
31
+ REEVALUATE_SYSTEM = "You are an expert in structured data extraction. You will be given an image or a set of images of a product and set of attributes and should reevaluate certainity of the attributes into the given structure."
32
+
33
+ REEVALUATE_HUMAN = """Reevaluate the following attributes of the main product (or {product_taxonomy}) shown in the images. Here are the attributes to reevaluate:
34
+ {product_data}
35
+
36
+ If an attribute can have multiple values, do not need to reevaluate the values, just the attribute itself. If an attribute can have only one value, reevaluate the top three values.
37
+ """
38
 
39
  class Prompts(BaseSettings):
40
  EXTRACT_INFO_SYSTEM_MESSAGE: str = EXTRACT_INFO_SYSTEM
 
49
 
50
  GET_PERCENTAGE_HUMAN_MESSAGE: str = GET_PERCENTAGE_HUMAN
51
 
52
+ REEVALUATE_SYSTEM_MESSAGE: str = REEVALUATE_SYSTEM
53
+
54
+ REEVALUATE_HUMAN_MESSAGE: str = REEVALUATE_HUMAN
55
+
56
 
57
  # Create a cached instance of settings
58
  @lru_cache
app/services/base.py CHANGED
@@ -11,6 +11,116 @@ from app.schemas.schema_tools import (
11
  )
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class BaseAttributionService(ABC):
15
  @abstractmethod
16
  async def extract_attributes(
@@ -23,6 +133,17 @@ class BaseAttributionService(ABC):
23
  ) -> Dict[str, Any]:
24
  pass
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  @abstractmethod
27
  async def follow_schema(
28
  self, schema: Dict[str, Any], data: Dict[str, Any]
@@ -52,9 +173,37 @@ class BaseAttributionService(ABC):
52
  # pil_images=pil_images, # temporarily removed to save cost
53
  img_paths=img_paths,
54
  )
 
55
  validate_json_data(data, schema)
56
 
57
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  async def follow_schema_with_validation(
60
  self, schema: Dict[str, Any], data: Dict[str, Any]
 
11
  )
12
 
13
 
14
+ example_data = example_data = {
15
+ "length": {
16
+ "maxi": 100,
17
+ "knee_length": 0,
18
+ "mini": 0,
19
+ "midi": 0
20
+ },
21
+ "style": {
22
+ "a_line": 0,
23
+ "bodycon": 0,
24
+ "shirt_dress": 0,
25
+ "wrap_dress": 0,
26
+ "slip": 0,
27
+ "smock": 0,
28
+ "corset": 100,
29
+ "jumper_dress": 0,
30
+ "blazer_dress": 0,
31
+ "asymmetric": 0,
32
+ "shift": 0,
33
+ "drop_waist": 0,
34
+ "empire": 0,
35
+ "modest": 0
36
+ },
37
+ "sleeve_length": {
38
+ "sleeveless": 0,
39
+ "three_quarters_sleeve": 0,
40
+ "long_sleeve": 0,
41
+ "short_sleeve": 0,
42
+ "strapless": 100
43
+ },
44
+ "neckline": {
45
+ "v_neck": 0,
46
+ "sweetheart": 100,
47
+ "round_neck": 0,
48
+ "halter_neck": 0,
49
+ "square_neck": 0,
50
+ "high_neck": 0,
51
+ "crew_neck": 0,
52
+ "turtle_neck": 0,
53
+ "off_the_shoulder": 0,
54
+ "one_shoulder": 0,
55
+ "boat_neck": 0
56
+ },
57
+ "pattern": {
58
+ "floral": 0,
59
+ "stripe": 0,
60
+ "leopard_print": 0,
61
+ "plain": 100,
62
+ "geometric": 0,
63
+ "logo": 0,
64
+ "graphic_print": 0,
65
+ "other": 0
66
+ },
67
+ "fabric": {
68
+ "cotton": 0,
69
+ "denim": 0,
70
+ "linen": 0,
71
+ "satin": 0,
72
+ "silk": 0,
73
+ "sequin": 0,
74
+ "leather": 0,
75
+ "velvet": 100,
76
+ "knit": 0,
77
+ "lace": 0,
78
+ "suede": 0,
79
+ "sheer": 0,
80
+ "polyester": 0,
81
+ "viscose": 0
82
+ },
83
+ "features": {
84
+ "pockets": 0,
85
+ "lined": 0,
86
+ "cut_out": 0,
87
+ "backless": 0,
88
+ "none": 100
89
+ },
90
+ "closure": {
91
+ "button": 0,
92
+ "zip": 0,
93
+ "press_stud": 0,
94
+ "clasp": 0
95
+ },
96
+ "body_fit": {
97
+ "petite": 0,
98
+ "maternity": 0,
99
+ "regular": 100,
100
+ "tall": 0,
101
+ "plus_size": 0
102
+ },
103
+ "occasion": {
104
+ "beach": 0,
105
+ "casual": 0,
106
+ "cocktail": 0,
107
+ "day": 0,
108
+ "evening": 100,
109
+ "mother_of_the_bride": 0,
110
+ "party": 0,
111
+ "prom": 0,
112
+ "wedding_guest": 0,
113
+ "work": 0,
114
+ "sportswear": 0
115
+ },
116
+ "season": {
117
+ "spring": 0,
118
+ "summer": 0,
119
+ "autumn": 0,
120
+ "winter": 100
121
+ }
122
+ }
123
+
124
  class BaseAttributionService(ABC):
125
  @abstractmethod
126
  async def extract_attributes(
 
133
  ) -> Dict[str, Any]:
134
  pass
135
 
136
+ @abstractmethod
137
+ async def reevaluate_atributes(
138
+ self,
139
+ attributes_model: Type[BaseModel],
140
+ ai_model: str,
141
+ img_urls: List[str],
142
+ product_taxonomy: str,
143
+ pil_images: List[Any] = None,
144
+ ) -> Dict[str, Any]:
145
+ pass
146
+
147
  @abstractmethod
148
  async def follow_schema(
149
  self, schema: Dict[str, Any], data: Dict[str, Any]
 
173
  # pil_images=pil_images, # temporarily removed to save cost
174
  img_paths=img_paths,
175
  )
176
+ # data = example_data
177
  validate_json_data(data, schema)
178
 
179
+ str_data = str(data)
180
+ reevaluate_data = await self.reevaluate_atributes(
181
+ attributes_model,
182
+ ai_model,
183
+ img_urls,
184
+ product_taxonomy if product_taxonomy != "" else "main",
185
+ str_data,
186
+ # pil_images=pil_images, # temporarily removed to save cost
187
+ img_paths=img_paths,
188
+ )
189
+
190
+ init_reevaluate_data = {}
191
+ for field_name, field in attributes_model.model_fields.items(): # type: ignore
192
+ print(f"{field_name}: {field.description}")
193
+ if "single value" in field.description.lower():
194
+ max_percentage = 0
195
+ for k, v in reevaluate_data[field_name].items():
196
+ if v > max_percentage:
197
+ max_percentage = v
198
+ init_reevaluate_data[field_name] = k
199
+ elif "multiple values" in field.description.lower():
200
+ init_list = []
201
+ for k, v in reevaluate_data[field_name].items():
202
+ if v >= 60:
203
+ init_list.append(k)
204
+ init_reevaluate_data[field_name] = init_list
205
+
206
+ return data, init_reevaluate_data
207
 
208
  async def follow_schema_with_validation(
209
  self, schema: Dict[str, Any], data: Dict[str, Any]
app/services/service_openai.py CHANGED
@@ -147,6 +147,95 @@ class OpenAIService(BaseAttributionService):
147
  raise VendorError(errors.VENDOR_ERROR_INVALID_JSON)
148
 
149
  return parsed_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  @weave.op
152
  async def follow_schema(
 
147
  raise VendorError(errors.VENDOR_ERROR_INVALID_JSON)
148
 
149
  return parsed_data
150
+
151
+ async def reevaluate_atributes(
152
+ self,
153
+ attributes_model: Type[BaseModel],
154
+ ai_model: str,
155
+ img_urls: List[str],
156
+ product_taxonomy: str,
157
+ product_data: str,
158
+ pil_images: List[Any] = None, # do not remove, this is for weave
159
+ img_paths: List[str] = None,
160
+ ) -> Dict[str, Any]:
161
+
162
+ print("Prompt: ")
163
+ print(prompts.REEVALUATE_HUMAN_MESSAGE.format(product_taxonomy=product_taxonomy, product_data=product_data))
164
+
165
+ text_content = [
166
+ {
167
+ "type": "text",
168
+ "text": prompts.REEVALUATE_HUMAN_MESSAGE.format(
169
+ product_taxonomy=product_taxonomy,
170
+ product_data=product_data,
171
+ ),
172
+ },
173
+ ]
174
+ if img_urls is not None:
175
+ base64_data_list = []
176
+ data_format_list = []
177
+
178
+ for img_url in img_urls:
179
+ base64_data, data_format = get_image_base64_and_type(img_url)
180
+ base64_data_list.append(base64_data)
181
+ data_format_list.append(data_format)
182
+
183
+ image_content = [
184
+ {
185
+ "type": "image_url",
186
+ "image_url": {
187
+ "url": f"data:image/{data_format};base64,{base64_data}",
188
+ },
189
+ }
190
+ for base64_data, data_format in zip(base64_data_list, data_format_list)
191
+ ]
192
+ elif img_paths is not None:
193
+ image_content = [
194
+ {
195
+ "type": "image_url",
196
+ "image_url": {
197
+ "url": f"data:image/{get_data_format(img_path)};base64,{get_image_data(img_path)}",
198
+ },
199
+ }
200
+ for img_path in img_paths
201
+ ]
202
+
203
+ try:
204
+ logger.info("Extracting info via OpenAI...")
205
+ response = await self.client.beta.chat.completions.parse(
206
+ model=ai_model,
207
+ messages=[
208
+ {
209
+ "role": "system",
210
+ "content": prompts.REEVALUATE_SYSTEM_MESSAGE,
211
+ },
212
+ {
213
+ "role": "user",
214
+ "content": text_content + image_content,
215
+ },
216
+ ],
217
+ max_tokens=1000,
218
+ response_format=attributes_model,
219
+ logprobs=False,
220
+ # top_logprobs=2,
221
+ # temperature=0.0,
222
+ top_p=1e-45,
223
+ )
224
+ except openai.BadRequestError as e:
225
+ error_message = exception_to_str(e)
226
+ raise BadRequestError(error_message)
227
+ except Exception as e:
228
+ raise VendorError(
229
+ errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
230
+ )
231
+
232
+ try:
233
+ content = response.choices[0].message.content
234
+ parsed_data = json.loads(content)
235
+ except:
236
+ raise VendorError(errors.VENDOR_ERROR_INVALID_JSON)
237
+
238
+ return parsed_data
239
 
240
  @weave.op
241
  async def follow_schema(