Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["HUGGINGFACE_DEMO"] = "1" # set before import from app | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| ################################################################################################ | |
| import gradio as gr | |
| import uuid | |
| import shutil | |
| from app.config import get_settings | |
| from app.schemas.requests import Attribute | |
| from app.request_handler import handle_extract | |
| from app.services.factory import AIServiceFactory | |
| settings = get_settings() | |
| IMAGE_MAX_SIZE = 1536 | |
| async def forward_request( | |
| attributes, product_taxonomy, product_data, ai_model, pil_images | |
| ): | |
| # prepare temp folder | |
| request_id = str(uuid.uuid4()) | |
| request_temp_folder = os.path.join("gradio_temp", request_id) | |
| os.makedirs(request_temp_folder, exist_ok=True) | |
| try: | |
| # convert attributes to schema | |
| attributes = "attributes_object = {" + attributes + "}" | |
| try: | |
| attributes = exec(attributes, globals()) | |
| except: | |
| raise gr.Error( | |
| "Invalid `Attribute Schema`. Please insert valid schema following the example." | |
| ) | |
| for key, value in attributes_object.items(): # type: ignore | |
| attributes_object[key] = Attribute(**value) # type: ignore | |
| if product_data == "": | |
| product_data = "{}" | |
| product_data_code = f"product_data_object = {product_data}" | |
| try: | |
| exec(product_data_code, globals()) | |
| except: | |
| raise gr.Error( | |
| "Invalid `Product Data`. Please insert valid dictionary or leave it empty." | |
| ) | |
| if pil_images is None: | |
| raise gr.Error("Please upload image(s) of the product") | |
| pil_images = [pil_image[0] for pil_image in pil_images] | |
| img_paths = [] | |
| for i, pil_image in enumerate(pil_images): | |
| if max(pil_image.size) > IMAGE_MAX_SIZE: | |
| ratio = IMAGE_MAX_SIZE / max(pil_image.size) | |
| pil_image = pil_image.resize( | |
| (int(pil_image.width * ratio), int(pil_image.height * ratio)) | |
| ) | |
| img_path = os.path.join(request_temp_folder, f"{i}.jpg") | |
| if pil_image.mode in ("RGBA", "LA") or ( | |
| pil_image.mode == "P" and "transparency" in pil_image.info | |
| ): | |
| pil_image = pil_image.convert("RGBA") | |
| if pil_image.getchannel("A").getextrema() == ( | |
| 255, | |
| 255, | |
| ): # if fully opaque, save as JPEG | |
| pil_image = pil_image.convert("RGB") | |
| image_format = "JPEG" | |
| else: | |
| image_format = "PNG" | |
| else: | |
| image_format = "JPEG" | |
| pil_image.save(img_path, image_format, quality=100, subsampling=0) | |
| img_paths.append(img_path) | |
| # mapping | |
| if ai_model in settings.OPENAI_MODELS: | |
| ai_vendor = "openai" | |
| elif ai_model in settings.ANTHROPIC_MODELS: | |
| ai_vendor = "anthropic" | |
| service = AIServiceFactory.get_service(ai_vendor) | |
| try: | |
| json_attributes = await service.extract_attributes_with_validation( | |
| attributes_object, # type: ignore | |
| ai_model, | |
| None, | |
| product_taxonomy, | |
| product_data_object, # type: ignore | |
| img_paths=img_paths, | |
| ) | |
| except: | |
| raise gr.Error("Failed to extract attributes. Something went wrong.") | |
| finally: | |
| # remove temp folder anyway | |
| shutil.rmtree(request_temp_folder) | |
| gr.Info("Process completed!") | |
| return json_attributes | |
| def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values): | |
| schema = f""" | |
| "{attr_name}": {{ | |
| "description": "{attr_desc}", | |
| "data_type": "{attr_type}", | |
| "allowed_values": [ | |
| {', '.join([f'"{v.strip()}"' for v in allowed_values.split(',')]) if allowed_values != "" else ""} | |
| ] | |
| }}, | |
| """ | |
| return attributes + schema, "", "", "", "" | |
| sample_schema = """"category": { | |
| "description": "Category of the garment", | |
| "data_type": "list[string]", | |
| "allowed_values": [ | |
| "upper garment", "lower garment", "footwear", "accessory", "headwear", "dresses" | |
| ] | |
| }, | |
| "color": { | |
| "description": "Color of the garment", | |
| "data_type": "list[string]", | |
| "allowed_values": [ | |
| "black", "white", "red", "blue", "green", "yellow", "pink", "purple", "orange", "brown", "grey", "beige", "multi-color", "other" | |
| ] | |
| }, | |
| "pattern": { | |
| "description": "Pattern of the garment", | |
| "data_type": "list[string]", | |
| "allowed_values": [ | |
| "plain", "striped", "checkered", "floral", "polka dot", "camouflage", "animal print", "abstract", "other" | |
| ] | |
| }, | |
| "material": { | |
| "description": "Material of the garment", | |
| "data_type": "string", | |
| "allowed_values": [] | |
| } | |
| """ | |
| description = """ | |
| This is a simple demo for Attribution. Follow the steps below: | |
| 1. Upload image(s) of a product. | |
| 2. Enter the product taxonomy (e.g. 'upper garment', 'lower garment', 'bag'). If only one product is in the image, you can leave this field empty. | |
| 3. Select the AI model to use. | |
| 4. Enter known attributes (optional). | |
| 5. Enter the attribute schema or use the "Add Attributes" section to add attributes. | |
| 6. Click "Extract Attributes" to get the extracted attributes. | |
| """ | |
| product_data_placeholder = """Example: | |
| { | |
| "brand": "Leaf", | |
| "size": "M", | |
| "product_name": "Leaf T-shirt", | |
| "color": "red" | |
| } | |
| """ | |
| product_data_value = """ | |
| { | |
| "data1": "", | |
| "data2": "" | |
| } | |
| """ | |
| with gr.Blocks(title="Internal Demo for Attribution") as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=12): | |
| gr.Markdown( | |
| """<div style="text-align: center; font-size: 24px;"><strong>Internal Demo for Attribution</strong></div>""" | |
| ) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=12): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gallery = gr.Gallery( | |
| label="Upload images of your product here", type="pil" | |
| ) | |
| product_taxnomy = gr.Textbox( | |
| label="Product Taxonomy", | |
| placeholder="Enter product taxonomy here (e.g. 'upper garment', 'lower garment', 'bag')", | |
| lines=1, | |
| max_lines=1, | |
| ) | |
| ai_model = gr.Dropdown( | |
| label="AI Model", | |
| choices=settings.SUPPORTED_MODELS, | |
| interactive=True, | |
| ) | |
| product_data = gr.TextArea( | |
| label="Product Data (Optional)", | |
| placeholder=product_data_placeholder, | |
| value=product_data_value.strip(), | |
| interactive=True, | |
| lines=10, | |
| max_lines=10, | |
| ) | |
| # track_count = gr.State(1) | |
| # @gr.render(inputs=track_count) | |
| # def render_tracks(count): | |
| # ka_names = [] | |
| # ka_values = [] | |
| # with gr.Column(): | |
| # for i in range(count): | |
| # with gr.Column(variant="panel"): | |
| # with gr.Row(): | |
| # ka_name = gr.Textbox(placeholder="key", key=f"key-{i}", show_label=False) | |
| # ka_value = gr.Textbox(placeholder="data", key=f"data-{i}", show_label=False) | |
| # ka_names.append(ka_name) | |
| # ka_values.append(ka_value) | |
| # add_track_btn = gr.Button("Add Product Data") | |
| # remove_track_btn = gr.Button("Remove Product Data") | |
| # add_track_btn.click(lambda count: count + 1, track_count, track_count) | |
| # remove_track_btn.click(lambda count: count - 1, track_count, track_count) | |
| with gr.Column(): | |
| attributes = gr.TextArea( | |
| label="Attribute Schema", | |
| value=sample_schema, | |
| placeholder="Enter schema here or use Add Attributes below", | |
| interactive=True, | |
| lines=30, | |
| max_lines=30, | |
| ) | |
| with gr.Accordion("Add Attributes", open=False): | |
| attr_name = gr.Textbox( | |
| label="Attribute name", placeholder="Enter attribute name" | |
| ) | |
| attr_desc = gr.Textbox( | |
| label="Description", placeholder="Enter description" | |
| ) | |
| attr_type = gr.Dropdown( | |
| label="Type", | |
| choices=[ | |
| "string", | |
| "list[string]", | |
| "int", | |
| "list[int]", | |
| "float", | |
| "list[float]", | |
| "bool", | |
| "list[bool]", | |
| ], | |
| interactive=True, | |
| ) | |
| allowed_values = gr.Textbox( | |
| label="Allowed values (separated by comma)", | |
| placeholder="yellow, red, blue", | |
| ) | |
| add_btn = gr.Button("Add Attribute") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Extract Attributes") | |
| with gr.Column(scale=6): | |
| output_json = gr.Json( | |
| label="Extracted Attributes", value={}, show_indices=False | |
| ) | |
| add_btn.click( | |
| add_attribute_schema, | |
| inputs=[attributes, attr_name, attr_desc, attr_type, allowed_values], | |
| outputs=[attributes, attr_name, attr_desc, attr_type, allowed_values], | |
| ) | |
| submit_btn.click( | |
| forward_request, | |
| inputs=[attributes, product_taxnomy, product_data, ai_model, gallery], | |
| outputs=output_json, | |
| ) | |
| attr_user = os.getenv("ATTR_USER", "1") | |
| attr_pass = os.getenv("ATTR_PASS", "a") | |
| auth = (attr_user, attr_pass) | |
| demo.launch(auth=auth, debug=True, ssr_mode=False) | |