Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from sentence_transformers import SentenceTransformer | |
import faiss # For similarity search | |
from itertools import product | |
# Load BLIP model for image captioning | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
# Load Sentence Transformer for embeddings | |
sbert_model = SentenceTransformer('all-MiniLM-L6-v2') | |
def process_images(uploaded_files): | |
image_filenames = [] | |
image_captions = [] | |
image_data = {} # Local to this function | |
# Process each uploaded image | |
for idx, img_file in enumerate(uploaded_files): | |
# Open image using PIL from the file path | |
image = Image.open(img_file.name).convert('RGB') | |
# Prepare the image for BLIP | |
inputs = blip_processor(images=image, return_tensors="pt").to(device) | |
# Generate caption using BLIP | |
with torch.no_grad(): | |
out = blip_model.generate(**inputs) | |
caption = blip_processor.decode(out[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
# Store filename and caption in the dictionary | |
filename = f'Image-{idx}' | |
image_filenames.append(filename) | |
image_captions.append(caption) | |
image_data[filename] = { | |
'image': image, | |
'caption': caption | |
} | |
# Compute embeddings for image captions using Sentence Transformer | |
if image_captions: | |
caption_embeddings = sbert_model.encode(image_captions, convert_to_tensor=True) | |
caption_embeddings_np = caption_embeddings.cpu().numpy() | |
# Initialize FAISS vector store | |
embedding_dim = caption_embeddings_np.shape[1] | |
vector_store = faiss.IndexFlatIP(embedding_dim) # Inner Product for cosine similarity | |
faiss.normalize_L2(caption_embeddings_np) | |
vector_store.add(caption_embeddings_np) | |
# Return image data, vector store, image filenames and captions for use in the next function | |
return image_data, vector_store, image_filenames, image_captions, "Images processed successfully!" | |
else: | |
return None, None, None, None, "No images were processed." | |
def recommend_outfits(user_query, image_data, vector_store, image_filenames, image_captions): | |
if vector_store is None or len(image_data) == 0: | |
return [], "Please upload images of your clothing first." | |
# Encode user query into the same embedding space | |
query_embedding = sbert_model.encode([user_query], convert_to_tensor=True) | |
query_embedding_np = query_embedding.cpu().numpy() | |
faiss.normalize_L2(query_embedding_np) | |
# Perform similarity search in the FAISS index | |
k = min(10, len(image_data)) # Retrieve top-k results | |
distances, indices = vector_store.search(query_embedding_np, k) | |
# Get retrieved filenames and captions | |
retrieved_filenames = [image_filenames[idx] for idx in indices[0]] | |
retrieved_captions = [image_captions[idx] for idx in indices[0]] | |
# Categorize retrieved items | |
categories = { | |
'tops': ['shirt', 't-shirt', 'jacket', 'sweater', 'blouse', 'coat'], | |
'bottoms': ['jeans', 'pants', 'shorts', 'skirt', 'trousers', 'chino'], | |
'dresses': ['dress', 'gown'], | |
'footwear': ['shoes', 'sneakers', 'boots', 'heels'], | |
'accessories': ['hat', 'sunglasses', 'scarf', 'belt', 'bag'], | |
} | |
items_by_category = {cat: [] for cat in categories} | |
for filename, caption in zip(retrieved_filenames, retrieved_captions): | |
matched = False | |
for category, keywords in categories.items(): | |
if any(keyword in caption.lower() for keyword in keywords): | |
items_by_category[category].append((filename, caption)) | |
matched = True | |
break | |
if not matched: | |
items_by_category.setdefault('others', []).append((filename, caption)) | |
# Generate combinations | |
combinations = [] | |
if items_by_category['dresses']: | |
# Outfits with dresses and footwear | |
combinations = list(product(items_by_category['dresses'], items_by_category['footwear'])) | |
else: | |
# Tops and bottoms | |
combinations = list(product(items_by_category['tops'], items_by_category['bottoms'])) | |
# Optionally include footwear if available | |
if items_by_category['footwear']: | |
combinations = list(product(items_by_category['tops'], items_by_category['bottoms'], items_by_category['footwear'])) | |
# Prepare output images and captions | |
outputs = [] | |
if combinations: | |
for outfit in combinations[:3]: # Limit to top 3 recommendations | |
images = [] | |
captions = [] | |
for item in outfit: | |
filename, caption = item | |
images.append(image_data[filename]['image']) | |
captions.append(caption) | |
outputs.append((images, captions)) | |
else: | |
return [], "Not enough items to generate outfit combinations based on the current categories and retrieved items." | |
# Prepare the outputs for Gradio components | |
output_images = [] | |
output_texts = [] | |
for images, captions in outputs: | |
# Combine images horizontally | |
widths, heights = zip(*(img.size for img in images)) | |
total_width = sum(widths) | |
max_height = max(heights) | |
new_im = Image.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for img in images: | |
new_im.paste(img, (x_offset, 0)) | |
x_offset += img.size[0] | |
output_images.append(new_im) | |
output_texts.append('\n'.join(captions)) | |
return output_images, '\n\n'.join(output_texts) | |
# Gradio Interface setup using Blocks | |
def gradio_app(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# RAG-Based Outfit Recommendation System") | |
gr.Markdown("Ever spend too much time in the morning or before going out trying to decide what to wear ? ") | |
gr.Markdown("Well think no more just upload images of your clothing items, like a virtual wardrobe and describe in natural language what kind of outfit you need.") | |
gr.Markdown("Example: I want something classy but with bright colors for a date night") | |
gr.Markdown("The system works like your typical RAG, your clothing items are embedded in a vector space based on their captions generated by a VLM") | |
gr.Markdown("Then we match your query to the best matching combinations of items from your wardrobe") | |
image_data_state = gr.State() | |
vector_store_state = gr.State() | |
image_filenames_state = gr.State() | |
image_captions_state = gr.State() | |
with gr.Tab("Step 1: Upload Images"): | |
with gr.Row(): | |
image_input = gr.File(type="filepath", label="Upload Your Clothing Images", file_count="multiple") | |
process_button = gr.Button("Process Images") | |
output_message = gr.Textbox(label="Status") | |
process_button.click( | |
fn=process_images, | |
inputs=image_input, | |
outputs=[image_data_state, vector_store_state, image_filenames_state, image_captions_state, output_message] | |
) | |
with gr.Tab("Step 2: Get Recommendations"): | |
user_query = gr.Textbox(lines=2, placeholder="Enter your outfit preference (e.g., casual for a date)", label="Describe Your Outfit Request") | |
recommend_button = gr.Button("Recommend Outfits") | |
output_gallery = gr.Gallery(label="Recommended Outfit Combinations") | |
output_descriptions = gr.Textbox(label="Item Descriptions") | |
recommend_button.click( | |
fn=recommend_outfits, | |
inputs=[user_query, image_data_state, vector_store_state, image_filenames_state, image_captions_state], | |
outputs=[output_gallery, output_descriptions] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = gradio_app() | |
demo.launch(share=True) | |