Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import CLIPProcessor, CLIPModel | |
| from datasets import load_dataset | |
| import torch | |
| # Load the pre-trained CLIP model and its tokenizer | |
| model_name = "openai/clip-vit-base-patch32" | |
| processor = CLIPProcessor.from_pretrained(model_name) | |
| model = CLIPModel.from_pretrained(model_name) | |
| # Load the fashion product images dataset from Hugging Face | |
| dataset = load_dataset("ashraq/fashion-product-images-small") | |
| deepfashion_database = dataset["train"] | |
| def preprocess_image(image): | |
| pil_image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| return preprocess(pil_image).unsqueeze(0) | |
| def encode_text(text): | |
| inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) | |
| return inputs | |
| def encode_image(image): | |
| inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True) | |
| return inputs | |
| def calculate_similarities(query_image, query_text): | |
| query_image_features = model.get_image_features(query_image) | |
| query_text_features = model.get_text_features(query_text) | |
| similarities = [] | |
| for product in deepfashion_database: | |
| product_image_features = torch.Tensor(product["image_features"]) | |
| product_text_features = torch.Tensor(product["text_features"]) | |
| image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features) | |
| text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features) | |
| similarity_score = image_similarity * text_similarity | |
| similarities.append(similarity_score) | |
| return similarities | |
| def initial_query(image, text): | |
| query_image = encode_image(image) | |
| query_text = encode_text(text) | |
| similarities = calculate_similarities(query_image, query_text) | |
| sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True) | |
| top_3_indices = sorted_indices[:3] | |
| top_3_products = [deepfashion_database[i] for i in top_3_indices] | |
| return top_3_products | |
| def send_message(txt, btn): | |
| if btn is not None: | |
| image = Image.open(btn) | |
| image = preprocess_image(image) | |
| else: | |
| image = None | |
| top_3_products = initial_query(image, txt) | |
| output_html = generate_output_html(top_3_products) | |
| chatbot.append_message("You", txt) | |
| chatbot.append_message("AI", output_html) | |
| chatbot = gr.Chatbot([]).style(height=750) | |
| txt = gr.Textbox(placeholder="Enter text and press enter, or upload an image", show_label=False) | |
| btn = gr.UploadButton("π", file_types=["image", "video", "audio"]) | |
| gr.Interface(send_message, inputs=[txt, btn], outputs=chatbot).launch() | |