Spaces:
Runtime error
Runtime error
File size: 1,543 Bytes
b821924 b9e8251 b821924 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torchvision.datasets as datasets
import numpy as np
import clip
import torch
def get_similiarity(prompt, model_resnet, model_vit, top_k=3):
device = "cuda" if torch.cuda.is_available() else "cpu"
data_dir = 'sample/sample/data'
image_arr = np.loadtxt("embeddings.csv", delimiter=",")
raw_dataset = datasets.ImageFolder(data_dir)
# получите список всех изображений
# create transformer-readable tokens
inputs = clip.tokenize(prompt).to(device)
text_emb = model_resnet.encode_text(inputs)
text_emb = text_emb.cpu().detach().numpy()
scores = np.dot(text_emb, image_arr.T)
# score_vit
# get the top k indices for most similar vecs
idx = np.argsort(-scores[0])[:top_k]
image_files = []
for i in idx:
image_files.append(raw_dataset.imgs[i][0])
image_arr_vit = np.loadtxt('embeddings_vit.csv', delimiter=",")
inputs_vit = clip.tokenize(prompt).to(device)
text_emb_vit = model_vit.encode_text(inputs_vit)
text_emb_vit = text_emb_vit.cpu().detach().numpy()
scores_vit = np.dot(text_emb_vit, image_arr_vit.T)
idx_vit = np.argsort(-scores_vit[0])[:top_k]
image_files_vit = []
for i in idx_vit:
image_files_vit.append(raw_dataset.imgs[i][0])
return image_files, image_files_vit
# def get_text_enc(input_text: str):
# text = clip.tokenize([input_text]).to(device)
# text_features = model.encode_text(text).cpu()
# text_features = text_features.cpu().detach().numpy()
# return text_features |