Spaces:
Running
Running
Delete utils.py
Browse files
utils.py
DELETED
@@ -1,200 +0,0 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
-
import numpy as np
|
4 |
-
import pickle
|
5 |
-
import torch
|
6 |
-
import transformers
|
7 |
-
from PIL import Image
|
8 |
-
from open_clip import create_model_from_pretrained, create_model_and_transforms
|
9 |
-
import json
|
10 |
-
|
11 |
-
# XLM model functions
|
12 |
-
from multilingual_clip import pt_multilingual_clip
|
13 |
-
|
14 |
-
from model_loading import load_model
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
class CustomDataSet(torch.utils.data.Dataset):
|
19 |
-
def __init__(self, main_dir, compose, image_name_list):
|
20 |
-
self.main_dir = main_dir
|
21 |
-
self.transform = compose
|
22 |
-
self.total_imgs = image_name_list
|
23 |
-
|
24 |
-
def __len__(self):
|
25 |
-
return len(self.total_imgs)
|
26 |
-
|
27 |
-
def get_image_name(self, idx):
|
28 |
-
|
29 |
-
return self.total_imgs[idx]
|
30 |
-
|
31 |
-
def __getitem__(self, idx):
|
32 |
-
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
|
33 |
-
image = Image.open(img_loc)
|
34 |
-
|
35 |
-
return self.transform(image)
|
36 |
-
|
37 |
-
|
38 |
-
def features_pickle(file_path=None):
|
39 |
-
|
40 |
-
with open(file_path, 'rb') as handle:
|
41 |
-
features_pickle = pickle.load(handle)
|
42 |
-
|
43 |
-
return features_pickle
|
44 |
-
|
45 |
-
|
46 |
-
def dataset_loading():
|
47 |
-
|
48 |
-
with open("/home/think3/Desktop/2. tf_testing_araclip/XTD_dataset/en_ar_XTD10_edited_v2.jsonl") as filino:
|
49 |
-
|
50 |
-
|
51 |
-
data = [json.loads(file_i) for file_i in filino]
|
52 |
-
|
53 |
-
sorted_data = sorted(data, key=lambda x: x['id'])
|
54 |
-
|
55 |
-
image_name_list = [lin["image_name"] for lin in sorted_data]
|
56 |
-
|
57 |
-
|
58 |
-
return sorted_data, image_name_list
|
59 |
-
|
60 |
-
|
61 |
-
def text_encoder(language_model, text):
|
62 |
-
"""Normalize the text embeddings"""
|
63 |
-
embedding = language_model(text)
|
64 |
-
norm_embedding = embedding / np.linalg.norm(embedding)
|
65 |
-
|
66 |
-
return embedding, norm_embedding
|
67 |
-
|
68 |
-
|
69 |
-
def compare_embeddings(logit_scale, img_embs, txt_embs):
|
70 |
-
|
71 |
-
image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
|
72 |
-
|
73 |
-
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
|
74 |
-
|
75 |
-
logits_per_text = logit_scale * text_features @ image_features.t()
|
76 |
-
|
77 |
-
return logits_per_text
|
78 |
-
|
79 |
-
# Done
|
80 |
-
def compare_embeddings_text(full_text_embds, txt_embs):
|
81 |
-
|
82 |
-
full_text_embds_features = full_text_embds / full_text_embds.norm(dim=-1, keepdim=True)
|
83 |
-
|
84 |
-
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
|
85 |
-
|
86 |
-
logits_per_text_full = text_features @ full_text_embds_features.t()
|
87 |
-
|
88 |
-
return logits_per_text_full
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
def find_image(language_model,clip_model, text_query, dataset, image_features, text_features_new,sorted_data, num=1):
|
93 |
-
|
94 |
-
embedding, _ = text_encoder(language_model, text_query)
|
95 |
-
|
96 |
-
logit_scale = clip_model.logit_scale.exp().float().to('cpu')
|
97 |
-
|
98 |
-
language_logits, text_logits = {}, {}
|
99 |
-
|
100 |
-
language_logits["Arabic"] = compare_embeddings(logit_scale, torch.from_numpy(image_features), torch.from_numpy(embedding))
|
101 |
-
|
102 |
-
text_logits["Arabic_text"] = compare_embeddings_text(torch.from_numpy(text_features_new), torch.from_numpy(embedding))
|
103 |
-
|
104 |
-
|
105 |
-
for _, txt_logits in language_logits.items():
|
106 |
-
|
107 |
-
probs = txt_logits.softmax(dim=-1).cpu().detach().numpy().T
|
108 |
-
|
109 |
-
file_paths = []
|
110 |
-
labels, json_data = {}, {}
|
111 |
-
|
112 |
-
for i in range(1, num+1):
|
113 |
-
idx = np.argsort(probs, axis=0)[-i, 0]
|
114 |
-
path = 'photos/XTD10_dataset/' + dataset.get_image_name(idx)
|
115 |
-
|
116 |
-
path_l = (path,f"{sorted_data[idx]['caption_ar']}")
|
117 |
-
|
118 |
-
labels[f" Image # {i}"] = probs[idx]
|
119 |
-
json_data[f" Image # {i}"] = sorted_data[idx]
|
120 |
-
|
121 |
-
file_paths.append(path_l)
|
122 |
-
|
123 |
-
|
124 |
-
json_text = {}
|
125 |
-
|
126 |
-
for _, txt_logits_full in text_logits.items():
|
127 |
-
|
128 |
-
probs_text = txt_logits_full.softmax(dim=-1).cpu().detach().numpy().T
|
129 |
-
|
130 |
-
for j in range(1, num+1):
|
131 |
-
|
132 |
-
idx = np.argsort(probs_text, axis=0)[-j, 0]
|
133 |
-
json_text[f" Text # {j}"] = sorted_data[idx]
|
134 |
-
|
135 |
-
return file_paths, labels, json_data, json_text
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
class AraClip():
|
140 |
-
def __init__(self):
|
141 |
-
|
142 |
-
self.text_model = load_model('bert-base-arabertv2-ViT-B-16-SigLIP-512-epoch-155-trained-2M', in_features= 768, out_features=768)
|
143 |
-
self.language_model = lambda queries: np.asarray(self.text_model(queries).detach().to('cpu'))
|
144 |
-
self.clip_model, self.compose = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP-512')
|
145 |
-
self.sorted_data, self.image_name_list = dataset_loading()
|
146 |
-
|
147 |
-
def load_images(self):
|
148 |
-
# Return the features of the text and images
|
149 |
-
image_features_new = features_pickle('testing_pickle_files_images_text/image_features_XTD_1000_images_arabert_siglib_best_model.pickle')
|
150 |
-
return image_features_new
|
151 |
-
|
152 |
-
def load_text(self):
|
153 |
-
text_features_new = features_pickle('testing_pickle_files_images_text/text_features_XTD_1000_images_arabert_siglib_best_model.pickle')
|
154 |
-
return text_features_new
|
155 |
-
|
156 |
-
def load_dataset(self):
|
157 |
-
dataset = CustomDataSet("photos/XTD10_dataset", self.compose, self.image_name_list)
|
158 |
-
return dataset
|
159 |
-
|
160 |
-
|
161 |
-
araclip = AraClip()
|
162 |
-
|
163 |
-
def predict(text, num):
|
164 |
-
|
165 |
-
image_paths, labels, json_data, json_text = find_image(araclip.language_model,araclip.clip_model, text, araclip.load_dataset(), araclip.load_images() , araclip.load_text(), araclip.sorted_data, num=int(num))
|
166 |
-
|
167 |
-
return image_paths, labels, json_data, json_text
|
168 |
-
|
169 |
-
|
170 |
-
class Mclip():
|
171 |
-
def __init__(self) -> None:
|
172 |
-
|
173 |
-
|
174 |
-
self.tokenizer_mclip = transformers.AutoTokenizer.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
|
175 |
-
self.text_model_mclip = pt_multilingual_clip.MultilingualCLIP.from_pretrained('M-CLIP/XLM-Roberta-Large-Vit-B-16Plus')
|
176 |
-
self.language_model_mclip = lambda queries: np.asarray(self.text_model_mclip.forward(queries, self.tokenizer_mclip).detach().to('cpu'))
|
177 |
-
self.clip_model_mclip, _, self.compose_mclip = create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
|
178 |
-
self.sorted_data, self.image_name_list = dataset_loading()
|
179 |
-
|
180 |
-
def load_images(self):
|
181 |
-
# Return the features of the text and images
|
182 |
-
image_features_mclip = features_pickle('Cach_embeddings/image_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
|
183 |
-
return image_features_mclip
|
184 |
-
|
185 |
-
def load_text(self):
|
186 |
-
text_features_new_mclip = features_pickle('Cach_embeddings/text_features_XTD_1000_images_XLM_Roberta_Large_Vit_B_16Plus_ar.pickle')
|
187 |
-
return text_features_new_mclip
|
188 |
-
|
189 |
-
def load_dataset(self):
|
190 |
-
dataset_mclip = CustomDataSet("photos/XTD10_dataset", self.compose_mclip, self.image_name_list)
|
191 |
-
return dataset_mclip
|
192 |
-
|
193 |
-
|
194 |
-
mclip = Mclip()
|
195 |
-
|
196 |
-
def predict_mclip(text, num):
|
197 |
-
|
198 |
-
image_paths, labels, json_data, json_text = find_image(mclip.language_model_mclip,mclip.clip_model_mclip, text, mclip.load_dataset() , mclip.load_text() , mclip.load_text() , mclip.sorted_data , num=int(num))
|
199 |
-
|
200 |
-
return image_paths, labels, json_data, json_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|