Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- get_similiarty.py +38 -0
- main.py +56 -0
get_similiarty.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.datasets as datasets
|
2 |
+
import numpy as np
|
3 |
+
import clip
|
4 |
+
import torch
|
5 |
+
def get_similiarity(prompt, model_resnet, model_vit, top_k=3):
|
6 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
+
data_dir = 'sample/data'
|
8 |
+
image_arr = np.loadtxt("embeddings.csv", delimiter=",")
|
9 |
+
raw_dataset = datasets.ImageFolder(data_dir)
|
10 |
+
# получите список всех изображений
|
11 |
+
# create transformer-readable tokens
|
12 |
+
inputs = clip.tokenize(prompt).to(device)
|
13 |
+
text_emb = model_resnet.encode_text(inputs)
|
14 |
+
text_emb = text_emb.cpu().detach().numpy()
|
15 |
+
scores = np.dot(text_emb, image_arr.T)
|
16 |
+
# score_vit
|
17 |
+
# get the top k indices for most similar vecs
|
18 |
+
idx = np.argsort(-scores[0])[:top_k]
|
19 |
+
image_files = []
|
20 |
+
for i in idx:
|
21 |
+
image_files.append(raw_dataset.imgs[i][0])
|
22 |
+
|
23 |
+
image_arr_vit = np.loadtxt('embeddings_vit.csv', delimiter=",")
|
24 |
+
inputs_vit = clip.tokenize(prompt).to(device)
|
25 |
+
text_emb_vit = model_vit.encode_text(inputs_vit)
|
26 |
+
text_emb_vit = text_emb_vit.cpu().detach().numpy()
|
27 |
+
scores_vit = np.dot(text_emb_vit, image_arr_vit.T)
|
28 |
+
idx_vit = np.argsort(-scores_vit[0])[:top_k]
|
29 |
+
image_files_vit = []
|
30 |
+
for i in idx_vit:
|
31 |
+
image_files_vit.append(raw_dataset.imgs[i][0])
|
32 |
+
|
33 |
+
return image_files, image_files_vit
|
34 |
+
# def get_text_enc(input_text: str):
|
35 |
+
# text = clip.tokenize([input_text]).to(device)
|
36 |
+
# text_features = model.encode_text(text).cpu()
|
37 |
+
# text_features = text_features.cpu().detach().numpy()
|
38 |
+
# return text_features
|
main.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
import clip
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import torchvision.datasets as datasets
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
from get_similiarty import get_similiarity
|
12 |
+
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
#load model -resnet50
|
15 |
+
|
16 |
+
|
17 |
+
model_resnet = torch.load("model.pt", device )
|
18 |
+
|
19 |
+
#load model - ViT-B/32
|
20 |
+
model_vit = torch.load("model_vit.pt", device )
|
21 |
+
|
22 |
+
|
23 |
+
st.title('Find my pic!')
|
24 |
+
|
25 |
+
def find_image_disc(prompt, df):
|
26 |
+
img_descs = []
|
27 |
+
img_descs_vit = []
|
28 |
+
list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, 3)
|
29 |
+
for img in list_images_names:
|
30 |
+
img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
|
31 |
+
#vit
|
32 |
+
for img in list_images_names_vit:
|
33 |
+
img_descs_vit.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
|
34 |
+
|
35 |
+
return list_images_names, img_descs, list_images_names_vit, img_descs_vit
|
36 |
+
|
37 |
+
txt = st.text_area("Describe the picture you'd like to see")
|
38 |
+
|
39 |
+
df = pd.read_csv('results.csv',
|
40 |
+
sep = '|',
|
41 |
+
names = ['image_name', 'comment_number', 'comment'],
|
42 |
+
header=0)
|
43 |
+
|
44 |
+
|
45 |
+
if txt is not None:
|
46 |
+
if st.button('Find!'):
|
47 |
+
|
48 |
+
list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df)
|
49 |
+
col1, col2 = st.columns(2)
|
50 |
+
for ind, pic in enumerate(zip(list_images, list_images_vit)):
|
51 |
+
with col1:
|
52 |
+
st.image(pic[0])
|
53 |
+
st.write(img_desc[ind])
|
54 |
+
with col2:
|
55 |
+
st.image(pic[1])
|
56 |
+
st.write(img_desc[ind])
|