import gradio as gr import torch from PIL import Image from transformers import BlipProcessor, BlipForConditionalGeneration from sentence_transformers import SentenceTransformer import faiss # For similarity search from itertools import product # Load BLIP model for image captioning device = 'cuda' if torch.cuda.is_available() else 'cpu' blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") # Load Sentence Transformer for embeddings sbert_model = SentenceTransformer('all-MiniLM-L6-v2') def process_images(uploaded_files): image_filenames = [] image_captions = [] image_data = {} # Local to this function # Process each uploaded image for idx, img_file in enumerate(uploaded_files): # Open image using PIL from the file path image = Image.open(img_file.name).convert('RGB') # Prepare the image for BLIP inputs = blip_processor(images=image, return_tensors="pt").to(device) # Generate caption using BLIP with torch.no_grad(): out = blip_model.generate(**inputs) caption = blip_processor.decode(out[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) # Store filename and caption in the dictionary filename = f'Image-{idx}' image_filenames.append(filename) image_captions.append(caption) image_data[filename] = { 'image': image, 'caption': caption } # Compute embeddings for image captions using Sentence Transformer if image_captions: caption_embeddings = sbert_model.encode(image_captions, convert_to_tensor=True) caption_embeddings_np = caption_embeddings.cpu().numpy() # Initialize FAISS vector store embedding_dim = caption_embeddings_np.shape[1] vector_store = faiss.IndexFlatIP(embedding_dim) # Inner Product for cosine similarity faiss.normalize_L2(caption_embeddings_np) vector_store.add(caption_embeddings_np) # Return image data, vector store, image filenames and captions for use in the next function return image_data, vector_store, image_filenames, image_captions, "Images processed successfully!" else: return None, None, None, None, "No images were processed." def recommend_outfits(user_query, image_data, vector_store, image_filenames, image_captions): if vector_store is None or len(image_data) == 0: return [], "Please upload images of your clothing first." # Encode user query into the same embedding space query_embedding = sbert_model.encode([user_query], convert_to_tensor=True) query_embedding_np = query_embedding.cpu().numpy() faiss.normalize_L2(query_embedding_np) # Perform similarity search in the FAISS index k = min(10, len(image_data)) # Retrieve top-k results distances, indices = vector_store.search(query_embedding_np, k) # Get retrieved filenames and captions retrieved_filenames = [image_filenames[idx] for idx in indices[0]] retrieved_captions = [image_captions[idx] for idx in indices[0]] # Categorize retrieved items categories = { 'tops': ['shirt', 't-shirt', 'jacket', 'sweater', 'blouse', 'coat'], 'bottoms': ['jeans', 'pants', 'shorts', 'skirt', 'trousers', 'chino'], 'dresses': ['dress', 'gown'], 'footwear': ['shoes', 'sneakers', 'boots', 'heels'], 'accessories': ['hat', 'sunglasses', 'scarf', 'belt', 'bag'], } items_by_category = {cat: [] for cat in categories} for filename, caption in zip(retrieved_filenames, retrieved_captions): matched = False for category, keywords in categories.items(): if any(keyword in caption.lower() for keyword in keywords): items_by_category[category].append((filename, caption)) matched = True break if not matched: items_by_category.setdefault('others', []).append((filename, caption)) # Generate combinations combinations = [] if items_by_category['dresses']: # Outfits with dresses and footwear combinations = list(product(items_by_category['dresses'], items_by_category['footwear'])) else: # Tops and bottoms combinations = list(product(items_by_category['tops'], items_by_category['bottoms'])) # Optionally include footwear if available if items_by_category['footwear']: combinations = list(product(items_by_category['tops'], items_by_category['bottoms'], items_by_category['footwear'])) # Prepare output images and captions outputs = [] if combinations: for outfit in combinations[:3]: # Limit to top 3 recommendations images = [] captions = [] for item in outfit: filename, caption = item images.append(image_data[filename]['image']) captions.append(caption) outputs.append((images, captions)) else: return [], "Not enough items to generate outfit combinations based on the current categories and retrieved items." # Prepare the outputs for Gradio components output_images = [] output_texts = [] for images, captions in outputs: # Combine images horizontally widths, heights = zip(*(img.size for img in images)) total_width = sum(widths) max_height = max(heights) new_im = Image.new('RGB', (total_width, max_height)) x_offset = 0 for img in images: new_im.paste(img, (x_offset, 0)) x_offset += img.size[0] output_images.append(new_im) output_texts.append('\n'.join(captions)) return output_images, '\n\n'.join(output_texts) # Gradio Interface setup using Blocks def gradio_app(): with gr.Blocks() as demo: gr.Markdown("# RAG-Based Outfit Recommendation System") gr.Markdown("Ever spend too much time in the morning or before going out trying to decide what to wear ? ") gr.Markdown("Well think no more just upload images of your clothing items, like a virtual wardrobe and describe in natural language what kind of outfit you need.") gr.Markdown("Example: I want something classy but with bright colors for a date night") gr.Markdown("The system works like your typical RAG, your clothing items are embedded in a vector space based on their captions generated by a VLM") gr.Markdown("Then we match your query to the best matching combinations of items from your wardrobe") image_data_state = gr.State() vector_store_state = gr.State() image_filenames_state = gr.State() image_captions_state = gr.State() with gr.Tab("Step 1: Upload Images"): with gr.Row(): image_input = gr.File(type="filepath", label="Upload Your Clothing Images", file_count="multiple") process_button = gr.Button("Process Images") output_message = gr.Textbox(label="Status") process_button.click( fn=process_images, inputs=image_input, outputs=[image_data_state, vector_store_state, image_filenames_state, image_captions_state, output_message] ) with gr.Tab("Step 2: Get Recommendations"): user_query = gr.Textbox(lines=2, placeholder="Enter your outfit preference (e.g., casual for a date)", label="Describe Your Outfit Request") recommend_button = gr.Button("Recommend Outfits") output_gallery = gr.Gallery(label="Recommended Outfit Combinations") output_descriptions = gr.Textbox(label="Item Descriptions") recommend_button.click( fn=recommend_outfits, inputs=[user_query, image_data_state, vector_store_state, image_filenames_state, image_captions_state], outputs=[output_gallery, output_descriptions] ) return demo if __name__ == "__main__": demo = gradio_app() demo.launch(share=True)