thanhnt-cf commited on
Commit
0dd08cb
·
1 Parent(s): 34967e4
app.py CHANGED
@@ -31,15 +31,13 @@ async def forward_request(
31
 
32
  try:
33
  # convert attributes to schema
34
- attributes = "attributes_object = {" + attributes + "}"
35
  try:
36
- attributes = exec(attributes, globals())
37
  except:
38
  raise gr.Error(
39
  "Invalid `Attribute Schema`. Please insert valid schema following the example."
40
  )
41
- for key, value in attributes_object.items(): # type: ignore
42
- attributes_object[key] = Attribute(**value) # type: ignore
43
 
44
  if product_data == "":
45
  product_data = "{}"
@@ -89,7 +87,7 @@ async def forward_request(
89
 
90
  try:
91
  json_attributes = await service.extract_attributes_with_validation(
92
- attributes_object, # type: ignore
93
  ai_model,
94
  None,
95
  product_taxonomy,
@@ -118,36 +116,149 @@ def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_va
118
  """
119
  return attributes + schema, "", "", "", ""
120
 
 
 
 
 
 
121
 
122
- sample_schema = """"category": {
123
- "description": "Category of the garment",
124
- "data_type": "list[string]",
125
- "allowed_values": [
126
- "upper garment", "lower garment", "footwear", "accessory", "headwear", "dresses"
127
- ]
128
- },
129
-
130
- "color": {
131
- "description": "Color of the garment",
132
- "data_type": "list[string]",
133
- "allowed_values": [
134
- "black", "white", "red", "blue", "green", "yellow", "pink", "purple", "orange", "brown", "grey", "beige", "multi-color", "other"
135
- ]
136
- },
137
-
138
- "pattern": {
139
- "description": "Pattern of the garment",
140
- "data_type": "list[string]",
141
- "allowed_values": [
142
- "plain", "striped", "checkered", "floral", "polka dot", "camouflage", "animal print", "abstract", "other"
143
- ]
144
- },
145
-
146
- "material": {
147
- "description": "Material of the garment",
148
- "data_type": "string",
149
- "allowed_values": []
150
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  """
152
  description = """
153
  This is a simple demo for Attribution. Follow the steps below:
@@ -239,32 +350,32 @@ with gr.Blocks(title="Internal Demo for Attribution") as demo:
239
  max_lines=30,
240
  )
241
 
242
- with gr.Accordion("Add Attributes", open=False):
243
- attr_name = gr.Textbox(
244
- label="Attribute name", placeholder="Enter attribute name"
245
- )
246
- attr_desc = gr.Textbox(
247
- label="Description", placeholder="Enter description"
248
- )
249
- attr_type = gr.Dropdown(
250
- label="Type",
251
- choices=[
252
- "string",
253
- "list[string]",
254
- "int",
255
- "list[int]",
256
- "float",
257
- "list[float]",
258
- "bool",
259
- "list[bool]",
260
- ],
261
- interactive=True,
262
- )
263
- allowed_values = gr.Textbox(
264
- label="Allowed values (separated by comma)",
265
- placeholder="yellow, red, blue",
266
- )
267
- add_btn = gr.Button("Add Attribute")
268
 
269
  with gr.Row():
270
  submit_btn = gr.Button("Extract Attributes")
@@ -274,11 +385,11 @@ with gr.Blocks(title="Internal Demo for Attribution") as demo:
274
  label="Extracted Attributes", value={}, show_indices=False
275
  )
276
 
277
- add_btn.click(
278
- add_attribute_schema,
279
- inputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
280
- outputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
281
- )
282
 
283
  submit_btn.click(
284
  forward_request,
 
31
 
32
  try:
33
  # convert attributes to schema
34
+ attributes = import_for_schema + attributes
35
  try:
36
+ exec(attributes, globals())
37
  except:
38
  raise gr.Error(
39
  "Invalid `Attribute Schema`. Please insert valid schema following the example."
40
  )
 
 
41
 
42
  if product_data == "":
43
  product_data = "{}"
 
87
 
88
  try:
89
  json_attributes = await service.extract_attributes_with_validation(
90
+ Product, # type: ignore
91
  ai_model,
92
  None,
93
  product_taxonomy,
 
116
  """
117
  return attributes + schema, "", "", "", ""
118
 
119
+ import_for_schema = """
120
+ from enum import Enum
121
+ from pydantic import BaseModel, Field
122
+ from typing import List
123
+ """
124
 
125
+ sample_schema = """from pydantic import BaseModel, Field
126
+
127
+
128
+ class Length(BaseModel):
129
+ maxi: int = Field(..., description="Maxi length dress")
130
+ knee_length: int = Field(..., description="Knee length dress")
131
+ mini: int = Field(..., description="Mini dress")
132
+ midi: int = Field(..., description="Midi dress")
133
+
134
+
135
+ class Style(BaseModel):
136
+ a_line: int = Field(..., description="A Line style")
137
+ bodycon: int = Field(..., description="Bodycon style")
138
+ column: int = Field(..., description="Column style")
139
+ shirt_dress: int = Field(..., description="Shirt Dress")
140
+ wrap_dress: int = Field(..., description="Wrap Dress")
141
+ slip: int = Field(..., description="Slip dress")
142
+ kaftan: int = Field(..., description="Kaftan")
143
+ smock: int = Field(..., description="Smock")
144
+ corset: int = Field(..., description="Corset bodice")
145
+ pinafore: int = Field(..., description="Pinafore")
146
+ jumper_dress: int = Field(..., description="Jumper Dress")
147
+ blazer_dress: int = Field(..., description="Blazer Dress")
148
+ tunic: int = Field(..., description="Tunic")
149
+
150
+
151
+ class SleeveLength(BaseModel):
152
+ sleeveless: int = Field(..., description="Sleeveless")
153
+ three_quarters_sleeve: int = Field(..., description="Three quarters Sleeve")
154
+ long_sleeve: int = Field(..., description="Long Sleeve")
155
+ short_sleeve: int = Field(..., description="Short Sleeve")
156
+ strapless: int = Field(..., description="Strapless")
157
+
158
+
159
+ class Neckline(BaseModel):
160
+ v_neck: int = Field(..., description="V Neck")
161
+ sweetheart: int = Field(..., description="Sweetheart neckline")
162
+ round_neck: int = Field(..., description="Round Neck")
163
+ halter_neck: int = Field(..., description="Halter Neck")
164
+ square_neck: int = Field(..., description="Square Neck")
165
+ high_neck: int = Field(..., description="High Neck")
166
+ crew_neck: int = Field(..., description="Crew Neck")
167
+ cowl_neck: int = Field(..., description="Cowl Neck")
168
+ turtle_neck: int = Field(..., description="Turtle Neck")
169
+ off_the_shoulder: int = Field(..., description="Off the Shoulder")
170
+ one_shoulder: int = Field(..., description="One Shoulder")
171
+
172
+
173
+ class Pattern(BaseModel):
174
+ floral: int = Field(..., description="Floral pattern")
175
+ stripe: int = Field(..., description="Stripe pattern")
176
+ leopard_print: int = Field(..., description="Leopard print")
177
+ spot: int = Field(..., description="Spot pattern")
178
+ plain: int = Field(..., description="Plain")
179
+ geometric: int = Field(..., description="Geometric pattern")
180
+ logo: int = Field(..., description="Logo print")
181
+ graphic_print: int = Field(..., description="Graphic print")
182
+ check: int = Field(..., description="Check pattern")
183
+ other: int = Field(..., description="Other pattern")
184
+
185
+
186
+ class Fabric(BaseModel):
187
+ cotton: int = Field(..., description="Cotton")
188
+ denim: int = Field(..., description="Denim")
189
+ jersey: int = Field(..., description="Jersey")
190
+ linen: int = Field(..., description="Linen")
191
+ satin: int = Field(..., description="Satin")
192
+ silk: int = Field(..., description="Silk")
193
+ sequin: int = Field(..., description="Sequin")
194
+ leather: int = Field(..., description="Leather")
195
+ velvet: int = Field(..., description="Velvet")
196
+ knit: int = Field(..., description="Knit")
197
+ lace: int = Field(..., description="Lace")
198
+ suede: int = Field(..., description="Suede")
199
+ sheer: int = Field(..., description="Sheer")
200
+ tulle: int = Field(..., description="Tulle")
201
+ crepe: int = Field(..., description="Crepe")
202
+ polyester: int = Field(..., description="Polyester")
203
+ viscose: int = Field(..., description="Viscose")
204
+
205
+
206
+ class Features(BaseModel):
207
+ pockets: int = Field(..., description="Has pockets")
208
+ lined: int = Field(..., description="Lined")
209
+ cut_out: int = Field(..., description="Cut out design")
210
+ backless: int = Field(..., description="Backless")
211
+ none: int = Field(..., description="No special features")
212
+
213
+
214
+ class Closure(BaseModel):
215
+ button: int = Field(..., description="Button closure")
216
+ zip: int = Field(..., description="Zip closure")
217
+ press_stud: int = Field(..., description="Press stud closure")
218
+ clasp: int = Field(..., description="Clasp closure")
219
+
220
+
221
+ class BodyFit(BaseModel):
222
+ petite: int = Field(..., description="Petite fit")
223
+ maternity: int = Field(..., description="Maternity fit")
224
+ regular: int = Field(..., description="Regular fit")
225
+ tall: int = Field(..., description="Tall fit")
226
+ plus_size: int = Field(..., description="Plus size fit")
227
+
228
+
229
+ class Occasion(BaseModel):
230
+ beach: int = Field(..., description="Suitable for beach")
231
+ casual: int = Field(..., description="Casual wear")
232
+ cocktail: int = Field(..., description="Cocktail event")
233
+ day: int = Field(..., description="Day wear")
234
+ evening: int = Field(..., description="Evening wear")
235
+ mother_of_the_bride: int = Field(..., description="Mother of the bride dress")
236
+ party: int = Field(..., description="Party wear")
237
+ prom: int = Field(..., description="Prom dress")
238
+ wedding_guest: int = Field(..., description="Wedding guest dress")
239
+ work: int = Field(..., description="Work attire")
240
+ sportswear: int = Field(..., description="Sportswear")
241
+
242
+
243
+ class Season(BaseModel):
244
+ spring: int = Field(..., description="Spring season")
245
+ summer: int = Field(..., description="Summer season")
246
+ autumn: int = Field(..., description="Autumn season")
247
+ winter: int = Field(..., description="Winter season")
248
+
249
+
250
+ class Product(BaseModel):
251
+ length: Length = Field(..., description="Single value ,Length of the dress")
252
+ style: Style = Field(..., description="Can have multiple values, Style of the dress")
253
+ sleeve_length: SleeveLength = Field(..., description="Single value ,Sleeve length of the dress")
254
+ neckline: Neckline = Field(..., description="Single value ,Neckline of the dress")
255
+ pattern: Pattern = Field(..., description="Can have multiple values, Pattern of the dress")
256
+ fabric: Fabric = Field(..., description="Can have multiple values, Fabric of the dress")
257
+ features: Features = Field(..., description="Can have multiple values, Features of the dress")
258
+ closure: Closure = Field(..., description="Can have multiple values ,Closure of the dress")
259
+ body_fit: BodyFit = Field(..., description="Single value ,Body fit of the dress")
260
+ occasion: Occasion = Field(..., description="Can have multiple values ,Occasion of the dress")
261
+ season: Season = Field(..., description="Single value ,Season of the dress")
262
  """
263
  description = """
264
  This is a simple demo for Attribution. Follow the steps below:
 
350
  max_lines=30,
351
  )
352
 
353
+ # with gr.Accordion("Add Attributes", open=False):
354
+ # attr_name = gr.Textbox(
355
+ # label="Attribute name", placeholder="Enter attribute name"
356
+ # )
357
+ # attr_desc = gr.Textbox(
358
+ # label="Description", placeholder="Enter description"
359
+ # )
360
+ # attr_type = gr.Dropdown(
361
+ # label="Type",
362
+ # choices=[
363
+ # "string",
364
+ # "list[string]",
365
+ # "int",
366
+ # "list[int]",
367
+ # "float",
368
+ # "list[float]",
369
+ # "bool",
370
+ # "list[bool]",
371
+ # ],
372
+ # interactive=True,
373
+ # )
374
+ # allowed_values = gr.Textbox(
375
+ # label="Allowed values (separated by comma)",
376
+ # placeholder="yellow, red, blue",
377
+ # )
378
+ # add_btn = gr.Button("Add Attribute")
379
 
380
  with gr.Row():
381
  submit_btn = gr.Button("Extract Attributes")
 
385
  label="Extracted Attributes", value={}, show_indices=False
386
  )
387
 
388
+ # add_btn.click(
389
+ # add_attribute_schema,
390
+ # inputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
391
+ # outputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
392
+ # )
393
 
394
  submit_btn.click(
395
  forward_request,
app/core/prompts.py CHANGED
@@ -19,6 +19,15 @@ FOLLOW_SCHEMA_HUMAN = """Convert following attributes to structured schema. Keep
19
 
20
  {json_info}"""
21
 
 
 
 
 
 
 
 
 
 
22
 
23
  class Prompts(BaseSettings):
24
  EXTRACT_INFO_SYSTEM_MESSAGE: str = EXTRACT_INFO_SYSTEM
@@ -29,6 +38,10 @@ class Prompts(BaseSettings):
29
 
30
  FOLLOW_SCHEMA_HUMAN_MESSAGE: str = FOLLOW_SCHEMA_HUMAN
31
 
 
 
 
 
32
 
33
  # Create a cached instance of settings
34
  @lru_cache
 
19
 
20
  {json_info}"""
21
 
22
+ GET_PERCENTAGE_SYSTEM = "You have to assign a percentage of certainty from, don't output just 0 and 1"
23
+
24
+ GET_PERCENTAGE_HUMAN = """For each allowed value in each attribute, assign a percentage of certainty (in scale of 100) that the product fits that value.
25
+ If an attribute can have multiple values, evaluate each value independently. If an attribute can have only one value, the percentages of certainty should sum up to 100.
26
+ You should use the following product data to assist you, if available:
27
+ {product_data}
28
+ If an attribute appears in both the image and the product data, use the value from the product data.
29
+ """
30
+
31
 
32
  class Prompts(BaseSettings):
33
  EXTRACT_INFO_SYSTEM_MESSAGE: str = EXTRACT_INFO_SYSTEM
 
38
 
39
  FOLLOW_SCHEMA_HUMAN_MESSAGE: str = FOLLOW_SCHEMA_HUMAN
40
 
41
+ GET_PERCENTAGE_SYSTEM_MESSAGE: str = GET_PERCENTAGE_SYSTEM
42
+
43
+ GET_PERCENTAGE_HUMAN_MESSAGE: str = GET_PERCENTAGE_HUMAN
44
+
45
 
46
  # Create a cached instance of settings
47
  @lru_cache
app/services/base.py CHANGED
@@ -31,7 +31,7 @@ class BaseAttributionService(ABC):
31
 
32
  async def extract_attributes_with_validation(
33
  self,
34
- attributes: Dict[str, Any],
35
  ai_model: str,
36
  img_urls: List[str],
37
  product_taxonomy: str,
@@ -42,17 +42,6 @@ class BaseAttributionService(ABC):
42
  # validate_json_schema(schema)
43
 
44
  # create mappings for keys of attributes, to make the key following naming convention of python variables
45
- forward_mapping = {}
46
- reverse_mapping = {}
47
- for i, key in enumerate(attributes.keys()):
48
- forward_mapping[key] = f'{to_snake_case(key)}_{i}'
49
- reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
50
-
51
- transformed_attributes = {}
52
- for key, value in attributes.items():
53
- transformed_attributes[forward_mapping[key]] = value
54
-
55
- attributes_model = convert_attribute_to_model(transformed_attributes)
56
  schema = attributes_model.model_json_schema()
57
  data = await self.extract_attributes(
58
  attributes_model,
@@ -65,11 +54,7 @@ class BaseAttributionService(ABC):
65
  )
66
  validate_json_data(data, schema)
67
 
68
- # reverse the key mapping to the original keys
69
- reverse_data = {}
70
- for key, value in data.items():
71
- reverse_data[reverse_mapping[key]] = value
72
- return reverse_data
73
 
74
  async def follow_schema_with_validation(
75
  self, schema: Dict[str, Any], data: Dict[str, Any]
 
31
 
32
  async def extract_attributes_with_validation(
33
  self,
34
+ attributes_model: Type[BaseModel],
35
  ai_model: str,
36
  img_urls: List[str],
37
  product_taxonomy: str,
 
42
  # validate_json_schema(schema)
43
 
44
  # create mappings for keys of attributes, to make the key following naming convention of python variables
 
 
 
 
 
 
 
 
 
 
 
45
  schema = attributes_model.model_json_schema()
46
  data = await self.extract_attributes(
47
  attributes_model,
 
54
  )
55
  validate_json_data(data, schema)
56
 
57
+ return data
 
 
 
 
58
 
59
  async def follow_schema_with_validation(
60
  self, schema: Dict[str, Any], data: Dict[str, Any]
app/services/service_anthropic.py CHANGED
@@ -26,8 +26,8 @@ elif ENV == "UAT":
26
  elif ENV == "PROD":
27
  pass
28
 
29
- if ENV != "PROD":
30
- weave.init(project_name=weave_project_name)
31
  settings = get_settings()
32
  prompts = get_prompts()
33
  logger = setup_logger(__name__)
@@ -82,12 +82,12 @@ class AnthropicService(BaseAttributionService):
82
  # this is not expected, raise some errors here later.
83
  pass
84
 
85
- system_message = [{"type": "text", "text": prompts.EXTRACT_INFO_SYSTEM_MESSAGE}]
86
 
87
  text_messages = [
88
  {
89
  "type": "text",
90
- "text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
91
  product_taxonomy=product_taxonomy,
92
  product_data=product_data_to_str(product_data),
93
  ),
 
26
  elif ENV == "PROD":
27
  pass
28
 
29
+ # if ENV != "PROD":
30
+ # weave.init(project_name=weave_project_name)
31
  settings = get_settings()
32
  prompts = get_prompts()
33
  logger = setup_logger(__name__)
 
82
  # this is not expected, raise some errors here later.
83
  pass
84
 
85
+ system_message = [{"type": "text", "text": prompts.GET_PERCENTAGE_SYSTEM_MESSAGE}]
86
 
87
  text_messages = [
88
  {
89
  "type": "text",
90
+ "text": prompts.GET_PERCENTAGE_HUMAN_MESSAGE.format(
91
  product_taxonomy=product_taxonomy,
92
  product_data=product_data_to_str(product_data),
93
  ),
app/services/service_openai.py CHANGED
@@ -31,8 +31,8 @@ elif ENV == "UAT":
31
  elif ENV == "PROD":
32
  pass
33
 
34
- if ENV != "PROD":
35
- weave.init(project_name=weave_project_name)
36
  settings = get_settings()
37
  prompts = get_prompts()
38
  logger = setup_logger(__name__)
@@ -71,7 +71,7 @@ class OpenAIService(BaseAttributionService):
71
  ) -> Dict[str, Any]:
72
 
73
  print("Prompt: ")
74
- print(prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(product_taxonomy=product_taxonomy, product_data=product_data_to_str(product_data)))
75
 
76
  text_content = [
77
  {
@@ -118,7 +118,7 @@ class OpenAIService(BaseAttributionService):
118
  messages=[
119
  {
120
  "role": "system",
121
- "content": prompts.EXTRACT_INFO_SYSTEM_MESSAGE,
122
  },
123
  {
124
  "role": "user",
 
31
  elif ENV == "PROD":
32
  pass
33
 
34
+ # if ENV != "PROD":
35
+ # weave.init(project_name=weave_project_name)
36
  settings = get_settings()
37
  prompts = get_prompts()
38
  logger = setup_logger(__name__)
 
71
  ) -> Dict[str, Any]:
72
 
73
  print("Prompt: ")
74
+ print(prompts.GET_PERCENTAGE_HUMAN_MESSAGE.format(product_taxonomy=product_taxonomy, product_data=product_data_to_str(product_data)))
75
 
76
  text_content = [
77
  {
 
118
  messages=[
119
  {
120
  "role": "system",
121
+ "content": prompts.GET_PERCENTAGE_SYSTEM_MESSAGE,
122
  },
123
  {
124
  "role": "user",