Spaces:
Sleeping
Sleeping
Commit
·
e85027d
1
Parent(s):
638f225
bring changes to demo app
Browse files- app/config.py +1 -0
- app/request_handler/extract_handler.py +25 -16
- app/request_handler/follow_handler.py +7 -3
- app/schemas/schema_tools.py +46 -5
- app/services/base.py +2 -2
- app/services/service_anthropic.py +25 -7
- app/services/service_openai.py +13 -8
- app/utils/logger.py +18 -0
app/config.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Optional
|
|
5 |
from pydantic_settings import BaseSettings
|
6 |
|
7 |
|
|
|
8 |
if os.getenv("HUGGINGFACE_DEMO"):
|
9 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
10 |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
|
|
5 |
from pydantic_settings import BaseSettings
|
6 |
|
7 |
|
8 |
+
os.environ["WEAVE_CAPTURE_CODE"] = "false"
|
9 |
if os.getenv("HUGGINGFACE_DEMO"):
|
10 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
11 |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
app/request_handler/extract_handler.py
CHANGED
@@ -9,7 +9,7 @@ from app.core.errors import BadRequestError, VendorError
|
|
9 |
from app.schemas.requests import ExtractionRequest
|
10 |
from app.schemas.responses import APIResponse
|
11 |
from app.services.factory import AIServiceFactory
|
12 |
-
from app.utils.logger import setup_logger
|
13 |
|
14 |
logger = setup_logger(__name__)
|
15 |
settings = get_settings()
|
@@ -32,16 +32,17 @@ async def handle_extract(request: ExtractionRequest):
|
|
32 |
)
|
33 |
service = AIServiceFactory.get_service(ai_vendor)
|
34 |
|
35 |
-
pil_images = []
|
|
|
36 |
for url in request.img_urls:
|
37 |
try:
|
38 |
-
response = requests.get(url)
|
39 |
-
response.raise_for_status()
|
40 |
-
image = Image.open(BytesIO(response.content))
|
41 |
-
pil_images.append(image)
|
|
|
42 |
except Exception as e:
|
43 |
-
|
44 |
-
logger.error(f"Failed to download or process image from {url}: {e}")
|
45 |
raise HTTPException(
|
46 |
status_code=400,
|
47 |
detail=f"Failed to process image from {url}",
|
@@ -58,20 +59,28 @@ async def handle_extract(request: ExtractionRequest):
|
|
58 |
)
|
59 |
break
|
60 |
except BadRequestError as e:
|
61 |
-
logger.error(
|
|
|
|
|
62 |
raise HTTPException(
|
63 |
-
status_code=400,
|
|
|
|
|
64 |
)
|
65 |
except ValueError as e:
|
66 |
-
logger.error("Value error:
|
67 |
raise HTTPException(
|
68 |
-
status_code=400,
|
|
|
|
|
69 |
)
|
70 |
except VendorError as e:
|
71 |
-
logger.error("Vendor error:
|
72 |
if attempt == request.max_attempts:
|
73 |
raise HTTPException(
|
74 |
-
status_code=500,
|
|
|
|
|
75 |
)
|
76 |
else:
|
77 |
if request.ai_model in settings.ANTHROPIC_MODELS:
|
@@ -90,10 +99,10 @@ async def handle_extract(request: ExtractionRequest):
|
|
90 |
)
|
91 |
|
92 |
except HTTPException as e:
|
93 |
-
logger.error("HTTP exception:
|
94 |
raise e
|
95 |
except Exception as e:
|
96 |
-
logger.error("Exception: ", e)
|
97 |
if (
|
98 |
"overload" in str(e).lower()
|
99 |
and request.ai_model in settings.ANTHROPIC_MODELS
|
|
|
9 |
from app.schemas.requests import ExtractionRequest
|
10 |
from app.schemas.responses import APIResponse
|
11 |
from app.services.factory import AIServiceFactory
|
12 |
+
from app.utils.logger import exception_to_str, setup_logger
|
13 |
|
14 |
logger = setup_logger(__name__)
|
15 |
settings = get_settings()
|
|
|
32 |
)
|
33 |
service = AIServiceFactory.get_service(ai_vendor)
|
34 |
|
35 |
+
# pil_images = []
|
36 |
+
pil_images = None # temporarily removed to save cost
|
37 |
for url in request.img_urls:
|
38 |
try:
|
39 |
+
# response = requests.get(url)
|
40 |
+
# response.raise_for_status()
|
41 |
+
# image = Image.open(BytesIO(response.content))
|
42 |
+
# pil_images.append(image)
|
43 |
+
pass
|
44 |
except Exception as e:
|
45 |
+
# logger.error(f"Failed to download or process image from {url}: {exception_to_str(e)}")
|
|
|
46 |
raise HTTPException(
|
47 |
status_code=400,
|
48 |
detail=f"Failed to process image from {url}",
|
|
|
59 |
)
|
60 |
break
|
61 |
except BadRequestError as e:
|
62 |
+
logger.error(
|
63 |
+
f"Bad request error: {exception_to_str(e)}",
|
64 |
+
)
|
65 |
raise HTTPException(
|
66 |
+
status_code=400,
|
67 |
+
detail=exception_to_str(e),
|
68 |
+
headers={"attempt": attempt},
|
69 |
)
|
70 |
except ValueError as e:
|
71 |
+
logger.error(f"Value error: {exception_to_str(e)}")
|
72 |
raise HTTPException(
|
73 |
+
status_code=400,
|
74 |
+
detail=exception_to_str(e),
|
75 |
+
headers={"attempt": attempt},
|
76 |
)
|
77 |
except VendorError as e:
|
78 |
+
logger.error(f"Vendor error: {exception_to_str(e)}")
|
79 |
if attempt == request.max_attempts:
|
80 |
raise HTTPException(
|
81 |
+
status_code=500,
|
82 |
+
detail=exception_to_str(e),
|
83 |
+
headers={"attempt": attempt},
|
84 |
)
|
85 |
else:
|
86 |
if request.ai_model in settings.ANTHROPIC_MODELS:
|
|
|
99 |
)
|
100 |
|
101 |
except HTTPException as e:
|
102 |
+
logger.error(f"HTTP exception: {exception_to_str(e)}")
|
103 |
raise e
|
104 |
except Exception as e:
|
105 |
+
logger.error("Exception: ", exception_to_str(e))
|
106 |
if (
|
107 |
"overload" in str(e).lower()
|
108 |
and request.ai_model in settings.ANTHROPIC_MODELS
|
app/request_handler/follow_handler.py
CHANGED
@@ -4,7 +4,7 @@ from app.config import get_settings
|
|
4 |
from app.core.errors import VendorError
|
5 |
from app.schemas.requests import FollowSchemaRequest
|
6 |
from app.services.factory import AIServiceFactory
|
7 |
-
from app.utils.logger import setup_logger
|
8 |
|
9 |
logger = setup_logger(__name__)
|
10 |
settings = get_settings()
|
@@ -34,12 +34,16 @@ async def handle_follow(request: FollowSchemaRequest):
|
|
34 |
except ValueError as e:
|
35 |
if attempt == request.max_attempts:
|
36 |
raise HTTPException(
|
37 |
-
status_code=400,
|
|
|
|
|
38 |
)
|
39 |
except VendorError as e:
|
40 |
if attempt == request.max_attempts:
|
41 |
raise HTTPException(
|
42 |
-
status_code=500,
|
|
|
|
|
43 |
)
|
44 |
except Exception as e:
|
45 |
if attempt == request.max_attempts:
|
|
|
4 |
from app.core.errors import VendorError
|
5 |
from app.schemas.requests import FollowSchemaRequest
|
6 |
from app.services.factory import AIServiceFactory
|
7 |
+
from app.utils.logger import exception_to_str, setup_logger
|
8 |
|
9 |
logger = setup_logger(__name__)
|
10 |
settings = get_settings()
|
|
|
34 |
except ValueError as e:
|
35 |
if attempt == request.max_attempts:
|
36 |
raise HTTPException(
|
37 |
+
status_code=400,
|
38 |
+
detail=exception_to_str(e),
|
39 |
+
headers={"attempt": attempt},
|
40 |
)
|
41 |
except VendorError as e:
|
42 |
if attempt == request.max_attempts:
|
43 |
raise HTTPException(
|
44 |
+
status_code=500,
|
45 |
+
detail=exception_to_str(e),
|
46 |
+
headers={"attempt": attempt},
|
47 |
)
|
48 |
except Exception as e:
|
49 |
if attempt == request.max_attempts:
|
app/schemas/schema_tools.py
CHANGED
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field # do not remove this import for exec
|
|
8 |
|
9 |
from app.core.errors import VendorError
|
10 |
from app.schemas.requests import Attribute
|
|
|
11 |
|
12 |
|
13 |
def validate_json_data(data: Dict[str, Any], schema: Dict[str, Any]):
|
@@ -17,7 +18,7 @@ def validate_json_data(data: Dict[str, Any], schema: Dict[str, Any]):
|
|
17 |
try:
|
18 |
jsonschema.validate(instance=data, schema=schema)
|
19 |
except jsonschema.ValidationError as e:
|
20 |
-
raise VendorError(f"Vendor generated invalid data
|
21 |
|
22 |
|
23 |
def validate_json_schema(schema: Dict[str, Any]):
|
@@ -74,16 +75,56 @@ def convert_attribute_to_model(attributes: Dict[str, Attribute]) -> Dict[str, An
|
|
74 |
if len(allowed_values) > 0:
|
75 |
enum_code = f"class {key.capitalize()}Enum(str, Enum):\n"
|
76 |
for allowed_value in allowed_values:
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
enum_code_list.append(enum_code)
|
79 |
data_type = f"{key.capitalize()}Enum"
|
80 |
|
81 |
if is_list:
|
82 |
data_type = f"List[{data_type}]"
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
entire_code = import_code + "\n".join(enum_code_list) + "\n" + master_class_code
|
89 |
exec(entire_code, globals())
|
|
|
8 |
|
9 |
from app.core.errors import VendorError
|
10 |
from app.schemas.requests import Attribute
|
11 |
+
from app.utils.logger import exception_to_str
|
12 |
|
13 |
|
14 |
def validate_json_data(data: Dict[str, Any], schema: Dict[str, Any]):
|
|
|
18 |
try:
|
19 |
jsonschema.validate(instance=data, schema=schema)
|
20 |
except jsonschema.ValidationError as e:
|
21 |
+
raise VendorError(f"Vendor generated invalid data {exception_to_str(e)}")
|
22 |
|
23 |
|
24 |
def validate_json_schema(schema: Dict[str, Any]):
|
|
|
75 |
if len(allowed_values) > 0:
|
76 |
enum_code = f"class {key.capitalize()}Enum(str, Enum):\n"
|
77 |
for allowed_value in allowed_values:
|
78 |
+
enum_name = (
|
79 |
+
allowed_value.replace(" ", "_")
|
80 |
+
.replace("-", "_")
|
81 |
+
.replace("&", "AND")
|
82 |
+
.replace("/", "_OR_")
|
83 |
+
.replace(":", "__")
|
84 |
+
.replace("+", "plus")
|
85 |
+
.replace(",", "_")
|
86 |
+
.replace(".", "_")
|
87 |
+
.replace("°", "degree")
|
88 |
+
.replace("(", "")
|
89 |
+
.replace(")", "")
|
90 |
+
.replace("'", "_")
|
91 |
+
.replace('%', "")
|
92 |
+
.replace("!", "")
|
93 |
+
.replace("?", "")
|
94 |
+
.replace("`", "")
|
95 |
+
.replace("~", "")
|
96 |
+
.replace(";", "")
|
97 |
+
.replace("<", "")
|
98 |
+
.replace(">", "")
|
99 |
+
.replace("[", "")
|
100 |
+
.replace("]", "")
|
101 |
+
.replace("{", "")
|
102 |
+
.replace("}", "")
|
103 |
+
.replace("\\", "")
|
104 |
+
.replace("|", "")
|
105 |
+
.replace('–', "_")
|
106 |
+
.replace('*', "_")
|
107 |
+
.upper()
|
108 |
+
)
|
109 |
+
|
110 |
+
if "'" in allowed_value:
|
111 |
+
enum_code += f' E{enum_name} = "{allowed_value}"\n'
|
112 |
+
else:
|
113 |
+
enum_code += f" E{enum_name} = '{allowed_value}'\n"
|
114 |
enum_code_list.append(enum_code)
|
115 |
data_type = f"{key.capitalize()}Enum"
|
116 |
|
117 |
if is_list:
|
118 |
data_type = f"List[{data_type}]"
|
119 |
|
120 |
+
if "'" in description:
|
121 |
+
master_class_code += (
|
122 |
+
f' {key}: {data_type} = Field(..., description="{description}")\n'
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
master_class_code += (
|
126 |
+
f" {key}: {data_type} = Field(..., description='{description}')\n"
|
127 |
+
)
|
128 |
|
129 |
entire_code = import_code + "\n".join(enum_code_list) + "\n" + master_class_code
|
130 |
exec(entire_code, globals())
|
app/services/base.py
CHANGED
@@ -45,9 +45,9 @@ class BaseAttributionService(ABC):
|
|
45 |
attributes_model,
|
46 |
ai_model,
|
47 |
img_urls,
|
48 |
-
product_taxonomy,
|
49 |
product_data,
|
50 |
-
# pil_images=pil_images, # temporarily removed
|
51 |
img_paths=img_paths,
|
52 |
)
|
53 |
validate_json_data(data, schema)
|
|
|
45 |
attributes_model,
|
46 |
ai_model,
|
47 |
img_urls,
|
48 |
+
product_taxonomy if product_taxonomy != "" else "main",
|
49 |
product_data,
|
50 |
+
# pil_images=pil_images, # temporarily removed to save cost
|
51 |
img_paths=img_paths,
|
52 |
)
|
53 |
validate_json_data(data, schema)
|
app/services/service_anthropic.py
CHANGED
@@ -14,14 +14,14 @@ from app.core.prompts import get_prompts
|
|
14 |
from app.services.base import BaseAttributionService
|
15 |
from app.utils.converter import product_data_to_str
|
16 |
from app.utils.image_processing import get_data_format, get_image_data
|
17 |
-
from app.utils.logger import setup_logger
|
18 |
|
19 |
-
|
20 |
-
if
|
21 |
weave_project_name = "cfai/attribution-exp"
|
22 |
-
elif
|
23 |
weave_project_name = "cfai/attribution-dev"
|
24 |
-
elif
|
25 |
weave_project_name = "cfai/attribution-prod"
|
26 |
|
27 |
weave.init(project_name=weave_project_name)
|
@@ -102,16 +102,32 @@ class AnthropicService(BaseAttributionService):
|
|
102 |
system=system_message,
|
103 |
tools=tools,
|
104 |
messages=messages,
|
|
|
|
|
|
|
105 |
)
|
106 |
except anthropic.BadRequestError as e:
|
107 |
raise BadRequestError(e.message)
|
108 |
except Exception as e:
|
109 |
-
raise VendorError(
|
|
|
|
|
110 |
|
111 |
for content in response.content:
|
112 |
if content.type == "tool_use":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
return content.input
|
114 |
|
|
|
|
|
|
|
|
|
115 |
@weave.op
|
116 |
async def follow_schema(self, schema, data):
|
117 |
logger.info("Following structure via Anthropic...")
|
@@ -146,7 +162,9 @@ class AnthropicService(BaseAttributionService):
|
|
146 |
messages=messages,
|
147 |
)
|
148 |
except Exception as e:
|
149 |
-
raise VendorError(
|
|
|
|
|
150 |
|
151 |
for content in response.content:
|
152 |
if content.type == "tool_use":
|
|
|
14 |
from app.services.base import BaseAttributionService
|
15 |
from app.utils.converter import product_data_to_str
|
16 |
from app.utils.image_processing import get_data_format, get_image_data
|
17 |
+
from app.utils.logger import exception_to_str, setup_logger
|
18 |
|
19 |
+
env = os.getenv("ENV", "LOCAL")
|
20 |
+
if env == "LOCAL": # local or demo
|
21 |
weave_project_name = "cfai/attribution-exp"
|
22 |
+
elif env == "DEV":
|
23 |
weave_project_name = "cfai/attribution-dev"
|
24 |
+
elif env == "PROD":
|
25 |
weave_project_name = "cfai/attribution-prod"
|
26 |
|
27 |
weave.init(project_name=weave_project_name)
|
|
|
102 |
system=system_message,
|
103 |
tools=tools,
|
104 |
messages=messages,
|
105 |
+
temperature=0.0,
|
106 |
+
top_p=1.0,
|
107 |
+
top_k=1,
|
108 |
)
|
109 |
except anthropic.BadRequestError as e:
|
110 |
raise BadRequestError(e.message)
|
111 |
except Exception as e:
|
112 |
+
raise VendorError(
|
113 |
+
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
114 |
+
)
|
115 |
|
116 |
for content in response.content:
|
117 |
if content.type == "tool_use":
|
118 |
+
if content.input is None or not content.input:
|
119 |
+
raise VendorError(
|
120 |
+
errors.VENDOR_THROW_ERROR.format(
|
121 |
+
error_message="content.input is None or content.input is empty"
|
122 |
+
)
|
123 |
+
)
|
124 |
+
|
125 |
return content.input
|
126 |
|
127 |
+
raise VendorError(
|
128 |
+
errors.VENDOR_THROW_ERROR.format(error_message="No tool_use found")
|
129 |
+
)
|
130 |
+
|
131 |
@weave.op
|
132 |
async def follow_schema(self, schema, data):
|
133 |
logger.info("Following structure via Anthropic...")
|
|
|
162 |
messages=messages,
|
163 |
)
|
164 |
except Exception as e:
|
165 |
+
raise VendorError(
|
166 |
+
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
167 |
+
)
|
168 |
|
169 |
for content in response.content:
|
170 |
if content.type == "tool_use":
|
app/services/service_openai.py
CHANGED
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
|
9 |
|
10 |
from app.utils.converter import product_data_to_str
|
11 |
from app.utils.image_processing import get_data_format, get_image_data
|
12 |
-
from app.utils.logger import setup_logger
|
13 |
|
14 |
from ..config import get_settings
|
15 |
from ..core import errors
|
@@ -17,12 +17,12 @@ from ..core.errors import BadRequestError, VendorError
|
|
17 |
from ..core.prompts import get_prompts
|
18 |
from .base import BaseAttributionService
|
19 |
|
20 |
-
|
21 |
-
if
|
22 |
weave_project_name = "cfai/attribution-exp"
|
23 |
-
elif
|
24 |
weave_project_name = "cfai/attribution-dev"
|
25 |
-
elif
|
26 |
weave_project_name = "cfai/attribution-prod"
|
27 |
|
28 |
weave.init(project_name=weave_project_name)
|
@@ -111,11 +111,14 @@ class OpenAIService(BaseAttributionService):
|
|
111 |
logprobs=False,
|
112 |
# top_logprobs=2,
|
113 |
temperature=0.0,
|
|
|
114 |
)
|
115 |
except openai.BadRequestError as e:
|
116 |
-
raise BadRequestError(
|
117 |
except Exception as e:
|
118 |
-
raise VendorError(
|
|
|
|
|
119 |
|
120 |
try:
|
121 |
content = response.choices[0].message.content
|
@@ -157,7 +160,9 @@ class OpenAIService(BaseAttributionService):
|
|
157 |
temperature=0.0,
|
158 |
)
|
159 |
except Exception as e:
|
160 |
-
raise VendorError(
|
|
|
|
|
161 |
|
162 |
if response.choices[0].message.refusal:
|
163 |
logger.info("OpenAI refused to respond to the request")
|
|
|
9 |
|
10 |
from app.utils.converter import product_data_to_str
|
11 |
from app.utils.image_processing import get_data_format, get_image_data
|
12 |
+
from app.utils.logger import exception_to_str, setup_logger
|
13 |
|
14 |
from ..config import get_settings
|
15 |
from ..core import errors
|
|
|
17 |
from ..core.prompts import get_prompts
|
18 |
from .base import BaseAttributionService
|
19 |
|
20 |
+
env = os.getenv("ENV", "LOCAL")
|
21 |
+
if env == "LOCAL": # local or demo
|
22 |
weave_project_name = "cfai/attribution-exp"
|
23 |
+
elif env == "DEV":
|
24 |
weave_project_name = "cfai/attribution-dev"
|
25 |
+
elif env == "PROD":
|
26 |
weave_project_name = "cfai/attribution-prod"
|
27 |
|
28 |
weave.init(project_name=weave_project_name)
|
|
|
111 |
logprobs=False,
|
112 |
# top_logprobs=2,
|
113 |
temperature=0.0,
|
114 |
+
top_p=1,
|
115 |
)
|
116 |
except openai.BadRequestError as e:
|
117 |
+
raise BadRequestError(exception_to_str(e))
|
118 |
except Exception as e:
|
119 |
+
raise VendorError(
|
120 |
+
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
121 |
+
)
|
122 |
|
123 |
try:
|
124 |
content = response.choices[0].message.content
|
|
|
160 |
temperature=0.0,
|
161 |
)
|
162 |
except Exception as e:
|
163 |
+
raise VendorError(
|
164 |
+
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e))
|
165 |
+
)
|
166 |
|
167 |
if response.choices[0].message.refusal:
|
168 |
logger.info("OpenAI refused to respond to the request")
|
app/utils/logger.py
CHANGED
@@ -1,6 +1,24 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
from logging.handlers import RotatingFileHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
# Configure logger
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
from logging.handlers import RotatingFileHandler
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
|
7 |
+
def exception_to_str(e: Exception, max_lines: int = 12) -> str:
|
8 |
+
"""
|
9 |
+
Convert an exception to a string, limiting the number of lines.
|
10 |
+
"""
|
11 |
+
|
12 |
+
exception_message = "\n".join(str(e).splitlines()[:2])
|
13 |
+
stack_trace = "".join(traceback.format_tb(e.__traceback__))
|
14 |
+
|
15 |
+
exception_str = exception_message + "\n" + stack_trace
|
16 |
+
|
17 |
+
lines = exception_str.splitlines()
|
18 |
+
if len(lines) > max_lines:
|
19 |
+
exception_str = "\n".join(lines[:max_lines]) + "\n... (truncated)"
|
20 |
+
|
21 |
+
return exception_str
|
22 |
|
23 |
|
24 |
# Configure logger
|