################################################## PACKAGES ############################################################ ################################################# PACKAGES ############################################################# # PyTorch for deep learning operations import torch import torch.nn as nn # PyTorch data loading and utilities import torch.multiprocessing # Additional PyTorch modules and libraries import cv2 # OpenCV for image processing # Transfer Learning model library import timm # Data manipulation and handling import requests # COCO dataset tools from pycocotools.coco import COCO import numpy as np # Hugging Face Transformers library for BERT models from transformers import BertModel, BertTokenizer, DistilBertModel, DistilBertConfig, DistilBertTokenizer import torch.nn.functional as F # Image processing and augmentations import albumentations as A # Visualization and progress tracking from tqdm import tqdm import matplotlib.pyplot as plt # Additional utility for iterating over combinations import itertools from albumentations.pytorch import ToTensorV2 import pandas as pd from configs import CFG from huggingface_hub import PyTorchModelHubMixin ################################################### MODELS ############################################################ ################################################# MODELS ############################################################## class ProjectionHead(nn.Module): def __init__(self, input_dim, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs): """ Projection Head module for contrastive learning. :param input_dim: Dimensionality of input features. :param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). :param dropout_rate: Dropout rate (default: CFG.dropout_rate). """ super(ProjectionHead, self).__init__(*args, **kwargs) # Attributes self.input_dim = input_dim self.projection_dim = projection_dim self.dropout_rate = dropout_rate # Layers self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim) self.gelu = nn.GELU() self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim) self.dropout = nn.Dropout(self.dropout_rate) self.normalization_layer = nn.LayerNorm(self.projection_dim) def forward(self, inputs): """ Forward pass of the projection head. :param inputs: Input features. :return: Projected features. """ x = inputs x = self.linear_layer1(x) x = self.gelu(x) x = self.linear_layer2(x) x = self.dropout(x) x = self.normalization_layer(x) return x def __call__(self, inputs): """ Callable method for the projection head. :param inputs: Input features. :return: Projected features. """ return self.forward(inputs) class ImageEncoder(nn.Module): def __init__(self, model_name=CFG.vit_name, projection_dim=CFG.projection_dim, trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): """ Image encoder module using Vision Transformer (ViT) backbone. :param model_name: Name of the Vision Transformer model (default: CFG.vit_name). :param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). :param trainable: Whether to make the backbone trainable (default: False). :param dropout_rate: Dropout rate (default: CFG.dropout_rate). """ super(ImageEncoder, self).__init__(*args, **kwargs) # Attributes self.model_name = model_name self.projection_dim = projection_dim self.trainable = trainable self.dropout_rate = dropout_rate # Models self.pretrained_vit = timm.create_model(self.model_name, pretrained=True, num_classes=0) self.projection_head = ProjectionHead(self.pretrained_vit.embed_dim, self.projection_dim, self.dropout_rate) # Freeze pretrained ViT layers for parameter in self.pretrained_vit.parameters(): parameter.requires_grad = self.trainable def forward(self, images): """ Forward pass of the image encoder. :param images: Input images. :return: Projected features. """ x = images # forward_features: to return sequences (encoder) -> torch.Size([batch_size, 197, 768]) forward_head: to # return flattened sequences (vectors) -> torch.Size([batch_size, 768]) if num_classes=0 (no classification) # in timm.create_model and torch.Size([batch_size, 1000]) otherwise (classification) x = self.pretrained_vit.forward_features(x) # output: torch.Size([batch_size, 197, 256]) x = self.projection_head(x) return x def __call__(self, images): """ Callable method for the image encoder. :param images: Input images. :return: Projected features. """ return self.forward(images) class TextEncoder(nn.Module): def __init__(self, model_name=CFG.bert_name, projection_dim=CFG.projection_dim, trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): """ Text encoder module using BERT backbone. :param model_name: Name of the BERT model (default: CFG.bert_name). :param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). :param trainable: Whether to make the backbone trainable (default: False). :param dropout_rate: Dropout rate (default: CFG.dropout_rate). """ super(TextEncoder, self).__init__(*args, **kwargs) # Attributes self.model_name = model_name self.projection_dim = projection_dim self.dropout_rate = dropout_rate self.trainable = trainable # Models self.pretrained_bert = BertModel.from_pretrained(self.model_name) self.projection_head = ProjectionHead(self.pretrained_bert.config.hidden_size, self.projection_dim, self.dropout_rate) # Freeze BERT for parameter in self.pretrained_bert.parameters(): parameter.requires_grad = self.trainable def forward(self, captions): """ Forward pass of the text encoder. :param captions: Input captions (input_ids, attention_mask). :return: Projected features. """ input_ids, attention_mask = captions # last_hidden_state: torch.Size([batch_size, sequence, 768]) # pooler_output: torch.Size([batch_size, 768]) x = self.pretrained_bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state # output: torch.Size([batch_size, sequence, 256]) x = self.projection_head(x) return x def __call__(self, captions): """ Callable method for the text encoder. :param captions: Input captions (input_ids, attention_mask). :return: Projected features. """ return self.forward(captions) class ModalityTokenEncoder(nn.Module): def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs): """ Modality token encoder module for encoding modality token information. :param projection_dim: Dimensionality of projected features (default: CFG.projection_dim). :param token_size: Token size. :param device: Device to run the module on (default: 'cpu'). """ super(ModalityTokenEncoder, self).__init__(*args, **kwargs) # Attributes self.projection_dim = projection_dim self.device = device self.token_size = token_size # Models text_variance = torch.rand(1) * 0.5 + 0.1 image_variance = torch.rand(1) * 0.5 + 0.1 self.text_token = nn.Parameter(torch.normal(mean=0, std=text_variance.item(), size=(self.token_size, self.projection_dim)).to(self.device)) self.image_token = nn.Parameter(torch.normal(mean=0, std=image_variance.item(), size=(self.token_size, self.projection_dim)).to(self.device)) def forward(self, modality_type): """ Forward pass of the modality encoder. :param modality_type: Input token indicator. :return: Projected features. """ token = torch.where(torch.tensor(modality_type == "image"), self.image_token, self.text_token) return token def __call__(self, modality_type): """ Callable method for the token encoder. :param modality_type: Input token indicator. :return: Projected features. """ return self.forward(modality_type) class UniversalProjectionOutput: def __init__(self, outputs): """ Wrapper class for projection model outputs. :param outputs: Dictionary containing model outputs. """ self.outputs = outputs def __getattr__(self, name): """ Retrieve attribute from outputs dictionary. :param name: Name of the attribute to retrieve. :return: Value of the attribute. """ if name in self.outputs: return self.outputs[name] else: raise AttributeError(f"'UniversalProjectionOutput' object has no attribute '{name}'") class UniversalProjectionEncoder(nn.Module): def __init__(self, input_dim=CFG.projection_dim, num_head=CFG.num_head, num_layers=CFG.num_layers, *args, **kwargs): """ Initialize Universal Projection module. :param input_dim: Dimensionality of the input embeddings. Defaults to CFG.projection_dim. :param num_head: Number of attention heads. Defaults to CFG.num_head. :param num_layers: Number of transformer layers. Defaults to CFG.num_layers. """ super(UniversalProjectionEncoder, self).__init__(*args, **kwargs) self.input_dim = input_dim self.num_head = num_head self.num_layers = num_layers self.transformer_encoder_block = nn.TransformerEncoderLayer( d_model=self.input_dim, nhead=self.num_head, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder( self.transformer_encoder_block, num_layers=self.num_layers ) # self.transformer_encoder = TransformerModel(self.input_dim, self.num_head, self.num_layers) # model_name = 'bert-large-uncased' self.layer_normalization = nn.LayerNorm(self.input_dim) # self.transfopip install torch torchvision -Urmer_encoder = BertModel.from_pretrained(model_name) def forward(self, inputs): # x: image or caption embeddings x, tokens = inputs ## Universal Projection block tokens = tokens.unsqueeze(0).expand(x.size()[0], -1, -1) # Concatenate tokens with image/caption embeddings # output_tensor = torch.cat((tokens, x), dim=1) output_tensor = x + tokens # Normalization output_norm = self.layer_normalization(output_tensor) # Projection output_encoder = self.transformer_encoder(output_norm) ## Residual Connection residual_output = output_encoder + output_tensor # output = output_encoder[:, CFG.token_size:, :] # Residual connection return UniversalProjectionOutput({'last_hidden_state': residual_output, 'mean_output': torch.mean(residual_output, dim=1), 'pooler_output': residual_output[:, 0, :]}) def __call__(self, inputs): return self.forward(inputs) class OneEncoder(nn.Module, PyTorchModelHubMixin): def __init__(self, image_encoder=ImageEncoder(), text_encoder=TextEncoder(), modality_token_encoder=ModalityTokenEncoder(), universal_projection_encoder=UniversalProjectionEncoder(), device='cpu', tokenizer=BertTokenizer.from_pretrained(CFG.bert_name), image_preprocessor=A.Compose([A.Resize(CFG.image_size, CFG.image_size, always_apply=True), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), always_apply=True), ToTensorV2()]), *args, **kwargs): """ Initialize the model. :param image_encoder: Image encoder module (default: ImageEncoder()). :param text_encoder: Text encoder module (default: TextEncoder()). :param modality_token_encoder: Modality encoder module (default: ModalityEncoder()). :param universal_projection_encoder: Universal projection encoder module (default: UniversalProjection()). :param device: Device to run the model on (default: 'cpu'). :param tokenizer: Tokenizer for text encoding (default: BertTokenizer.from_pretrained(CFG.bert_name)). :param image_preprocessor: Preprocessor for image inputs (default: A.Compose([...])). """ super(OneEncoder, self).__init__(*args, **kwargs) self.device = device self.image_encoder = image_encoder self.text_encoder = text_encoder self.universal_projection_encoder = universal_projection_encoder self.modality_token_encoder = modality_token_encoder self.modality_token_encoder.device = self.device self.tokenizer = tokenizer self.image_preprocessor = image_preprocessor # The learnable temperature parameter τ was initialized to the equivalent of 0.07 from (Wu et al., 2018) # and clipped to prevent scaling the logits by more than 100, which we found necessary # to prevent training instability. self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device)) @classmethod def load_image(cls, image_path): # Load online image if image_path.startswith("http"): response = requests.get(image_path) # Check if the request was successful if response.status_code == 200: # Convert the image content to a numpy array img_array = np.asarray(bytearray(response.content), dtype=np.uint8) # Decode the image using OpenCV image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) # Convert BGR to RGB image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Load local image else: image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def encode_image(self, image_paths=None, image_tensors=None, outputs="mean"): """ Encode images into feature vectors. :param image_paths: List of image paths. :param image_tensors: Torch tensor (batch, 3, 224, 224). :param outputs type of outputs: mean, pooler, sequence :return: Encoded image features. """ if image_paths is not None: image_processed = [self.image_preprocessor(image=self.load_image(image))["image"] for image in image_paths] image_processed = torch.stack(image_processed).to(self.device) with torch.no_grad(): image_features = self.image_encoder(image_processed.to(self.device)) modality_token_feature = self.modality_token_encoder("image") output_features = self.universal_projection_encoder([image_features, modality_token_feature]) elif image_tensors is not None: with torch.no_grad(): image_features = self.image_encoder(image_tensors.to(self.device)) modality_token_feature = self.modality_token_encoder("image") output_features = self.universal_projection_encoder([image_features, modality_token_feature]) if outputs == "mean": image_features = output_features.mean_output elif outputs == "sequence": image_features = output_features.last_hidden_state else: image_features = output_features.pooler_output return image_features def encode_text(self, texts, max_length=128, outputs="mean"): """ Encode text descriptions into feature vectors. :param texts: List of text descriptions. :param max_length: Maximum length of the text sequences (default: 128). :param outputs type of outputs: mean, sequence, pooler :return: Encoded text features. """ encoded_query = self.tokenizer( texts, padding=True, truncation=True, max_length=max_length ) batch = { key: torch.tensor(values).to(self.device) for key, values in encoded_query.items() } with torch.no_grad(): text_features = self.text_encoder([ batch["input_ids"], batch["attention_mask"] ]) modality_token_feature = self.modality_token_encoder("text") output_features = self.universal_projection_encoder([text_features, modality_token_feature]) if outputs == "mean": text_features = output_features.mean_output elif outputs == "sequence": text_features = output_features.last_hidden_state else: text_features = output_features.pooler_output return text_features def matching(self, image_paths, texts, normalize=True, top_k=None, strategy="similarity", temperature=0.0): """ Calculate similarities between images and texts. :param image_paths: List of paths to images. :param texts: List of text descriptions. :param normalize: Whether to normalize the features (default: True). :param top_k: Return top K results (default: None). :param strategy: Matching strategy, either 'similarity' or 'softmax' (default: 'similarity'). :param temperature: change real distribution, default = 2.5 :return: If top_k is provided, returns top probabilities and labels, otherwise returns dot similarities. """ image_features = self.encode_image(image_paths=image_paths) text_features = self.encode_text(texts=texts) if normalize: image_features = F.normalize(image_features, p=2, dim=-1) text_features = F.normalize(text_features, p=2, dim=-1) dot_similarities = (image_features @ text_features.T) * torch.exp(torch.tensor(temperature).to(self.device)) if strategy == 'softmax': dot_similarities = (float(len(set(texts))) * dot_similarities).softmax(dim=-1) if top_k is not None: top_probs, top_labels = dot_similarities.cpu().topk(top_k, dim=-1) return top_probs, top_labels else: return dot_similarities, None def image_retrieval(self, query, image_paths, image_embeddings=None, temperature=0.0, n=9, plot=False): """ Perform image retrieval based on a text query. :param query: Text query (string). :param image_paths: List of image paths (optional). :param image_embeddings: Precomputed image embeddings (optional). :param temperature: change real distribution, default = 2.5 :param n: Number of images to retrieve (default: 9). :param plot: Whether to plot the retrieved images (default: False). :return: Tuple containing similarity values and indices of the retrieved images. """ text_embeddings = self.encode_text([query]) if image_embeddings is None: image_embeddings = self.encode_image(image_paths=image_paths) image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) dot_similarity = (text_embeddings_n @ image_embeddings_n.T) * torch.exp( torch.tensor(temperature).to(self.device)) if n > len(image_paths): n = len(image_paths) values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n) if plot: nrows = int(np.sqrt(n)) ncols = int(np.ceil(n / nrows)) matches = [image_paths[idx] for idx in indices] fig, axes = plt.subplots(nrows, ncols, figsize=(20, 20)) for match, ax in zip(matches, axes.flatten()): image = self.load_image(f"{match}") ax.imshow(image) ax.axis("off") plt.savefig("img.png") #fig.suptitle(query) #plt.show() #return values, indices def text_retrieval(self, query, texts, text_embeddings=None, n=9, plot_image=False, temperature=0.0): """ Perform text retrieval based on an image query. :param query: Image query (path of image). :param texts: List of text samples. :param text_embeddings: Precomputed text embeddings (optional). :param n: Number of texts to retrieve (default: 9). :param plot_image: Plot the query :param temperature: change real distribution, default = 2.5 :return: List of retrieved text samples and its probabilities. """ if text_embeddings is None: text_embeddings = self.encode_text(texts) image_embeddings = self.encode_image([query]) image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) dot_similarity = (image_embeddings_n @ text_embeddings_n.T) * torch.exp( torch.tensor(temperature).to(self.device)) if n > len(texts): n = len(texts) values, indices = torch.topk(dot_similarity.cpu().squeeze(0), n) matches = [texts[idx] for idx in indices] if plot_image: # Read and plot the image image = self.load_image(query) # Plot the image plt.imshow(image) plt.title('Random Image') plt.axis('off') plt.savefig("img.png") plt.show() return matches, values