File size: 2,523 Bytes
d9a550a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75798cc
d9a550a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import PreTrainedModel, ViTMAEModel
from .configuration_magiv2 import Magiv2Config
import torch
import numpy as np
from transformers import ViTImageProcessor
import PIL

def move_to_device(inputs, device):
    if hasattr(inputs, "keys"):
        return {k: move_to_device(v, device) for k, v in inputs.items()}
    elif isinstance(inputs, list):
        return [move_to_device(v, device) for v in inputs]
    elif isinstance(inputs, tuple):
        return tuple([move_to_device(v, device) for v in inputs])
    elif isinstance(inputs, np.ndarray):
        return torch.from_numpy(inputs).to(device)
    else:
        return inputs.to(device)

class Magiv2Model(PreTrainedModel):
    config_class = Magiv2Config

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.processor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
        self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)

    def move_to_device(self, input):
        return move_to_device(input, self.device)
    
    def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
        move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
        if len(images) == 0:
            return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size))

        assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images"
        images = [np.array(image.convert("L").convert("RGB")) for image in images]
        images = self.processor(images, return_tensors="pt").pixel_values
        images = move_to_device_fn(images)

        # temporarily change the mask ratio from default to the one specified
        old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
        self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
        
        # process the crops in batches to avoid OOM
        embeddings = []
        for i in range(0, len(images), batch_size):
            crops = images[i:i+batch_size]
            embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
            embeddings.append(embeddings_per_batch)
        embeddings = torch.cat(embeddings, dim=0)
        
        # restore the mask ratio to the default
        self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio

        return embeddings