|
import os |
|
|
|
os.environ["HUGGINGFACE_DEMO"] = "1" |
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
|
|
import gradio as gr |
|
import uuid |
|
import shutil |
|
|
|
from app.config import get_settings |
|
from app.schemas.requests import Attribute |
|
from app.request_handler import handle_extract |
|
from app.services.factory import AIServiceFactory |
|
|
|
|
|
settings = get_settings() |
|
IMAGE_MAX_SIZE = 1536 |
|
|
|
|
|
async def forward_request(attributes, product_taxonomy, product_data, ai_model, pil_images): |
|
|
|
request_id = str(uuid.uuid4()) |
|
request_temp_folder = os.path.join('gradio_temp', request_id) |
|
os.makedirs(request_temp_folder, exist_ok=True) |
|
|
|
try: |
|
|
|
attributes = "attributes_object = {" + attributes + "}" |
|
try: |
|
attributes = exec(attributes, globals()) |
|
except: |
|
raise gr.Error("Invalid `Attribute Schema`. Please insert valid schema following the example.") |
|
for key, value in attributes_object.items(): |
|
attributes_object[key] = Attribute(**value) |
|
|
|
if product_data == "": |
|
product_data = "{}" |
|
product_data_code = f"product_data_object = {product_data}" |
|
|
|
try: |
|
exec(product_data_code, globals()) |
|
except: |
|
raise gr.Error('Invalid `Product Data`. Please insert valid dictionary or leave it empty.') |
|
|
|
if pil_images is None: |
|
raise gr.Error('Please upload image(s) of the product') |
|
pil_images = [pil_image[0] for pil_image in pil_images] |
|
img_paths = [] |
|
for i, pil_image in enumerate(pil_images): |
|
if max(pil_image.size) > IMAGE_MAX_SIZE: |
|
ratio = IMAGE_MAX_SIZE / max(pil_image.size) |
|
pil_image = pil_image.resize((int(pil_image.width * ratio), int(pil_image.height * ratio))) |
|
img_path = os.path.join(request_temp_folder, f'{i}.jpg') |
|
if pil_image.mode in ('RGBA', 'LA') or (pil_image.mode == 'P' and 'transparency' in pil_image.info): |
|
pil_image = pil_image.convert("RGBA") |
|
if pil_image.getchannel("A").getextrema() == (255, 255): |
|
pil_image = pil_image.convert("RGB") |
|
image_format = 'JPEG' |
|
else: |
|
image_format = 'PNG' |
|
else: |
|
image_format = 'JPEG' |
|
pil_image.save(img_path, image_format, quality=100, subsampling=0) |
|
img_paths.append(img_path) |
|
|
|
|
|
if ai_model in settings.OPENAI_MODELS: |
|
ai_vendor = 'openai' |
|
elif ai_model in settings.ANTHROPIC_MODELS: |
|
ai_vendor = 'anthropic' |
|
service = AIServiceFactory.get_service(ai_vendor) |
|
|
|
try: |
|
json_attributes = await service.extract_attributes_with_validation( |
|
attributes_object, |
|
ai_model, |
|
None, |
|
product_taxonomy, |
|
product_data_object, |
|
img_paths=img_paths, |
|
) |
|
except: |
|
raise gr.Error('Failed to extract attributes. Something went wrong.') |
|
finally: |
|
|
|
shutil.rmtree(request_temp_folder) |
|
|
|
gr.Info('Process completed!') |
|
return json_attributes |
|
|
|
|
|
def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values): |
|
schema = f""" |
|
"{attr_name}": {{ |
|
"description": "{attr_desc}", |
|
"data_type": "{attr_type}", |
|
"allowed_values": [ |
|
{', '.join([f'"{v.strip()}"' for v in allowed_values.split(',')])} |
|
] |
|
}}, |
|
""" |
|
return attributes + schema, "", "", "", "" |
|
|
|
|
|
sample_schema = """"category": { |
|
"description": "Category of the garment", |
|
"data_type": "list[string]", |
|
"allowed_values": [ |
|
"upper garment", "lower garment", "footwear", "accessory", "headwear", "dresses" |
|
] |
|
}, |
|
|
|
"color": { |
|
"description": "Color of the garment", |
|
"data_type": "list[string]", |
|
"allowed_values": [ |
|
"black", "white", "red", "blue", "green", "yellow", "pink", "purple", "orange", "brown", "grey", "beige", "multi-color", "other" |
|
] |
|
}, |
|
|
|
"pattern": { |
|
"description": "Pattern of the garment", |
|
"data_type": "list[string]", |
|
"allowed_values": [ |
|
"plain", "striped", "checkered", "floral", "polka dot", "camouflage", "animal print", "abstract", "other" |
|
] |
|
}, |
|
|
|
"material": { |
|
"description": "Material of the garment", |
|
"data_type": "string", |
|
"allowed_values": [] |
|
} |
|
""" |
|
description = """ |
|
This is a simple demo for Attribution. Follow the steps below: |
|
|
|
1. Upload image(s) of a product. |
|
2. Enter the product taxonomy (e.g. 'upper garment', 'lower garment', 'bag'). If only one product is in the image, you can leave this field empty. |
|
3. Select the AI model to use. |
|
4. Enter known attributes (optional). |
|
5. Enter the attribute schema or use the "Add Attributes" section to add attributes. |
|
6. Click "Extract Attributes" to get the extracted attributes. |
|
""" |
|
|
|
product_data_placeholder = """Example: |
|
{ |
|
"brand": "Leaf", |
|
"size": "M", |
|
"product_name": "Leaf T-shirt", |
|
"color": "red" |
|
} |
|
""" |
|
product_data_value = """ |
|
{ |
|
"data1": "", |
|
"data2": "" |
|
} |
|
""" |
|
|
|
with gr.Blocks(title="Internal Demo for Attribution") as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=12): |
|
gr.Markdown( |
|
"""<div style="text-align: center; font-size: 24px;"><strong>Internal Demo for Attribution</strong></div>""" |
|
) |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=12): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gallery = gr.Gallery( |
|
label="Upload images of your product here", type="pil" |
|
) |
|
product_taxnomy = gr.Textbox( |
|
label="Product Taxonomy", |
|
placeholder="Enter product taxonomy here (e.g. 'upper garment', 'lower garment', 'bag')", |
|
lines=1, |
|
max_lines=1, |
|
) |
|
ai_model = gr.Dropdown( |
|
label="AI Model", |
|
choices=settings.SUPPORTED_MODELS, |
|
interactive=True, |
|
) |
|
product_data = gr.TextArea( |
|
label="Product Data (Optional)", |
|
placeholder=product_data_placeholder, |
|
value=product_data_value.strip(), |
|
interactive=True, |
|
lines=10, |
|
max_lines=10, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
attributes = gr.TextArea( |
|
label="Attribute Schema", |
|
value=sample_schema, |
|
placeholder="Enter schema here or use Add Attributes below", |
|
interactive=True, |
|
lines=30, |
|
max_lines=30, |
|
) |
|
|
|
with gr.Accordion("Add Attributes", open=False): |
|
attr_name = gr.Textbox( |
|
label="Attribute name", placeholder="Enter attribute name" |
|
) |
|
attr_desc = gr.Textbox( |
|
label="Description", placeholder="Enter description" |
|
) |
|
attr_type = gr.Dropdown( |
|
label="Type", |
|
choices=[ |
|
"string", |
|
"list[string]", |
|
"int", |
|
"list[int]", |
|
"float", |
|
"list[float]", |
|
"bool", |
|
"list[bool]", |
|
], |
|
interactive=True, |
|
) |
|
allowed_values = gr.Textbox( |
|
label="Allowed values (separated by comma)", |
|
placeholder="yellow, red, blue", |
|
) |
|
add_btn = gr.Button("Add Attribute") |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Extract Attributes") |
|
|
|
with gr.Column(scale=6): |
|
output_json = gr.Json( |
|
label="Extracted Attributes", value={}, show_indices=False |
|
) |
|
|
|
add_btn.click( |
|
add_attribute_schema, |
|
inputs=[attributes, attr_name, attr_desc, attr_type, allowed_values], |
|
outputs=[attributes, attr_name, attr_desc, attr_type, allowed_values], |
|
) |
|
|
|
submit_btn.click( |
|
forward_request, |
|
inputs=[attributes, product_taxnomy, product_data, ai_model, gallery], |
|
outputs=output_json, |
|
) |
|
|
|
demo.launch() |
|
|