File size: 1,490 Bytes
859131c 3738ba5 859131c 3738ba5 859131c 3738ba5 859131c 0ec2004 859131c 3738ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from dataclasses import dataclass
from transformers import GPT2Config, CLIPVisionConfig
PREFIX_MAP = {
"openai/clip-vit-base-patch32": 50,
"openai/clip-vit-base-patch16": 197,
"openai/clip-vit-large-patch14": 257,
"openai/clip-vit-large-patch14-336": 577
}
TEXT_HIDDEN_SIZE_MAP = {
"gpt2": 768,
"gpt2-medium": 768,
"gpt2-large": 1280,
"gpt2-xl": 1600
}
IMAGE_HIDDEN_SIZE_MAP = {
"openai/clip-vit-base-patch32": 768,
"openai/clip-vit-base-patch16": 768,
"openai/clip-vit-large-patch14": 768,
"openai/clip-vit-large-patch14-336": 768
}
@dataclass
class CLIPGPT2Config:
image_model: str = "openai/clip-vit-base-patch32"
freeze_image_model: bool = True
text_model: str = "gpt2-large"
freeze_text_model: bool = True
linear_mapping_type: int = "linear"
add_image_token: bool = True
freeze_ln: bool = False
image_from_pretrained: bool = True
text_from_pretrained: bool = True
def __post_init__(self):
self.prefix_length = PREFIX_MAP[self.image_model]
self.image_hidden_size = IMAGE_HIDDEN_SIZE_MAP[self.image_model]
self.text_hidden_size = TEXT_HIDDEN_SIZE_MAP[self.text_model]
self.image_resize = 224 if "336" not in self.image_model else 336
self.text_config = GPT2Config.from_pretrained(self.text_model)
self.image_config = CLIPVisionConfig.from_pretrained(self.image_model)
self.vocab_size = self.text_config.vocab_size + self.add_image_token
|