Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- Dockerfile +17 -0
- requirements.txt +97 -0
- src/main.py +91 -0
- src/utils.py +247 -0
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11.8-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt .
|
6 |
+
|
7 |
+
RUN pip install --upgrade pip && \
|
8 |
+
pip install --no-cache-dir -r requirements.txt
|
9 |
+
|
10 |
+
COPY src /app/src
|
11 |
+
COPY images /app/images
|
12 |
+
COPY .env /app/.env
|
13 |
+
|
14 |
+
|
15 |
+
EXPOSE 8080
|
16 |
+
|
17 |
+
CMD ["python", "-m", "src.main"]
|
requirements.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.28.0
|
2 |
+
aiohttp==3.9.3
|
3 |
+
aiosignal==1.2.0
|
4 |
+
annotated-types==0.6.0
|
5 |
+
anyio==4.2.0
|
6 |
+
asttokens==2.0.5
|
7 |
+
attrdict==2.0.1
|
8 |
+
attrs==23.1.0
|
9 |
+
Bottleneck==1.3.7
|
10 |
+
Brotli==1.0.9
|
11 |
+
certifi==2024.2.2
|
12 |
+
charset-normalizer==2.0.4
|
13 |
+
click==8.1.7
|
14 |
+
comm==0.2.1
|
15 |
+
datasets==2.12.0
|
16 |
+
debugpy==1.6.7
|
17 |
+
decorator==5.1.1
|
18 |
+
dill==0.3.6
|
19 |
+
distro==1.8.0
|
20 |
+
einops==0.7.0
|
21 |
+
executing==0.8.3
|
22 |
+
faiss-cpu
|
23 |
+
filelock==3.13.1
|
24 |
+
Flask==2.2.2
|
25 |
+
frozenlist==1.4.0
|
26 |
+
fsspec==2023.10.0
|
27 |
+
geographiclib==2.0
|
28 |
+
geopy==2.4.1
|
29 |
+
h11==0.14.0
|
30 |
+
httpcore==1.0.2
|
31 |
+
httpx==0.27.0
|
32 |
+
huggingface-hub==0.20.3
|
33 |
+
idna==3.4
|
34 |
+
ipykernel==6.25.0
|
35 |
+
ipython==8.20.0
|
36 |
+
itsdangerous==2.0.1
|
37 |
+
jedi==0.18.1
|
38 |
+
Jinja2==3.1.3
|
39 |
+
jupyter_client==8.6.0
|
40 |
+
jupyter_core==5.5.0
|
41 |
+
MarkupSafe==2.1.3
|
42 |
+
matplotlib-inline==0.1.6
|
43 |
+
mpmath==1.3.0
|
44 |
+
multidict==6.0.4
|
45 |
+
multiprocess==0.70.14
|
46 |
+
nest-asyncio==1.6.0
|
47 |
+
networkx==3.1
|
48 |
+
numexpr==2.8.7
|
49 |
+
numpy==1.26.4
|
50 |
+
openai==1.16.2
|
51 |
+
packaging==23.2
|
52 |
+
pandas==2.2.1
|
53 |
+
parso==0.8.3
|
54 |
+
pexpect==4.8.0
|
55 |
+
Pillow==10.0.1
|
56 |
+
pip==23.3.1
|
57 |
+
platformdirs==3.10.0
|
58 |
+
prompt-toolkit==3.0.43
|
59 |
+
ptyprocess==0.7.0
|
60 |
+
pure-eval==0.2.2
|
61 |
+
pyarrow==14.0.2
|
62 |
+
Pygments==2.15.1
|
63 |
+
pymongo==3.12.0
|
64 |
+
PySocks==1.7.1
|
65 |
+
python-dateutil==2.8.2
|
66 |
+
python-dotenv==1.0.1
|
67 |
+
pytz==2023.3.post1
|
68 |
+
PyYAML==6.0.1
|
69 |
+
pyzmq==25.1.2
|
70 |
+
regex==2023.10.3
|
71 |
+
replicate==0.25.1
|
72 |
+
requests==2.31.0
|
73 |
+
responses==0.13.3
|
74 |
+
safetensors==0.4.2
|
75 |
+
sentencepiece==0.2.0
|
76 |
+
setuptools==68.2.2
|
77 |
+
six==1.16.0
|
78 |
+
sniffio==1.3.0
|
79 |
+
stack-data==0.2.0
|
80 |
+
sympy==1.12
|
81 |
+
timm==0.9.16
|
82 |
+
tokenizers==0.15.1
|
83 |
+
torch==2.2.2
|
84 |
+
torchaudio==2.2.2
|
85 |
+
torchvision==0.17.2
|
86 |
+
tornado==6.3.3
|
87 |
+
tqdm==4.65.0
|
88 |
+
traitlets==5.7.1
|
89 |
+
transformers==4.39.3
|
90 |
+
triton==2.2.0
|
91 |
+
typing_extensions==4.9.0
|
92 |
+
tzdata==2023.3
|
93 |
+
urllib3==2.1.0
|
94 |
+
wcwidth==0.2.5
|
95 |
+
Werkzeug==2.3.8
|
96 |
+
wheel==0.41.2
|
97 |
+
yarl==1.9.3
|
src/main.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from flask import Flask, request, jsonify
|
5 |
+
from openai import OpenAI
|
6 |
+
from pymongo import MongoClient
|
7 |
+
|
8 |
+
from src.utils import allowed_file, save_file, get_image_embeddings, \
|
9 |
+
get_clothing_type, save_data_to_db, fetch_weather, get_gender_by_username, \
|
10 |
+
prompt_gpt, get_outfit
|
11 |
+
|
12 |
+
|
13 |
+
UPLOAD_FOLDER = 'images'
|
14 |
+
|
15 |
+
app = Flask(__name__)
|
16 |
+
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
17 |
+
|
18 |
+
load_dotenv()
|
19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
20 |
+
os.environ["REPLICATE_API_TOKEN"] = os.getenv("REPLICATE_API_TOKEN")
|
21 |
+
|
22 |
+
MONGO_URI = "mongodb+srv://moda:[email protected]/?retryWrites=true&w=majority&appName=ClusterModa"
|
23 |
+
client = MongoClient(MONGO_URI)
|
24 |
+
db = client.moda
|
25 |
+
|
26 |
+
myclient = OpenAI(api_key=OPENAI_API_KEY)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
@app.route("/register", methods=['POST'])
|
31 |
+
def register_user():
|
32 |
+
collection = db.users
|
33 |
+
username = request.json.get('username')
|
34 |
+
gender = request.json.get("gender")
|
35 |
+
|
36 |
+
existing_user = collection.find_one({'_id': username})
|
37 |
+
if existing_user:
|
38 |
+
return jsonify({'message': 'Username already exists'}), 400
|
39 |
+
|
40 |
+
new_user = {'_id': username,
|
41 |
+
'gender': gender,
|
42 |
+
"closet": {}}
|
43 |
+
result = collection.insert_one(new_user)
|
44 |
+
|
45 |
+
return jsonify({'message': 'User registered successfully', 'user_id': str(result.inserted_id)}), 201
|
46 |
+
|
47 |
+
|
48 |
+
@app.route('/upload', methods=['POST'])
|
49 |
+
def upload_file():
|
50 |
+
if 'file' not in request.files:
|
51 |
+
return jsonify({'message': 'No file part'}), 400
|
52 |
+
file = request.files['file']
|
53 |
+
if file.filename == '':
|
54 |
+
return jsonify({'message': 'No selected file'}), 400
|
55 |
+
if not allowed_file(file.filename):
|
56 |
+
return jsonify({'message': 'File extension not allowed'}), 400
|
57 |
+
if 'username' not in request.form:
|
58 |
+
return jsonify({'message': 'No username provided'}), 400
|
59 |
+
|
60 |
+
file_path = save_file(file, app)
|
61 |
+
|
62 |
+
image = open(file_path, "rb")
|
63 |
+
clothing_type = get_clothing_type(image)
|
64 |
+
image_embeddings = get_image_embeddings(image)
|
65 |
+
data = {
|
66 |
+
"username": request.form["username"],
|
67 |
+
"image_path": file_path,
|
68 |
+
"type": clothing_type,
|
69 |
+
"image_embeddings": image_embeddings
|
70 |
+
}
|
71 |
+
save_data_to_db(data, db)
|
72 |
+
|
73 |
+
return jsonify({'message': 'File uploaded successfully'}), 200
|
74 |
+
|
75 |
+
|
76 |
+
@app.route("/recommend", methods=["POST"])
|
77 |
+
def recommend_outfit():
|
78 |
+
"""Takes as input a username, a context, latitude, and longitude"""
|
79 |
+
username = request.json.get('username')
|
80 |
+
context = request.json.get('context')
|
81 |
+
temperature = fetch_weather(float(request.json.get("latitude")), float(request.json.get("longitude")))
|
82 |
+
gender = get_gender_by_username(username, db)
|
83 |
+
|
84 |
+
outfit_description = prompt_gpt(myclient, gender, context, temperature)
|
85 |
+
outfit = get_outfit(outfit_description, username, db)
|
86 |
+
|
87 |
+
return jsonify({'outfit': outfit, 'message': "Recommendation Successful"}), 200
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
app.run(debug=True)
|
src/utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
|
4 |
+
import faiss
|
5 |
+
import numpy as np
|
6 |
+
import replicate
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from geopy.geocoders import Nominatim
|
9 |
+
from werkzeug.utils import secure_filename
|
10 |
+
|
11 |
+
|
12 |
+
CLOTHES_TYPES = ["tops", "bottoms", "shoes", "outerwear"]
|
13 |
+
load_dotenv()
|
14 |
+
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
|
15 |
+
|
16 |
+
|
17 |
+
def allowed_file(filename):
|
18 |
+
"""Checks if uploaded file is allowed"""
|
19 |
+
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
|
20 |
+
return '.' in filename and \
|
21 |
+
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
22 |
+
|
23 |
+
def save_file(file, app):
|
24 |
+
"""Makes filename unique and saves it"""
|
25 |
+
filename = secure_filename(file.filename)
|
26 |
+
name, extension = os.path.splitext(filename)
|
27 |
+
counter = 1
|
28 |
+
while os.path.exists(os.path.join(app.config['UPLOAD_FOLDER'], filename)):
|
29 |
+
filename = f"{name}_{counter}{extension}"
|
30 |
+
counter += 1
|
31 |
+
|
32 |
+
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
33 |
+
file.save(file_path)
|
34 |
+
|
35 |
+
return file_path
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def get_image_embeddings(image):
|
40 |
+
output = replicate.run(
|
41 |
+
"daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304",
|
42 |
+
input={
|
43 |
+
"input": image,
|
44 |
+
"modality": "vision"
|
45 |
+
}
|
46 |
+
)
|
47 |
+
|
48 |
+
return output
|
49 |
+
|
50 |
+
def get_text_embeddings(text):
|
51 |
+
output = replicate.run(
|
52 |
+
"daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304",
|
53 |
+
input={
|
54 |
+
"text_input": text,
|
55 |
+
"modality": "text"
|
56 |
+
}
|
57 |
+
)
|
58 |
+
|
59 |
+
return output
|
60 |
+
|
61 |
+
def get_clothing_type(image):
|
62 |
+
output = replicate.run(
|
63 |
+
"yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
|
64 |
+
input={
|
65 |
+
"image": image,
|
66 |
+
"prompt": f"What is this piece of clothing? Please select ONLY ONE CHOICE: {CLOTHES_TYPES}. \
|
67 |
+
If you are in doubt, just pick ONE OF THEM.\
|
68 |
+
\nIf you think it's outerwear but it doesn't have a zipper or buttons, it should \
|
69 |
+
be considered as: 'tops'. Keep in mind that you shouldn't write anything except one \
|
70 |
+
of the 4 choices, and no other text."
|
71 |
+
}
|
72 |
+
)
|
73 |
+
|
74 |
+
return "".join(output).lower()
|
75 |
+
|
76 |
+
def get_user_closet_length(query, collection):
|
77 |
+
user_doc = collection.find_one(query)
|
78 |
+
if user_doc:
|
79 |
+
closet = user_doc.get('closet', {})
|
80 |
+
return len(closet)
|
81 |
+
|
82 |
+
raise ValueError("User not found")
|
83 |
+
|
84 |
+
def save_data_to_db(data:dict, db):
|
85 |
+
collection = db.users
|
86 |
+
query = {'_id': data["username"]}
|
87 |
+
new_item = {
|
88 |
+
'path': data["image_path"],
|
89 |
+
'type': data["type"],
|
90 |
+
'embedding': data["image_embeddings"]
|
91 |
+
}
|
92 |
+
|
93 |
+
closet_length = get_user_closet_length(query, collection)
|
94 |
+
new_item_key = f"item{closet_length + 1}"
|
95 |
+
collection.update_one(
|
96 |
+
{"_id": data["username"]},
|
97 |
+
{"$set": {f"closet.{new_item_key}": new_item}}
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def get_city_from_coord(latitude, longitude):
|
103 |
+
geolocator = Nominatim(user_agent="city_name_app")
|
104 |
+
location = geolocator.reverse((latitude, longitude), exactly_one=True)
|
105 |
+
address = location.address if location else None
|
106 |
+
if address:
|
107 |
+
city = address.split(",")[-3]
|
108 |
+
return city.strip()
|
109 |
+
|
110 |
+
return None
|
111 |
+
|
112 |
+
def fetch_weather(latitude, longitude):
|
113 |
+
city_name = get_city_from_coord(latitude, longitude)
|
114 |
+
if not city_name:
|
115 |
+
return None
|
116 |
+
|
117 |
+
url = f"http://api.openweathermap.org/data/2.5/weather?q={city_name}&appid={WEATHER_API_KEY}&units=metric"
|
118 |
+
response = requests.get(url)
|
119 |
+
data = response.json()
|
120 |
+
|
121 |
+
return data["main"]["feels_like"]
|
122 |
+
|
123 |
+
def get_gender_by_username(username:str, db):
|
124 |
+
collection = db.users
|
125 |
+
document = collection.find_one({"_id": username})
|
126 |
+
if not document:
|
127 |
+
return None
|
128 |
+
|
129 |
+
return document.get("gender")
|
130 |
+
|
131 |
+
|
132 |
+
def prompt_gpt(client, gender, context, temperature):
|
133 |
+
"""given some context, returns a dictionary describing what you should wear in your outfit."""
|
134 |
+
outfit = {}
|
135 |
+
|
136 |
+
history = [{"role": "system", "content": "You are a fashion expert who is dedicated to picking outfits for a user. \
|
137 |
+
You will receive context from the user that will help you choose an outfit. \
|
138 |
+
An outfit contains four items: top, bottom, shoes, outwear. \
|
139 |
+
The outwear is optional and depends on the weather, if it is not needed, \
|
140 |
+
only write `none` and nothing else. You will need to provide a description for each item one by one.\
|
141 |
+
I want the description to include: color, style, fit, and material. Make sure the different items \
|
142 |
+
go well toghether in terms of style, color and other factors. the answer should be in this format: \
|
143 |
+
'description': "}]
|
144 |
+
|
145 |
+
prompt = {"role": "user", "content": f"-Context: `{context}`,\n-Gender: `{gender}`\
|
146 |
+
\n-Temperature: `{temperature}` degrees celsius.\
|
147 |
+
\nGiven this context, generate a description for what I should wear as a top:"}
|
148 |
+
|
149 |
+
history.append(prompt)
|
150 |
+
response = client.chat.completions.create(
|
151 |
+
model="gpt-3.5-turbo",
|
152 |
+
messages= history
|
153 |
+
)
|
154 |
+
content = response.choices[0].message.content
|
155 |
+
history.append({"role": "assistant", "content": content})
|
156 |
+
outfit["tops"] = content[15:]
|
157 |
+
|
158 |
+
|
159 |
+
prompt = {"role": "user", "content": f"Given the previous answers, generate a description for what \
|
160 |
+
I should wear as a bottom:"}
|
161 |
+
|
162 |
+
history.append(prompt)
|
163 |
+
response = client.chat.completions.create(
|
164 |
+
model="gpt-3.5-turbo",
|
165 |
+
messages= history
|
166 |
+
)
|
167 |
+
content = response.choices[0].message.content
|
168 |
+
history.append({"role": "assistant", "content": content})
|
169 |
+
outfit["bottoms"] = content[15:]
|
170 |
+
|
171 |
+
|
172 |
+
prompt = {"role": "user", "content": f"Given the previous answers, generate a description for what \
|
173 |
+
I should wear as shoes:"}
|
174 |
+
|
175 |
+
history.append(prompt)
|
176 |
+
response = client.chat.completions.create(
|
177 |
+
model="gpt-3.5-turbo",
|
178 |
+
messages= history
|
179 |
+
)
|
180 |
+
content = response.choices[0].message.content
|
181 |
+
history.append({"role": "assistant", "content": content})
|
182 |
+
outfit["shoes"] = content[15:]
|
183 |
+
|
184 |
+
|
185 |
+
prompt = {"role": "user", "content": f"Given the previous answers, and given the temperature, \
|
186 |
+
generate a description for what I should wear as outwear. If the temperature I gave you \
|
187 |
+
({temperature}) is above 25, write none."}
|
188 |
+
|
189 |
+
history.append(prompt)
|
190 |
+
response = client.chat.completions.create(
|
191 |
+
model="gpt-3.5-turbo",
|
192 |
+
messages= history
|
193 |
+
)
|
194 |
+
content = response.choices[0].message.content
|
195 |
+
history.append({"role": "assistant", "content": content})
|
196 |
+
if not "none" in content.lower():
|
197 |
+
outfit["outerwear"] = content[15:]
|
198 |
+
|
199 |
+
|
200 |
+
return outfit
|
201 |
+
|
202 |
+
|
203 |
+
def get_items_by_type(db, username, item_type:str):
|
204 |
+
collection = db.users
|
205 |
+
document = collection.find_one({"_id": username})
|
206 |
+
if not document:
|
207 |
+
return None
|
208 |
+
closet = document.get("closet", {})
|
209 |
+
items = [closet[key] for key in closet if closet[key].get("type") == item_type]
|
210 |
+
|
211 |
+
return items
|
212 |
+
|
213 |
+
def get_items_embeddings(items):
|
214 |
+
embeddings = []
|
215 |
+
for item in items:
|
216 |
+
embedding = item.get("embedding")
|
217 |
+
embeddings.append(embedding)
|
218 |
+
|
219 |
+
return embeddings
|
220 |
+
|
221 |
+
def get_most_similar_embedding(embedding, embeddings_list):
|
222 |
+
embedding = np.asarray(embedding, dtype=np.float32)
|
223 |
+
embeddings_list = np.asarray(embeddings_list, dtype=np.float32)
|
224 |
+
|
225 |
+
d = embedding.shape[0]
|
226 |
+
index = faiss.IndexFlatL2(d)
|
227 |
+
index.add(embeddings_list)
|
228 |
+
_, most_similar_index = index.search(np.expand_dims(embedding, axis=0), 1)
|
229 |
+
|
230 |
+
return most_similar_index[0][0]
|
231 |
+
|
232 |
+
def get_outfit(outfit_description, username, db):
|
233 |
+
outfit = []
|
234 |
+
|
235 |
+
for item_type, description in outfit_description.items():
|
236 |
+
description_embedding = get_text_embeddings(description)
|
237 |
+
|
238 |
+
items = get_items_by_type(db, username, item_type)
|
239 |
+
items_embeddings= get_items_embeddings(items)
|
240 |
+
outfit_item_index = get_most_similar_embedding(description_embedding, items_embeddings)
|
241 |
+
outfit_item_path = items[outfit_item_index]["path"]
|
242 |
+
if item_type == "outerwear":
|
243 |
+
outfit.insert(0, outfit_item_path)
|
244 |
+
else:
|
245 |
+
outfit.append(outfit_item_path)
|
246 |
+
|
247 |
+
return outfit
|