from transformers import PreTrainedModel, ViTMAEModel from .configuration_magiv2 import Magiv2Config import torch import numpy as np from transformers import ViTImageProcessor import PIL def move_to_device(inputs, device): if hasattr(inputs, "keys"): return {k: move_to_device(v, device) for k, v in inputs.items()} elif isinstance(inputs, list): return [move_to_device(v, device) for v in inputs] elif isinstance(inputs, tuple): return tuple([move_to_device(v, device) for v in inputs]) elif isinstance(inputs, np.ndarray): return torch.from_numpy(inputs).to(device) else: return inputs.to(device) class Magiv2Model(PreTrainedModel): config_class = Magiv2Config def __init__(self, config): super().__init__(config) self.config = config self.processor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config) self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config) def move_to_device(self, input): return move_to_device(input, self.device) def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256): if len(images) == 0: return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size)) assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images" move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn images = [np.array(image.convert("L").convert("RGB")) for image in images] images = self.processor(images, return_tensors="pt").pixel_values images = move_to_device_fn(images) # temporarily change the mask ratio from default to the one specified old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio # process the crops in batches to avoid OOM embeddings = [] for i in range(0, len(images), batch_size): crops = images[i:i+batch_size] embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0] embeddings.append(embeddings_per_batch) embeddings = torch.cat(embeddings, dim=0) # restore the mask ratio to the default self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio return embeddings