thanhnt-cf commited on
Commit
9645c29
·
1 Parent(s): e85027d

update demo app

Browse files
app/config.py CHANGED
@@ -12,10 +12,14 @@ if os.getenv("HUGGINGFACE_DEMO"):
12
  else:
13
  from app.aws.secrets import get_secret
14
 
 
15
  secrets = get_secret()
16
- os.environ["WANDB_API_KEY"] = secrets["WANDB_API_KEY"]
 
17
  OPENAI_API_KEY = secrets["OPENAI_API_KEY"]
18
  ANTHROPIC_API_KEY = secrets["ANTHROPIC_API_KEY"]
 
 
19
  os.environ["WANDB_BASE_URL"] = "https://api.wandb.ai"
20
 
21
 
@@ -59,21 +63,20 @@ class Settings(BaseSettings):
59
  RATE_LIMIT_PERIOD: int = 60
60
 
61
  # Cache Configuration
62
- REDIS_URL: Optional[str] = None
63
- CACHE_TTL: int = 3600 # 1 hour
64
 
65
  # Logging
66
  LOG_LEVEL: str = "INFO"
67
  LOG_FORMAT: str = "json"
68
 
69
- # Timeout Configuration
70
- OPENAI_TIMEOUT: float = 30.0
71
- ANTHROPIC_TIMEOUT: float = 30.0
72
-
73
  # API Keys
74
  OPENAI_API_KEY: str = OPENAI_API_KEY
75
  ANTHROPIC_API_KEY: str = ANTHROPIC_API_KEY
76
 
 
 
 
77
  def validate_api_keys(self):
78
  """Validate that required API keys are present."""
79
  if not self.OPENAI_API_KEY:
 
12
  else:
13
  from app.aws.secrets import get_secret
14
 
15
+ ENV = os.getenv("ENV", "LOCAL")
16
  secrets = get_secret()
17
+ if ENV != "PROD":
18
+ os.environ["WANDB_API_KEY"] = secrets["WANDB_API_KEY"]
19
  OPENAI_API_KEY = secrets["OPENAI_API_KEY"]
20
  ANTHROPIC_API_KEY = secrets["ANTHROPIC_API_KEY"]
21
+ REDIS_PASSWORD = secrets["REDIS_PASSWORD"] if ENV == "PROD" else ""
22
+ REDIS_USE_SSL = True if ENV == "PROD" or ENV == "UAT" else False
23
  os.environ["WANDB_BASE_URL"] = "https://api.wandb.ai"
24
 
25
 
 
63
  RATE_LIMIT_PERIOD: int = 60
64
 
65
  # Cache Configuration
66
+ REDIS_PASSWORD: Optional[str] = REDIS_PASSWORD
67
+ REDIS_USE_SSL: Optional[bool] = REDIS_USE_SSL
68
 
69
  # Logging
70
  LOG_LEVEL: str = "INFO"
71
  LOG_FORMAT: str = "json"
72
 
 
 
 
 
73
  # API Keys
74
  OPENAI_API_KEY: str = OPENAI_API_KEY
75
  ANTHROPIC_API_KEY: str = ANTHROPIC_API_KEY
76
 
77
+ #
78
+ MAX_DOWNLOAD_RETRY: int = 10 # times
79
+
80
  def validate_api_keys(self):
81
  """Validate that required API keys are present."""
82
  if not self.OPENAI_API_KEY:
app/core/prompts.py CHANGED
@@ -6,7 +6,7 @@ from pydantic_settings import BaseSettings
6
  EXTRACT_INFO_SYSTEM = "You are an expert in structured data extraction. You will be given an image or a set of images of a product and should extract its properties into the given structure."
7
 
8
  EXTRACT_INFO_HUMAN = (
9
- """Output properties of the {product_taxonomy} product shown in the images. You should use the following product data to assist you, if available:
10
 
11
  {product_data}
12
 
 
6
  EXTRACT_INFO_SYSTEM = "You are an expert in structured data extraction. You will be given an image or a set of images of a product and should extract its properties into the given structure."
7
 
8
  EXTRACT_INFO_HUMAN = (
9
+ """Output properties of the main product (or {product_taxonomy}) shown in the images. You should use the following product data to assist you, if available:
10
 
11
  {product_data}
12
 
app/schemas/schema_tools.py CHANGED
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field # do not remove this import for exec
8
 
9
  from app.core.errors import VendorError
10
  from app.schemas.requests import Attribute
 
11
  from app.utils.logger import exception_to_str
12
 
13
 
@@ -74,38 +75,8 @@ def convert_attribute_to_model(attributes: Dict[str, Attribute]) -> Dict[str, An
74
 
75
  if len(allowed_values) > 0:
76
  enum_code = f"class {key.capitalize()}Enum(str, Enum):\n"
77
- for allowed_value in allowed_values:
78
- enum_name = (
79
- allowed_value.replace(" ", "_")
80
- .replace("-", "_")
81
- .replace("&", "AND")
82
- .replace("/", "_OR_")
83
- .replace(":", "__")
84
- .replace("+", "plus")
85
- .replace(",", "_")
86
- .replace(".", "_")
87
- .replace("°", "degree")
88
- .replace("(", "")
89
- .replace(")", "")
90
- .replace("'", "_")
91
- .replace('%', "")
92
- .replace("!", "")
93
- .replace("?", "")
94
- .replace("`", "")
95
- .replace("~", "")
96
- .replace(";", "")
97
- .replace("<", "")
98
- .replace(">", "")
99
- .replace("[", "")
100
- .replace("]", "")
101
- .replace("{", "")
102
- .replace("}", "")
103
- .replace("\\", "")
104
- .replace("|", "")
105
- .replace('–', "_")
106
- .replace('*', "_")
107
- .upper()
108
- )
109
 
110
  if "'" in allowed_value:
111
  enum_code += f' E{enum_name} = "{allowed_value}"\n'
 
8
 
9
  from app.core.errors import VendorError
10
  from app.schemas.requests import Attribute
11
+ from app.utils.converter import to_snake_case
12
  from app.utils.logger import exception_to_str
13
 
14
 
 
75
 
76
  if len(allowed_values) > 0:
77
  enum_code = f"class {key.capitalize()}Enum(str, Enum):\n"
78
+ for i, allowed_value in enumerate(allowed_values):
79
+ enum_name = f'{to_snake_case(allowed_value).upper()}_{i}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if "'" in allowed_value:
82
  enum_code += f' E{enum_name} = "{allowed_value}"\n'
app/services/base.py CHANGED
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
2
  from typing import Any, Dict, List, Type, Union
3
 
4
  from pydantic import BaseModel
 
5
 
6
  from app.schemas.schema_tools import (
7
  convert_attribute_to_model,
@@ -39,7 +40,19 @@ class BaseAttributionService(ABC):
39
  img_paths: List[str] = None,
40
  ) -> Dict[str, Any]:
41
  # validate_json_schema(schema)
42
- attributes_model = convert_attribute_to_model(attributes)
 
 
 
 
 
 
 
 
 
 
 
 
43
  schema = attributes_model.model_json_schema()
44
  data = await self.extract_attributes(
45
  attributes_model,
@@ -51,7 +64,12 @@ class BaseAttributionService(ABC):
51
  img_paths=img_paths,
52
  )
53
  validate_json_data(data, schema)
54
- return data
 
 
 
 
 
55
 
56
  async def follow_schema_with_validation(
57
  self, schema: Dict[str, Any], data: Dict[str, Any]
 
2
  from typing import Any, Dict, List, Type, Union
3
 
4
  from pydantic import BaseModel
5
+ from app.utils.converter import to_snake_case
6
 
7
  from app.schemas.schema_tools import (
8
  convert_attribute_to_model,
 
40
  img_paths: List[str] = None,
41
  ) -> Dict[str, Any]:
42
  # validate_json_schema(schema)
43
+
44
+ # create mappings for keys of attributes, to make the key following naming convention of python variables
45
+ forward_mapping = {}
46
+ reverse_mapping = {}
47
+ for i, key in enumerate(attributes.keys()):
48
+ forward_mapping[key] = f'{to_snake_case(key)}_{i}'
49
+ reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
50
+
51
+ transformed_attributes = {}
52
+ for key, value in attributes.items():
53
+ transformed_attributes[forward_mapping[key]] = value
54
+
55
+ attributes_model = convert_attribute_to_model(transformed_attributes)
56
  schema = attributes_model.model_json_schema()
57
  data = await self.extract_attributes(
58
  attributes_model,
 
64
  img_paths=img_paths,
65
  )
66
  validate_json_data(data, schema)
67
+
68
+ # reverse the key mapping to the original keys
69
+ reverse_data = {}
70
+ for key, value in data.items():
71
+ reverse_data[reverse_mapping[key]] = value
72
+ return reverse_data
73
 
74
  async def follow_schema_with_validation(
75
  self, schema: Dict[str, Any], data: Dict[str, Any]
app/services/service_anthropic.py CHANGED
@@ -16,15 +16,18 @@ from app.utils.converter import product_data_to_str
16
  from app.utils.image_processing import get_data_format, get_image_data
17
  from app.utils.logger import exception_to_str, setup_logger
18
 
19
- env = os.getenv("ENV", "LOCAL")
20
- if env == "LOCAL": # local or demo
21
  weave_project_name = "cfai/attribution-exp"
22
- elif env == "DEV":
23
  weave_project_name = "cfai/attribution-dev"
24
- elif env == "PROD":
25
- weave_project_name = "cfai/attribution-prod"
 
 
26
 
27
- weave.init(project_name=weave_project_name)
 
28
  settings = get_settings()
29
  prompts = get_prompts()
30
  logger = setup_logger(__name__)
@@ -102,8 +105,8 @@ class AnthropicService(BaseAttributionService):
102
  system=system_message,
103
  tools=tools,
104
  messages=messages,
105
- temperature=0.0,
106
- top_p=1.0,
107
  top_k=1,
108
  )
109
  except anthropic.BadRequestError as e:
 
16
  from app.utils.image_processing import get_data_format, get_image_data
17
  from app.utils.logger import exception_to_str, setup_logger
18
 
19
+ ENV = os.getenv("ENV", "LOCAL")
20
+ if ENV == "LOCAL": # local or demo
21
  weave_project_name = "cfai/attribution-exp"
22
+ elif ENV == "DEV":
23
  weave_project_name = "cfai/attribution-dev"
24
+ elif ENV == "UAT":
25
+ weave_project_name = "cfai/attribution-uat"
26
+ elif ENV == "PROD":
27
+ pass
28
 
29
+ if ENV != "PROD":
30
+ weave.init(project_name=weave_project_name)
31
  settings = get_settings()
32
  prompts = get_prompts()
33
  logger = setup_logger(__name__)
 
105
  system=system_message,
106
  tools=tools,
107
  messages=messages,
108
+ # temperature=0.0,
109
+ # top_p=1e-45,
110
  top_k=1,
111
  )
112
  except anthropic.BadRequestError as e:
app/services/service_openai.py CHANGED
@@ -8,7 +8,11 @@ from openai import AsyncOpenAI
8
  from pydantic import BaseModel
9
 
10
  from app.utils.converter import product_data_to_str
11
- from app.utils.image_processing import get_data_format, get_image_data
 
 
 
 
12
  from app.utils.logger import exception_to_str, setup_logger
13
 
14
  from ..config import get_settings
@@ -17,15 +21,18 @@ from ..core.errors import BadRequestError, VendorError
17
  from ..core.prompts import get_prompts
18
  from .base import BaseAttributionService
19
 
20
- env = os.getenv("ENV", "LOCAL")
21
- if env == "LOCAL": # local or demo
22
  weave_project_name = "cfai/attribution-exp"
23
- elif env == "DEV":
24
  weave_project_name = "cfai/attribution-dev"
25
- elif env == "PROD":
26
- weave_project_name = "cfai/attribution-prod"
 
 
27
 
28
- weave.init(project_name=weave_project_name)
 
29
  settings = get_settings()
30
  prompts = get_prompts()
31
  logger = setup_logger(__name__)
@@ -62,7 +69,6 @@ class OpenAIService(BaseAttributionService):
62
  pil_images: List[Any] = None, # do not remove, this is for weave
63
  img_paths: List[str] = None,
64
  ) -> Dict[str, Any]:
65
- logger.info("Extracting info via OpenAI...")
66
  text_content = [
67
  {
68
  "type": "text",
@@ -73,14 +79,22 @@ class OpenAIService(BaseAttributionService):
73
  },
74
  ]
75
  if img_urls is not None:
 
 
 
 
 
 
 
 
76
  image_content = [
77
  {
78
  "type": "image_url",
79
  "image_url": {
80
- "url": img_url,
81
  },
82
  }
83
- for img_url in img_urls
84
  ]
85
  elif img_paths is not None:
86
  image_content = [
@@ -94,6 +108,7 @@ class OpenAIService(BaseAttributionService):
94
  ]
95
 
96
  try:
 
97
  response = await self.client.beta.chat.completions.parse(
98
  model=ai_model,
99
  messages=[
@@ -110,11 +125,12 @@ class OpenAIService(BaseAttributionService):
110
  response_format=attributes_model,
111
  logprobs=False,
112
  # top_logprobs=2,
113
- temperature=0.0,
114
- top_p=1,
115
  )
116
  except openai.BadRequestError as e:
117
- raise BadRequestError(exception_to_str(e))
 
118
  except Exception as e:
119
  raise VendorError(
120
  errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
 
8
  from pydantic import BaseModel
9
 
10
  from app.utils.converter import product_data_to_str
11
+ from app.utils.image_processing import (
12
+ get_data_format,
13
+ get_image_base64_and_type,
14
+ get_image_data,
15
+ )
16
  from app.utils.logger import exception_to_str, setup_logger
17
 
18
  from ..config import get_settings
 
21
  from ..core.prompts import get_prompts
22
  from .base import BaseAttributionService
23
 
24
+ ENV = os.getenv("ENV", "LOCAL")
25
+ if ENV == "LOCAL": # local or demo
26
  weave_project_name = "cfai/attribution-exp"
27
+ elif ENV == "DEV":
28
  weave_project_name = "cfai/attribution-dev"
29
+ elif ENV == "UAT":
30
+ weave_project_name = "cfai/attribution-uat"
31
+ elif ENV == "PROD":
32
+ pass
33
 
34
+ if ENV != "PROD":
35
+ weave.init(project_name=weave_project_name)
36
  settings = get_settings()
37
  prompts = get_prompts()
38
  logger = setup_logger(__name__)
 
69
  pil_images: List[Any] = None, # do not remove, this is for weave
70
  img_paths: List[str] = None,
71
  ) -> Dict[str, Any]:
 
72
  text_content = [
73
  {
74
  "type": "text",
 
79
  },
80
  ]
81
  if img_urls is not None:
82
+ base64_data_list = []
83
+ data_format_list = []
84
+
85
+ for img_url in img_urls:
86
+ base64_data, data_format = get_image_base64_and_type(img_url)
87
+ base64_data_list.append(base64_data)
88
+ data_format_list.append(data_format)
89
+
90
  image_content = [
91
  {
92
  "type": "image_url",
93
  "image_url": {
94
+ "url": f"data:image/{data_format};base64,{base64_data}",
95
  },
96
  }
97
+ for base64_data, data_format in zip(base64_data_list, data_format_list)
98
  ]
99
  elif img_paths is not None:
100
  image_content = [
 
108
  ]
109
 
110
  try:
111
+ logger.info("Extracting info via OpenAI...")
112
  response = await self.client.beta.chat.completions.parse(
113
  model=ai_model,
114
  messages=[
 
125
  response_format=attributes_model,
126
  logprobs=False,
127
  # top_logprobs=2,
128
+ # temperature=0.0,
129
+ top_p=1e-45,
130
  )
131
  except openai.BadRequestError as e:
132
+ error_message = exception_to_str(e)
133
+ raise BadRequestError(error_message)
134
  except Exception as e:
135
  raise VendorError(
136
  errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
app/utils/converter.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Dict, List, Union
2
 
3
 
@@ -40,3 +41,20 @@ def product_data_to_str(product_data: Dict[str, Union[str, List[str]]]) -> str:
40
  data_list.append(data_line)
41
 
42
  return "\n".join(data_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
  from typing import Dict, List, Union
3
 
4
 
 
41
  data_list.append(data_line)
42
 
43
  return "\n".join(data_list)
44
+
45
+
46
+ def to_snake_case(s):
47
+ # Remove leading/trailing whitespace and convert to lowercase
48
+ s = s.strip().lower()
49
+ # Replace spaces, hyphens, and periods with underscores
50
+ s = re.sub(r'[\s\-\.\+]', '_', s)
51
+ # Remove any characters that are not alphanumeric or underscores
52
+ s = re.sub(r'[^a-z0-9_]', '', s)
53
+ # Replace multiple underscores with a single one
54
+ s = re.sub(r'_+', '_', s)
55
+ # Remove leading digits (Python variable names can't start with a number)
56
+ s = re.sub(r'^[0-9]+', '', s)
57
+ # Make sure it doesn't start with an underscore and is not empty
58
+ if not s or not s[0].isalpha():
59
+ s = 'var_' + s
60
+ return s
app/utils/image_processing.py CHANGED
@@ -1,4 +1,8 @@
1
  import base64
 
 
 
 
2
 
3
 
4
  def get_image_data(image_path):
@@ -11,4 +15,82 @@ def get_data_format(image_path):
11
  image_format = image_path.split(".")[-1]
12
  if image_format == "jpg":
13
  image_format = "jpeg"
14
- return image_format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
+ import io
3
+ from PIL import Image
4
+
5
+ import requests
6
 
7
 
8
  def get_image_data(image_path):
 
15
  image_format = image_path.split(".")[-1]
16
  if image_format == "jpg":
17
  image_format = "jpeg"
18
+ return
19
+
20
+
21
+ def get_image_base64_and_type(image_url: str, max_dimension: int = 2048) -> tuple[str | None, str | None]:
22
+ try:
23
+ # --- 1. Download the image ---
24
+ response = requests.get(image_url, stream=True, timeout=20) # Added timeout
25
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
26
+
27
+ # Check content type
28
+ content_type = response.headers.get('content-type')
29
+ allowed_types = ['image/png', 'image/jpeg', 'image/webp', 'image/gif']
30
+ if not content_type or content_type not in allowed_types:
31
+ raise ValueError(f"Unsupported image type: {content_type}. Expected one of {allowed_types}.")
32
+
33
+ # --- 2. Open the image using Pillow ---
34
+ image_data = response.content
35
+ img = Image.open(io.BytesIO(image_data))
36
+
37
+ # Check if the image is animated (GIF)
38
+ if img.format == 'GIF' and getattr(img, 'is_animated', False):
39
+ raise ValueError("Animated GIFs are not supported.")
40
+
41
+ # --- 3. Check dimensions and resize if necessary ---
42
+ width, height = img.size
43
+ longest_dim = max(width, height)
44
+
45
+ if longest_dim > max_dimension:
46
+ # print(f"Image dimensions ({width}x{height}) exceed max dimension ({max_dimension}). Resizing...")
47
+ if width > height:
48
+ # Width is the longest dimension
49
+ new_width = max_dimension
50
+ new_height = int(height * (max_dimension / width))
51
+ else:
52
+ # Height is the longest or they are equal
53
+ new_height = max_dimension
54
+ new_width = int(width * (max_dimension / height))
55
+
56
+ # Resize the image - Use Resampling.LANCZOS for high-quality downscaling
57
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
58
+ # print(f"Image resized to: {img.size}"
59
+
60
+ width, height = img.size
61
+ shortest_dim = min(width, height)
62
+
63
+ if shortest_dim > 768:
64
+ if width < height:
65
+ new_width = 768
66
+ new_height = int(height * (768 / width))
67
+ else:
68
+ new_height = 768
69
+ new_width = int(width * (768 / height))
70
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
71
+
72
+
73
+ # --- 4. Save the image to a byte buffer ---
74
+ # We need to save the potentially modified image back to bytes
75
+ buffer = io.BytesIO()
76
+ # Save with the JPG format. Handle potential format issues.
77
+ try:
78
+ img_format = 'JPEG'
79
+ img.save(buffer, format=img_format, quality=100)
80
+ except Exception as save_err:
81
+ try:
82
+ # Fallback to PNG if original format saving fails
83
+ img_format = 'PNG'
84
+ img.save(buffer, format=img_format)
85
+ except Exception as png_save_err:
86
+ raise Exception(f"Failed to save image in PNG format. Error: {png_save_err}")
87
+
88
+ image_bytes = buffer.getvalue()
89
+
90
+ # --- 5. Encode the image bytes to base64 ---
91
+ base64_encoded_image = base64.b64encode(image_bytes).decode('utf-8')
92
+
93
+ return base64_encoded_image, img_format
94
+
95
+ except Exception as e:
96
+ raise ValueError(f"Invalid image URL: {e}")