OutfitRec / app.py
BechirTrabelsi1's picture
Update app.py
68ec7ef verified
import gradio as gr
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from sentence_transformers import SentenceTransformer
import faiss # For similarity search
from itertools import product
# Load BLIP model for image captioning
device = 'cuda' if torch.cuda.is_available() else 'cpu'
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# Load Sentence Transformer for embeddings
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
def process_images(uploaded_files):
image_filenames = []
image_captions = []
image_data = {} # Local to this function
# Process each uploaded image
for idx, img_file in enumerate(uploaded_files):
# Open image using PIL from the file path
image = Image.open(img_file.name).convert('RGB')
# Prepare the image for BLIP
inputs = blip_processor(images=image, return_tensors="pt").to(device)
# Generate caption using BLIP
with torch.no_grad():
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
# Store filename and caption in the dictionary
filename = f'Image-{idx}'
image_filenames.append(filename)
image_captions.append(caption)
image_data[filename] = {
'image': image,
'caption': caption
}
# Compute embeddings for image captions using Sentence Transformer
if image_captions:
caption_embeddings = sbert_model.encode(image_captions, convert_to_tensor=True)
caption_embeddings_np = caption_embeddings.cpu().numpy()
# Initialize FAISS vector store
embedding_dim = caption_embeddings_np.shape[1]
vector_store = faiss.IndexFlatIP(embedding_dim) # Inner Product for cosine similarity
faiss.normalize_L2(caption_embeddings_np)
vector_store.add(caption_embeddings_np)
# Return image data, vector store, image filenames and captions for use in the next function
return image_data, vector_store, image_filenames, image_captions, "Images processed successfully!"
else:
return None, None, None, None, "No images were processed."
def recommend_outfits(user_query, image_data, vector_store, image_filenames, image_captions):
if vector_store is None or len(image_data) == 0:
return [], "Please upload images of your clothing first."
# Encode user query into the same embedding space
query_embedding = sbert_model.encode([user_query], convert_to_tensor=True)
query_embedding_np = query_embedding.cpu().numpy()
faiss.normalize_L2(query_embedding_np)
# Perform similarity search in the FAISS index
k = min(10, len(image_data)) # Retrieve top-k results
distances, indices = vector_store.search(query_embedding_np, k)
# Get retrieved filenames and captions
retrieved_filenames = [image_filenames[idx] for idx in indices[0]]
retrieved_captions = [image_captions[idx] for idx in indices[0]]
# Categorize retrieved items
categories = {
'tops': ['shirt', 't-shirt', 'jacket', 'sweater', 'blouse', 'coat'],
'bottoms': ['jeans', 'pants', 'shorts', 'skirt', 'trousers', 'chino'],
'dresses': ['dress', 'gown'],
'footwear': ['shoes', 'sneakers', 'boots', 'heels'],
'accessories': ['hat', 'sunglasses', 'scarf', 'belt', 'bag'],
}
items_by_category = {cat: [] for cat in categories}
for filename, caption in zip(retrieved_filenames, retrieved_captions):
matched = False
for category, keywords in categories.items():
if any(keyword in caption.lower() for keyword in keywords):
items_by_category[category].append((filename, caption))
matched = True
break
if not matched:
items_by_category.setdefault('others', []).append((filename, caption))
# Generate combinations
combinations = []
if items_by_category['dresses']:
# Outfits with dresses and footwear
combinations = list(product(items_by_category['dresses'], items_by_category['footwear']))
else:
# Tops and bottoms
combinations = list(product(items_by_category['tops'], items_by_category['bottoms']))
# Optionally include footwear if available
if items_by_category['footwear']:
combinations = list(product(items_by_category['tops'], items_by_category['bottoms'], items_by_category['footwear']))
# Prepare output images and captions
outputs = []
if combinations:
for outfit in combinations[:3]: # Limit to top 3 recommendations
images = []
captions = []
for item in outfit:
filename, caption = item
images.append(image_data[filename]['image'])
captions.append(caption)
outputs.append((images, captions))
else:
return [], "Not enough items to generate outfit combinations based on the current categories and retrieved items."
# Prepare the outputs for Gradio components
output_images = []
output_texts = []
for images, captions in outputs:
# Combine images horizontally
widths, heights = zip(*(img.size for img in images))
total_width = sum(widths)
max_height = max(heights)
new_im = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in images:
new_im.paste(img, (x_offset, 0))
x_offset += img.size[0]
output_images.append(new_im)
output_texts.append('\n'.join(captions))
return output_images, '\n\n'.join(output_texts)
# Gradio Interface setup using Blocks
def gradio_app():
with gr.Blocks() as demo:
gr.Markdown("# RAG-Based Outfit Recommendation System")
gr.Markdown("Ever spend too much time in the morning or before going out trying to decide what to wear ? ")
gr.Markdown("Well think no more just upload images of your clothing items, like a virtual wardrobe and describe in natural language what kind of outfit you need.")
gr.Markdown("Example: I want something classy but with bright colors for a date night")
gr.Markdown("The system works like your typical RAG, your clothing items are embedded in a vector space based on their captions generated by a VLM")
gr.Markdown("Then we match your query to the best matching combinations of items from your wardrobe")
image_data_state = gr.State()
vector_store_state = gr.State()
image_filenames_state = gr.State()
image_captions_state = gr.State()
with gr.Tab("Step 1: Upload Images"):
with gr.Row():
image_input = gr.File(type="filepath", label="Upload Your Clothing Images", file_count="multiple")
process_button = gr.Button("Process Images")
output_message = gr.Textbox(label="Status")
process_button.click(
fn=process_images,
inputs=image_input,
outputs=[image_data_state, vector_store_state, image_filenames_state, image_captions_state, output_message]
)
with gr.Tab("Step 2: Get Recommendations"):
user_query = gr.Textbox(lines=2, placeholder="Enter your outfit preference (e.g., casual for a date)", label="Describe Your Outfit Request")
recommend_button = gr.Button("Recommend Outfits")
output_gallery = gr.Gallery(label="Recommended Outfit Combinations")
output_descriptions = gr.Textbox(label="Item Descriptions")
recommend_button.click(
fn=recommend_outfits,
inputs=[user_query, image_data_state, vector_store_state, image_filenames_state, image_captions_state],
outputs=[output_gallery, output_descriptions]
)
return demo
if __name__ == "__main__":
demo = gradio_app()
demo.launch(share=True)