jamessyx commited on
Commit
95b3a4d
·
verified ·
1 Parent(s): 36b91e1

Upload 27 files

Browse files
src/Qwen-encoder-1.5B/added_tokens.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "<|endoftext|>": 151643,
3
+ "<|im_end|>": 151645,
4
+ "<|im_start|>": 151644
5
+ }
src/Qwen-encoder-1.5B/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/BiomikeeNew/werent4/llm2vec/output/mntp/qwen_0.5/checkpoint-120000",
3
+ "architectures": [
4
+ "Qwen2ForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 1536,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 8960,
13
+ "max_position_embeddings": 32768,
14
+ "max_window_layers": 28,
15
+ "model_type": "qwen2",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 28,
18
+ "num_key_value_heads": 2,
19
+ "rms_norm_eps": 1e-06,
20
+ "rope_theta": 1000000.0,
21
+ "sliding_window": 32768,
22
+ "tie_word_embeddings": true,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.40.2",
25
+ "use_cache": true,
26
+ "use_sliding_window": false,
27
+ "vocab_size": 151936
28
+ }
src/Qwen-encoder-1.5B/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "transformers_version": "4.40.2"
6
+ }
src/Qwen-encoder-1.5B/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
src/Qwen-encoder-1.5B/model.safetensors.index.json ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 6174857216
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
8
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
18
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
28
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
30
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.10.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
37
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.10.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
40
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
42
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
49
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.11.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
52
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.11.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
54
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.12.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
61
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.12.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
64
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.12.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
66
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.13.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
73
+ "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.13.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
76
+ "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.13.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
78
+ "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.14.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
85
+ "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.14.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
88
+ "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
89
+ "model.layers.14.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
90
+ "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
93
+ "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.15.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
97
+ "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
98
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.15.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
100
+ "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
101
+ "model.layers.15.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
102
+ "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
107
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.16.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
109
+ "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
110
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.16.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
112
+ "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.16.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
114
+ "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
118
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.17.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
121
+ "model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.17.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
124
+ "model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.17.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
126
+ "model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
128
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
130
+ "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
131
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
132
+ "model.layers.18.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
133
+ "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
134
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.18.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
136
+ "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
137
+ "model.layers.18.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
138
+ "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
139
+ "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
144
+ "model.layers.19.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
145
+ "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.19.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
148
+ "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.19.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
150
+ "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
153
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
155
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
157
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
158
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
159
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
160
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
161
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
162
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
163
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
164
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
165
+ "model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
166
+ "model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
167
+ "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
168
+ "model.layers.20.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
169
+ "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
170
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
171
+ "model.layers.20.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
172
+ "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
173
+ "model.layers.20.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
174
+ "model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
175
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
176
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
177
+ "model.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
178
+ "model.layers.21.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
179
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
180
+ "model.layers.21.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
181
+ "model.layers.21.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
182
+ "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
183
+ "model.layers.21.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
184
+ "model.layers.21.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
185
+ "model.layers.21.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
186
+ "model.layers.21.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
187
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
188
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
189
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
190
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
191
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
192
+ "model.layers.22.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
193
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
194
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
195
+ "model.layers.22.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
196
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
197
+ "model.layers.22.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
198
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
199
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
200
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
201
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
202
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
203
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
204
+ "model.layers.23.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
205
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
206
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
207
+ "model.layers.23.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
208
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
209
+ "model.layers.23.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
210
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
211
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
212
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
213
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
214
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
215
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
216
+ "model.layers.24.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
217
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
218
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
219
+ "model.layers.24.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
220
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
221
+ "model.layers.24.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
222
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
223
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
224
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
225
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
226
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
227
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
228
+ "model.layers.25.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
229
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
230
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
231
+ "model.layers.25.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
232
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
233
+ "model.layers.25.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
234
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
235
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
236
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
237
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
238
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
239
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
240
+ "model.layers.26.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
241
+ "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
242
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
243
+ "model.layers.26.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
244
+ "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
245
+ "model.layers.26.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
246
+ "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
247
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
248
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
249
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
250
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
251
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
252
+ "model.layers.27.self_attn.k_proj.bias": "model-00002-of-00002.safetensors",
253
+ "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
254
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
255
+ "model.layers.27.self_attn.q_proj.bias": "model-00002-of-00002.safetensors",
256
+ "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
257
+ "model.layers.27.self_attn.v_proj.bias": "model-00002-of-00002.safetensors",
258
+ "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
259
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
260
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
261
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
262
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
263
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
264
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
265
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
266
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
267
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
268
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
269
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
270
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
271
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
272
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
273
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
274
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
275
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
276
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
277
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
278
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
279
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
280
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
281
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
282
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
283
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
284
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
285
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
286
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
287
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
288
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
289
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
290
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
291
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
292
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
293
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
294
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
295
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
296
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
297
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
298
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
299
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
300
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
301
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
302
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
303
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
304
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
305
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
306
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
307
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
308
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
309
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
310
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
311
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
312
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
313
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
314
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
315
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
316
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
317
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
318
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
319
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
320
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
321
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
322
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
323
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
324
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
325
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
326
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
327
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
328
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
329
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
330
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
331
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
332
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
333
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
334
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
335
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
336
+ "model.layers.9.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
337
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
338
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
339
+ "model.layers.9.self_attn.q_proj.bias": "model-00001-of-00002.safetensors",
340
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
341
+ "model.layers.9.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
342
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
343
+ "model.norm.weight": "model-00002-of-00002.safetensors"
344
+ }
345
+ }
src/Qwen-encoder-1.5B/special_tokens_map.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>"
5
+ ],
6
+ "eos_token": {
7
+ "content": "<|im_end|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "mask_token": "_",
14
+ "pad_token": {
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ }
21
+ }
src/Qwen-encoder-1.5B/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
src/Qwen-encoder-1.5B/tokenizer_config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "62": {
5
+ "content": "_",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151643": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151644": {
21
+ "content": "<|im_start|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "151645": {
29
+ "content": "<|im_end|>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ }
36
+ },
37
+ "additional_special_tokens": [
38
+ "<|im_start|>",
39
+ "<|im_end|>"
40
+ ],
41
+ "bos_token": null,
42
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
43
+ "clean_up_tokenization_spaces": false,
44
+ "eos_token": "<|im_end|>",
45
+ "errors": "replace",
46
+ "mask_token": "_",
47
+ "model_max_length": 32768,
48
+ "pad_token": "<|endoftext|>",
49
+ "split_special_tokens": false,
50
+ "tokenizer_class": "Qwen2Tokenizer",
51
+ "unk_token": null
52
+ }
src/Qwen-encoder-1.5B/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
src/open_clip/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .version import __version__
2
+
3
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
4
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
5
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
6
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
8
+ get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
9
+ from .openai import load_openai_model, list_openai_models
10
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
11
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
12
+ from .tokenizer import SimpleTokenizer, tokenize, decode
13
+ from .transform import image_transform, AugmentationCfg
src/open_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
src/open_clip/constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
4
+ IMAGENET_STD = (0.229, 0.224, 0.225)
5
+ INCEPTION_MEAN = (0.5, 0.5, 0.5)
6
+ INCEPTION_STD = (0.5, 0.5, 0.5)
src/open_clip/convert.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
2
+ """
3
+ from typing import Union
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from .model import CLIP, CustomTextCLIP
9
+ from .transformer import TextTransformer, Transformer
10
+
11
+
12
+ @torch.no_grad()
13
+ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
14
+ """ Load weights from .npz checkpoints for official Google big_vision image-text models
15
+
16
+ Currently the SigLIP source models are supported and a CustomTextCLIP destination model
17
+ w/ timm image encoder.
18
+ """
19
+ from timm.layers import resample_patch_embed, resample_abs_pos_embed
20
+
21
+ def _n2p(w, t=True):
22
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
23
+ w = w.flatten()
24
+ if t:
25
+ if w.ndim == 4:
26
+ w = w.transpose([3, 2, 0, 1])
27
+ elif w.ndim == 3:
28
+ w = w.transpose([2, 0, 1])
29
+ elif w.ndim == 2:
30
+ w = w.transpose([1, 0])
31
+ return torch.from_numpy(w)
32
+
33
+ w = np.load(checkpoint_path)
34
+ interpolation = 'bilinear'
35
+ antialias = False
36
+
37
+ def _convert_timm_img(module, prefix):
38
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
39
+ if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
40
+ embed_conv_w = resample_patch_embed(
41
+ embed_conv_w,
42
+ module.patch_embed.proj.weight.shape[-2:],
43
+ interpolation=interpolation,
44
+ antialias=antialias,
45
+ verbose=True,
46
+ )
47
+ module.patch_embed.proj.weight.copy_(embed_conv_w)
48
+ module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
49
+
50
+ if module.cls_token is not None:
51
+ module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
52
+
53
+ pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
54
+ if pos_embed_w.shape != module.pos_embed.shape:
55
+ assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
56
+ num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
57
+ pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
58
+ pos_embed_w,
59
+ new_size=module.patch_embed.grid_size,
60
+ num_prefix_tokens=num_prefix_tokens,
61
+ interpolation=interpolation,
62
+ antialias=antialias,
63
+ verbose=True,
64
+ )
65
+ module.pos_embed.copy_(pos_embed_w)
66
+
67
+ mha_sub, b_sub, ln1_sub = (0, 0, 1)
68
+ for i, block in enumerate(module.blocks.children()):
69
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
70
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
71
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
72
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
73
+ block.attn.qkv.weight.copy_(torch.cat([
74
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
75
+ block.attn.qkv.bias.copy_(torch.cat([
76
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
77
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
78
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
79
+ for r in range(2):
80
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
81
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
82
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
83
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
84
+
85
+ module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
86
+ module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
87
+
88
+ if module.attn_pool is not None:
89
+ block_prefix = f'{prefix}MAPHead_0/'
90
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
91
+ module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
92
+ module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
93
+ module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
94
+ module.attn_pool.kv.weight.copy_(torch.cat([
95
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
96
+ module.attn_pool.kv.bias.copy_(torch.cat([
97
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
98
+ module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
99
+ module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
100
+ module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
101
+ module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
102
+ for r in range(2):
103
+ getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
104
+ getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
105
+
106
+ def _convert_openclip_transformer(module: Transformer, prefix):
107
+ for i, block in enumerate(module.resblocks.children()):
108
+ block_prefix = f'{prefix}encoderblock_{i}/'
109
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
110
+ block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
111
+ block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
112
+ block.attn.in_proj_weight.copy_(torch.cat([
113
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
114
+ block.attn.in_proj_bias.copy_(torch.cat([
115
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
116
+ block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
117
+ block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
118
+ block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
119
+ block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
120
+ block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
121
+ block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
122
+ block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
123
+ block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
124
+
125
+ def _convert_openclip_txt(module: TextTransformer, prefix):
126
+ module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
127
+ pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
128
+ module.positional_embedding.copy_(pos_embed_w)
129
+ _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
130
+ module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
131
+ module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
132
+ module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
133
+ module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
134
+
135
+ _convert_timm_img(model.visual.trunk, 'params/img/')
136
+ _convert_openclip_txt(model.text, 'params/txt/')
137
+ model.logit_bias.copy_(_n2p(w['params/b'])[0])
138
+ model.logit_scale.copy_(_n2p(w['params/t'])[0])
139
+
140
+
141
+ @torch.no_grad()
142
+ def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
143
+
144
+ def _convert_timm_img(state_dict):
145
+ if fastvit:
146
+ from timm.models.fastvit import checkpoint_filter_fn
147
+ else:
148
+ from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
149
+ timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
150
+ timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
151
+ return timm_state_dict
152
+
153
+ def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
154
+ text_dict = {}
155
+ for k, v in state_dict.items():
156
+ if not k.startswith(prefix):
157
+ continue
158
+ k = k.replace(prefix, '')
159
+ k = k.replace('projection_layer', 'text_projection')
160
+ k = k.replace('embedding_layer', 'token_embedding')
161
+ if k.startswith('positional_embedding.pos_embed.pos_embed'):
162
+ k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
163
+ v = v.squeeze()
164
+ k = k.replace('final_layer_norm', 'ln_final')
165
+ k = k.replace('pre_norm_mha.0', 'ln_1')
166
+ k = k.replace('pre_norm_mha.1', 'attn')
167
+ k = k.replace('pre_norm_ffn.0', 'ln_2')
168
+ k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
169
+ k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
170
+ k = k.replace('qkv_proj.weight', 'in_proj_weight')
171
+ k = k.replace('qkv_proj.bias', 'in_proj_bias')
172
+ k = k.replace('transformer.', 'transformer.resblocks.')
173
+ text_dict['text.' + k] = v
174
+ return text_dict
175
+
176
+ image_dict = _convert_timm_img(state_dict)
177
+ text_dict = _convert_openclip_txt(state_dict)
178
+ out_dict = {**image_dict, **text_dict}
179
+ out_dict['logit_scale'] = state_dict['logit_scale']
180
+ return out_dict
181
+
182
+
183
+ def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
184
+ if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
185
+ # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
186
+ state_dict = convert_mobile_clip_state_dict(model, state_dict)
187
+ if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
188
+ # convert b model
189
+ state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
190
+ return state_dict
src/open_clip/factory.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ from copy import deepcopy
6
+ from dataclasses import asdict
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from .convert import convert_state_dict
14
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
15
+ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
16
+ from .openai import load_openai_model
17
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
18
+ list_pretrained_tags_by_model, download_pretrained_from_hf
19
+ from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
20
+ from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
21
+
22
+ HF_HUB_PREFIX = 'hf-hub:'
23
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
24
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
25
+
26
+
27
+ def _natural_key(string_):
28
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
29
+
30
+
31
+ def _rescan_model_configs():
32
+ global _MODEL_CONFIGS
33
+
34
+ config_ext = ('.json',)
35
+ config_files = []
36
+ for config_path in _MODEL_CONFIG_PATHS:
37
+ if config_path.is_file() and config_path.suffix in config_ext:
38
+ config_files.append(config_path)
39
+ elif config_path.is_dir():
40
+ for ext in config_ext:
41
+ config_files.extend(config_path.glob(f'*{ext}'))
42
+
43
+ for cf in config_files:
44
+ with open(cf, 'r') as f:
45
+ model_cfg = json.load(f)
46
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
47
+ _MODEL_CONFIGS[cf.stem] = model_cfg
48
+
49
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
50
+
51
+
52
+ _rescan_model_configs() # initial populate of model config registry
53
+
54
+
55
+ def list_models():
56
+ """ enumerate available model architectures based on config files """
57
+ return list(_MODEL_CONFIGS.keys())
58
+
59
+
60
+ def add_model_config(path):
61
+ """ add model config path or file and update registry """
62
+ if not isinstance(path, Path):
63
+ path = Path(path)
64
+ _MODEL_CONFIG_PATHS.append(path)
65
+ _rescan_model_configs()
66
+
67
+
68
+ def get_model_config(model_name):
69
+ if model_name in _MODEL_CONFIGS:
70
+ return deepcopy(_MODEL_CONFIGS[model_name])
71
+ else:
72
+ return None
73
+
74
+
75
+ def _get_hf_config(model_id, cache_dir=None):
76
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
77
+ with open(config_path, 'r', encoding='utf-8') as f:
78
+ config = json.load(f)
79
+ return config
80
+
81
+
82
+ def get_tokenizer(
83
+ model_name: str = '',
84
+ context_length: Optional[int] = None,
85
+ **kwargs,
86
+ ):
87
+ if model_name.startswith(HF_HUB_PREFIX):
88
+ model_name = model_name[len(HF_HUB_PREFIX):]
89
+ try:
90
+ config = _get_hf_config(model_name)['model_cfg']
91
+ except Exception:
92
+ tokenizer = HFTokenizer(
93
+ model_name,
94
+ context_length=context_length or DEFAULT_CONTEXT_LENGTH,
95
+ **kwargs,
96
+ )
97
+ return tokenizer
98
+ else:
99
+ config = get_model_config(model_name)
100
+ assert config is not None, f"No valid model config found for {model_name}."
101
+
102
+ text_config = config.get('text_cfg', {})
103
+ if 'tokenizer_kwargs' in text_config:
104
+ tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
105
+ else:
106
+ tokenizer_kwargs = kwargs
107
+
108
+ if context_length is None:
109
+ context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
110
+
111
+ if 'hf_tokenizer_name' in text_config:
112
+ tokenizer = HFTokenizer(
113
+ text_config['hf_tokenizer_name'],
114
+ context_length=context_length,
115
+ **tokenizer_kwargs,
116
+ )
117
+ else:
118
+ tokenizer = SimpleTokenizer(
119
+ context_length=context_length,
120
+ **tokenizer_kwargs,
121
+ )
122
+
123
+ return tokenizer
124
+
125
+
126
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
127
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
128
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
129
+ state_dict = checkpoint['state_dict']
130
+ elif isinstance(checkpoint, torch.jit.ScriptModule):
131
+ state_dict = checkpoint.state_dict()
132
+ for key in ["input_resolution", "context_length", "vocab_size"]:
133
+ state_dict.pop(key, None)
134
+ else:
135
+ state_dict = checkpoint
136
+ if next(iter(state_dict.items()))[0].startswith('module'):
137
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
138
+ return state_dict
139
+
140
+
141
+ def load_checkpoint(
142
+ model: Union[CLIP, CustomTextCLIP],
143
+ checkpoint_path: str,
144
+ strict: bool = True,
145
+ ):
146
+ if Path(checkpoint_path).suffix in ('.npz', '.npy'):
147
+ # Separate path loading numpy big_vision (SigLIP) weights
148
+ from open_clip.convert import load_big_vision_weights
149
+ load_big_vision_weights(model, checkpoint_path)
150
+ return {}
151
+
152
+ state_dict = load_state_dict(checkpoint_path)
153
+
154
+ # Detect & convert 3rd party state_dicts -> open_clip
155
+ state_dict = convert_state_dict(model, state_dict)
156
+
157
+ # Detect old format and make compatible with new format
158
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
159
+ state_dict = convert_to_custom_text_state_dict(state_dict)
160
+
161
+ # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
162
+ if 'logit_bias' not in state_dict and model.logit_bias is not None:
163
+ state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
164
+
165
+ # Certain text transformers no longer expect position_ids after transformers==4.31
166
+ position_id_key = 'text.transformer.embeddings.position_ids'
167
+ if position_id_key in state_dict and not hasattr(model, position_id_key):
168
+ del state_dict[position_id_key]
169
+
170
+ # resize_pos_embed(state_dict, model)
171
+ resize_text_pos_embed(state_dict, model)
172
+
173
+ # Finally, load the massaged state_dict into model
174
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
175
+ # incompatible_keys = []
176
+
177
+ return incompatible_keys
178
+
179
+
180
+ def create_model(
181
+ model_name: str,
182
+ pretrained: Optional[str] = None,
183
+ precision: str = 'fp32',
184
+ device: Union[str, torch.device] = 'cpu',
185
+ jit: bool = False,
186
+ force_quick_gelu: bool = False,
187
+ force_custom_text: bool = False,
188
+ force_patch_dropout: Optional[float] = None,
189
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
190
+ force_preprocess_cfg: Optional[Dict[str, Any]] = None,
191
+ pretrained_image: bool = False,
192
+ pretrained_hf: bool = True,
193
+ cache_dir: Optional[str] = None,
194
+ output_dict: Optional[bool] = None,
195
+ require_pretrained: bool = False,
196
+ **model_kwargs,
197
+ ):
198
+ force_preprocess_cfg = force_preprocess_cfg or {}
199
+ preprocess_cfg = asdict(PreprocessCfg())
200
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
201
+ if has_hf_hub_prefix:
202
+ model_id = model_name[len(HF_HUB_PREFIX):]
203
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
204
+ config = _get_hf_config(model_id, cache_dir)
205
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
206
+ model_cfg = config['model_cfg']
207
+ pretrained_hf = False # override, no need to load original HF text weights
208
+ else:
209
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
210
+ checkpoint_path = None
211
+ model_cfg = None
212
+
213
+ if isinstance(device, str):
214
+ device = torch.device(device)
215
+
216
+ if pretrained and pretrained.lower() == 'openai':
217
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
218
+ model = load_openai_model(
219
+ model_name,
220
+ precision=precision,
221
+ device=device,
222
+ cache_dir=cache_dir,
223
+ )
224
+ else:
225
+ model_cfg = model_cfg or get_model_config(model_name)
226
+ if model_cfg is not None:
227
+ logging.info(f'Loaded {model_name} model config.')
228
+ else:
229
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
230
+ raise RuntimeError(f'Model config for {model_name} not found.')
231
+
232
+ if force_quick_gelu:
233
+ # override for use of QuickGELU on non-OpenAI transformer models
234
+ model_cfg["quick_gelu"] = True
235
+
236
+ if force_patch_dropout is not None:
237
+ # override the default patch dropout value
238
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
239
+
240
+ if force_image_size is not None:
241
+ # override model config's image size
242
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
243
+
244
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
245
+ if pretrained_image:
246
+ if is_timm_model:
247
+ # pretrained weight loading for timm models set via vision_cfg
248
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
249
+ else:
250
+ assert False, 'pretrained image towers currently only supported for timm models'
251
+
252
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
253
+ cast_dtype = get_cast_dtype(precision)
254
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
255
+ if is_hf_model:
256
+ # load pretrained weights for HF text model IFF no CLIP weights being loaded
257
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
258
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
259
+
260
+ model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
261
+ if custom_text:
262
+ if "multimodal_cfg" in model_cfg:
263
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
264
+ else:
265
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
266
+ else:
267
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
268
+
269
+ if precision in ("fp16", "bf16"):
270
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
271
+ # manual mixed precision that matches original OpenAI behaviour
272
+ if is_timm_model:
273
+ # FIXME this is a bit janky, create timm based model in low-precision and
274
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
275
+ # Why? The convert_weights_to_lp fn only works with native models.
276
+ model.to(device=device, dtype=dtype)
277
+ from .transformer import LayerNormFp32
278
+
279
+ def _convert_ln(m):
280
+ if isinstance(m, LayerNormFp32):
281
+ m.weight.data = m.weight.data.to(torch.float32)
282
+ m.bias.data = m.bias.data.to(torch.float32)
283
+ model.apply(_convert_ln)
284
+ else:
285
+ model.to(device=device)
286
+ convert_weights_to_lp(model, dtype=dtype)
287
+ elif precision in ("pure_fp16", "pure_bf16"):
288
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
289
+ model.to(device=device, dtype=dtype)
290
+ else:
291
+ model.to(device=device)
292
+
293
+ pretrained_loaded = False
294
+ if pretrained:
295
+ checkpoint_path = ''
296
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
297
+ if pretrained_cfg:
298
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
299
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
300
+ elif os.path.exists(pretrained):
301
+ checkpoint_path = pretrained
302
+
303
+ if checkpoint_path:
304
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
305
+ load_checkpoint(model, checkpoint_path)
306
+ else:
307
+ error_str = (
308
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
309
+ f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
310
+ logging.warning(error_str)
311
+ raise RuntimeError(error_str)
312
+ pretrained_loaded = True
313
+ elif has_hf_hub_prefix:
314
+ logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
315
+ load_checkpoint(model, checkpoint_path)
316
+ pretrained_loaded = True
317
+
318
+ if require_pretrained and not pretrained_loaded:
319
+ # callers of create_model_from_pretrained always expect pretrained weights
320
+ raise RuntimeError(
321
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
322
+
323
+ if output_dict and hasattr(model, "output_dict"):
324
+ model.output_dict = True
325
+
326
+ if jit:
327
+ model = torch.jit.script(model)
328
+
329
+ # set image preprocessing configuration in model attributes for convenience
330
+ if getattr(model.visual, 'image_size', None) is not None:
331
+ # use image_size set on model creation (via config or force_image_size arg)
332
+ force_preprocess_cfg['size'] = model.visual.image_size
333
+ set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
334
+
335
+ return model
336
+
337
+
338
+ def create_loss(args):
339
+ if args.distill:
340
+ return DistillClipLoss(
341
+ local_loss=args.local_loss,
342
+ gather_with_grad=args.gather_with_grad,
343
+ cache_labels=True,
344
+ rank=args.rank,
345
+ world_size=args.world_size,
346
+ use_horovod=args.horovod,
347
+ )
348
+ elif "coca" in args.model.lower():
349
+ return CoCaLoss(
350
+ caption_loss_weight=args.coca_caption_loss_weight,
351
+ clip_loss_weight=args.coca_contrastive_loss_weight,
352
+ local_loss=args.local_loss,
353
+ gather_with_grad=args.gather_with_grad,
354
+ cache_labels=True,
355
+ rank=args.rank,
356
+ world_size=args.world_size,
357
+ use_horovod=args.horovod,
358
+ )
359
+ elif args.siglip:
360
+ assert not args.horovod, "Horovod not currently supported for SigLip"
361
+ return SigLipLoss(
362
+ rank=args.rank,
363
+ world_size=args.world_size,
364
+ )
365
+ return ClipLoss(
366
+ local_loss=args.local_loss,
367
+ gather_with_grad=args.gather_with_grad,
368
+ cache_labels=True,
369
+ rank=args.rank,
370
+ world_size=args.world_size,
371
+ use_horovod=args.horovod,
372
+ )
373
+
374
+
375
+ def create_model_and_transforms(
376
+ model_name: str,
377
+ pretrained: Optional[str] = None,
378
+ precision: str = 'fp32',
379
+ device: Union[str, torch.device] = 'cpu',
380
+ jit: bool = False,
381
+ force_quick_gelu: bool = False,
382
+ force_custom_text: bool = False,
383
+ force_patch_dropout: Optional[float] = None,
384
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
385
+ image_mean: Optional[Tuple[float, ...]] = None,
386
+ image_std: Optional[Tuple[float, ...]] = None,
387
+ image_interpolation: Optional[str] = None,
388
+ image_resize_mode: Optional[str] = None, # only effective for inference
389
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
390
+ pretrained_image: bool = False,
391
+ pretrained_hf: bool = True,
392
+ cache_dir: Optional[str] = None,
393
+ output_dict: Optional[bool] = None,
394
+ **model_kwargs,
395
+ ):
396
+ force_preprocess_cfg = merge_preprocess_kwargs(
397
+ {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
398
+
399
+ model = create_model(
400
+ model_name,
401
+ pretrained,
402
+ precision=precision,
403
+ device=device,
404
+ jit=jit,
405
+ force_quick_gelu=force_quick_gelu,
406
+ force_custom_text=force_custom_text,
407
+ force_patch_dropout=force_patch_dropout,
408
+ force_image_size=force_image_size,
409
+ force_preprocess_cfg=force_preprocess_cfg,
410
+ pretrained_image=pretrained_image,
411
+ pretrained_hf=pretrained_hf,
412
+ cache_dir=cache_dir,
413
+ output_dict=output_dict,
414
+ **model_kwargs,
415
+ )
416
+
417
+ pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
418
+
419
+ preprocess_train = image_transform_v2(
420
+ pp_cfg,
421
+ is_train=True,
422
+ aug_cfg=aug_cfg,
423
+ )
424
+ preprocess_val = image_transform_v2(
425
+ pp_cfg,
426
+ is_train=False,
427
+ )
428
+
429
+ return model, preprocess_train, preprocess_val
430
+
431
+
432
+ def create_model_from_pretrained(
433
+ model_name: str,
434
+ pretrained: Optional[str] = None,
435
+ precision: str = 'fp32',
436
+ device: Union[str, torch.device] = 'cpu',
437
+ jit: bool = False,
438
+ force_quick_gelu: bool = False,
439
+ force_custom_text: bool = False,
440
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
441
+ image_mean: Optional[Tuple[float, ...]] = None,
442
+ image_std: Optional[Tuple[float, ...]] = None,
443
+ image_interpolation: Optional[str] = None,
444
+ image_resize_mode: Optional[str] = None, # only effective for inference
445
+ return_transform: bool = True,
446
+ cache_dir: Optional[str] = None,
447
+ **model_kwargs,
448
+ ):
449
+ force_preprocess_cfg = merge_preprocess_kwargs(
450
+ {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
451
+
452
+ model = create_model(
453
+ model_name,
454
+ pretrained,
455
+ precision=precision,
456
+ device=device,
457
+ jit=jit,
458
+ force_quick_gelu=force_quick_gelu,
459
+ force_custom_text=force_custom_text,
460
+ force_image_size=force_image_size,
461
+ force_preprocess_cfg=force_preprocess_cfg,
462
+ cache_dir=cache_dir,
463
+ require_pretrained=True,
464
+ **model_kwargs,
465
+ )
466
+
467
+ if not return_transform:
468
+ return model
469
+
470
+ preprocess = image_transform_v2(
471
+ PreprocessCfg(**model.visual.preprocess_cfg),
472
+ is_train=False,
473
+ )
474
+
475
+ return model, preprocess
src/open_clip/hf_configs.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ # https://huggingface.co/docs/transformers/model_doc/bert
46
+ "bert": {
47
+ "config_names": {
48
+ "context_length": "max_position_embeddings",
49
+ "vocab_size": "vocab_size",
50
+ "width": "hidden_size",
51
+ "heads": "num_attention_heads",
52
+ "layers": "num_hidden_layers",
53
+ },
54
+ "pooler": "cls_pooler",
55
+ },
56
+ # https://huggingface.co/docs/transformers/model_doc/m2m_100
57
+ "m2m_100": {
58
+ "config_names": {
59
+ "context_length": "max_position_embeddings",
60
+ "vocab_size": "vocab_size",
61
+ "width": "d_model",
62
+ "heads": "encoder_attention_heads",
63
+ "layers": "encoder_layers",
64
+ },
65
+ "pooler": "cls_pooler",
66
+ },
67
+ }
src/open_clip/hf_model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+ import re
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import TensorType
10
+
11
+ try:
12
+ import transformers
13
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15
+ BaseModelOutputWithPoolingAndCrossAttentions
16
+ except ImportError as e:
17
+ transformers = None
18
+
19
+
20
+ class BaseModelOutput:
21
+ pass
22
+
23
+
24
+ class PretrainedConfig:
25
+ pass
26
+
27
+ from .hf_configs import arch_dict
28
+
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+
35
+ # TODO: ?last - for gpt-like models
36
+ _POOLERS = {}
37
+
38
+
39
+ def register_pooler(cls):
40
+ """Decorator registering pooler class"""
41
+ _POOLERS[_camel2snake(cls.__name__)] = cls
42
+ return cls
43
+
44
+
45
+ @register_pooler
46
+ class MeanPooler(nn.Module):
47
+ """Mean pooling"""
48
+
49
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
50
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
51
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
52
+
53
+
54
+ @register_pooler
55
+ class MaxPooler(nn.Module):
56
+ """Max pooling"""
57
+
58
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
59
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
60
+ return masked_output.max(1).values
61
+
62
+
63
+ @register_pooler
64
+ class ClsPooler(nn.Module):
65
+ """CLS token pooling"""
66
+
67
+ def __init__(self, use_pooler_output=True):
68
+ super().__init__()
69
+ self.cls_token_position = 0
70
+ self.use_pooler_output = use_pooler_output
71
+
72
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
73
+ if (self.use_pooler_output and
74
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
75
+ (x.pooler_output is not None)
76
+ ):
77
+ return x.pooler_output
78
+
79
+ return x.last_hidden_state[:, self.cls_token_position, :]
80
+
81
+
82
+ @register_pooler
83
+ class ClsLastHiddenStatePooler(nn.Module):
84
+ """CLS token pooling
85
+ NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
86
+ """
87
+
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.cls_token_position = 0
91
+
92
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
93
+ return x.last_hidden_state[:, self.cls_token_position, :]
94
+
95
+
96
+ class HFTextEncoder(nn.Module):
97
+ """HuggingFace model adapter"""
98
+ output_tokens: torch.jit.Final[bool]
99
+
100
+ def __init__(
101
+ self,
102
+ model_name_or_path: str,
103
+ output_dim: int,
104
+ config: PretrainedConfig = None,
105
+ pooler_type: str = None,
106
+ proj_type: str = None,
107
+ pretrained: bool = True,
108
+ output_tokens: bool = False,
109
+ ):
110
+ super().__init__()
111
+ self.output_tokens = output_tokens
112
+ self.output_dim = output_dim
113
+
114
+ # TODO: find better way to get this information
115
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
116
+
117
+ if transformers is None:
118
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
119
+ if config is None:
120
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
121
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
122
+ AutoModel.from_config, self.config)
123
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
124
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
125
+ self.transformer = create_func(model_args)
126
+ self.transformer = self.transformer.encoder
127
+ else:
128
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
129
+ else:
130
+ self.config = config
131
+ self.transformer = AutoModel.from_config(config)
132
+ if pooler_type is None: # get default arch pooler
133
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
134
+
135
+ # FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
136
+ self.vocab_size = getattr(self.config, 'vocab_size', 0)
137
+ self.context_length = getattr(self.config, 'max_position_embeddings', 0)
138
+
139
+ self.pooler = _POOLERS[pooler_type]()
140
+
141
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
142
+ if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
143
+ self.proj = nn.Identity()
144
+ elif proj_type == 'linear':
145
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
146
+ elif proj_type == 'mlp':
147
+ hidden_size = (d_model + output_dim) // 2
148
+ self.proj = nn.Sequential(
149
+ nn.Linear(d_model, hidden_size, bias=False),
150
+ nn.GELU(),
151
+ nn.Linear(hidden_size, output_dim, bias=False),
152
+ )
153
+
154
+ def forward(self, x: TensorType):
155
+ attn_mask = (x != self.config.pad_token_id).long()
156
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
157
+ pooled_out = self.pooler(out, attn_mask)
158
+ projected = self.proj(pooled_out)
159
+
160
+ seq_len = out.last_hidden_state.shape[1]
161
+ tokens = (
162
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
163
+ if type(self.pooler) == ClsPooler
164
+ else out.last_hidden_state
165
+ )
166
+
167
+ if self.output_tokens:
168
+ return projected, tokens
169
+ return projected
170
+
171
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
172
+ if not unlocked_layers: # full freezing
173
+ for n, p in self.transformer.named_parameters():
174
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
175
+ return
176
+
177
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
178
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
179
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
180
+ embeddings = getattr(
181
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
182
+ modules = [embeddings, *layer_list][:-unlocked_layers]
183
+ # freeze layers
184
+ for module in modules:
185
+ for n, p in module.named_parameters():
186
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
187
+
188
+ @torch.jit.ignore
189
+ def set_grad_checkpointing(self, enable=True):
190
+ self.transformer.gradient_checkpointing_enable()
191
+
192
+ def init_parameters(self):
193
+ pass
src/open_clip/model.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import copy
6
+ import logging
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, Optional, Tuple, Union
10
+ import timm
11
+ from timm.data import resolve_data_config
12
+ from timm.data.transforms_factory import create_transform
13
+ from timm.layers import SwiGLUPacked
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+ from torch.utils.checkpoint import checkpoint
19
+ from functools import partial
20
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
21
+ from llm2vec.models import Qwen2BiModel
22
+
23
+ # from .hf_configs import arch_dict
24
+ from .hf_model import HFTextEncoder
25
+ from .timm_model import TimmModel
26
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
27
+ text_global_pool
28
+ from .utils import to_2tuple
29
+
30
+
31
+ @dataclass
32
+ class CLIPVisionCfg:
33
+ layers: Union[Tuple[int, int, int, int], int] = 12
34
+ width: int = 768
35
+ head_width: int = 64
36
+ mlp_ratio: float = 4.0
37
+ patch_size: int = 16
38
+ image_size: Union[Tuple[int, int], int] = 224
39
+
40
+ ls_init_value: Optional[float] = None # layer scale initial value
41
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
42
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
43
+ attn_pooler_queries: int = 256 # n_queries for attentional pooler
44
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
45
+ no_ln_pre: bool = False # disable pre transformer LayerNorm
46
+ pos_embed_type: str = 'learnable'
47
+ final_ln_after_pool: bool = False # apply final LayerNorm after pooling
48
+ pool_type: str = 'tok'
49
+ output_tokens: bool = False
50
+ act_kwargs: Optional[dict] = None
51
+ norm_kwargs: Optional[dict] = None
52
+
53
+ timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
54
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
55
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
56
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
57
+ timm_proj_bias: bool = False # enable bias final projection
58
+ timm_drop: float = 0. # head dropout
59
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
60
+
61
+
62
+ @dataclass
63
+ class CLIPTextCfg:
64
+ context_length: int = 77
65
+ vocab_size: int = 49408
66
+ hf_tokenizer_name: Optional[str] = None
67
+ tokenizer_kwargs: Optional[dict] = None
68
+
69
+ width: int = 512
70
+ heads: int = 8
71
+ layers: int = 12
72
+ mlp_ratio: float = 4.0
73
+ ls_init_value: Optional[float] = None # layer scale initial value
74
+ embed_cls: bool = False
75
+ pad_id: int = 0
76
+ no_causal_mask: bool = False # disable causal masking
77
+ final_ln_after_pool: bool = False # apply final LayerNorm after pooling
78
+ pool_type: str = 'argmax'
79
+ proj_bias: bool = False
80
+ output_tokens: bool = False
81
+ act_kwargs: dict = None
82
+ norm_kwargs: dict = None
83
+
84
+ # HuggingFace specific text tower config
85
+ hf_model_name: Optional[str] = None
86
+ hf_model_pretrained: bool = True
87
+ hf_proj_type: str = 'mlp'
88
+ hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
89
+
90
+
91
+ def get_cast_dtype(precision: str):
92
+ cast_dtype = None
93
+ if precision == 'bf16':
94
+ cast_dtype = torch.bfloat16
95
+ elif precision == 'fp16':
96
+ cast_dtype = torch.float16
97
+ return cast_dtype
98
+
99
+
100
+ def get_input_dtype(precision: str):
101
+ input_dtype = None
102
+ if precision in ('bf16', 'pure_bf16'):
103
+ input_dtype = torch.bfloat16
104
+ elif precision in ('fp16', 'pure_fp16'):
105
+ input_dtype = torch.float16
106
+ return input_dtype
107
+
108
+
109
+ def _build_vision_tower(
110
+ embed_dim: int,
111
+ vision_cfg: CLIPVisionCfg,
112
+ quick_gelu: bool = False,
113
+ cast_dtype: Optional[torch.dtype] = None
114
+ ):
115
+ if isinstance(vision_cfg, dict):
116
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
117
+
118
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
119
+ # memory efficient in recent PyTorch releases (>= 1.10).
120
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
121
+ act_layer = QuickGELU if quick_gelu else nn.GELU
122
+
123
+ if vision_cfg.timm_model_name:
124
+ visual = TimmModel(
125
+ vision_cfg.timm_model_name,
126
+ pretrained=vision_cfg.timm_model_pretrained,
127
+ pool=vision_cfg.timm_pool,
128
+ proj=vision_cfg.timm_proj,
129
+ proj_bias=vision_cfg.timm_proj_bias,
130
+ drop=vision_cfg.timm_drop,
131
+ drop_path=vision_cfg.timm_drop_path,
132
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
133
+ embed_dim=embed_dim,
134
+ image_size=vision_cfg.image_size,
135
+ )
136
+ elif isinstance(vision_cfg.layers, (tuple, list)):
137
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
138
+ visual = ModifiedResNet(
139
+ layers=vision_cfg.layers,
140
+ output_dim=embed_dim,
141
+ heads=vision_heads,
142
+ image_size=vision_cfg.image_size,
143
+ width=vision_cfg.width,
144
+ )
145
+ else:
146
+ vision_heads = vision_cfg.width // vision_cfg.head_width
147
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
148
+ if vision_cfg.norm_kwargs:
149
+ norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
150
+ if vision_cfg.act_kwargs is not None:
151
+ act_layer = partial(act_layer, **vision_cfg.act_kwargs)
152
+
153
+ visual = VisionTransformer(
154
+ image_size=vision_cfg.image_size,
155
+ patch_size=vision_cfg.patch_size,
156
+ width=vision_cfg.width,
157
+ layers=vision_cfg.layers,
158
+ heads=vision_heads,
159
+ mlp_ratio=vision_cfg.mlp_ratio,
160
+ ls_init_value=vision_cfg.ls_init_value,
161
+ patch_dropout=vision_cfg.patch_dropout,
162
+ attentional_pool=vision_cfg.attentional_pool,
163
+ attn_pooler_queries=vision_cfg.attn_pooler_queries,
164
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
165
+ pos_embed_type=vision_cfg.pos_embed_type,
166
+ no_ln_pre=vision_cfg.no_ln_pre,
167
+ final_ln_after_pool=vision_cfg.final_ln_after_pool,
168
+ pool_type=vision_cfg.pool_type,
169
+ output_tokens=vision_cfg.output_tokens,
170
+ output_dim=embed_dim,
171
+ act_layer=act_layer,
172
+ norm_layer=norm_layer,
173
+ )
174
+
175
+ return visual
176
+
177
+
178
+ def _build_text_tower(
179
+ embed_dim: int,
180
+ text_cfg: CLIPTextCfg,
181
+ quick_gelu: bool = False,
182
+ cast_dtype: Optional[torch.dtype] = None,
183
+ ):
184
+ if isinstance(text_cfg, dict):
185
+ text_cfg = CLIPTextCfg(**text_cfg)
186
+
187
+ if text_cfg.hf_model_name:
188
+ text = HFTextEncoder(
189
+ text_cfg.hf_model_name,
190
+ output_dim=embed_dim,
191
+ proj_type=text_cfg.hf_proj_type,
192
+ pooler_type=text_cfg.hf_pooler_type,
193
+ pretrained=text_cfg.hf_model_pretrained,
194
+ output_tokens=text_cfg.output_tokens,
195
+ )
196
+ else:
197
+ act_layer = QuickGELU if quick_gelu else nn.GELU
198
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
199
+ if text_cfg.norm_kwargs:
200
+ norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
201
+ if text_cfg.act_kwargs is not None:
202
+ act_layer = partial(act_layer, **text_cfg.act_kwargs)
203
+
204
+ text = TextTransformer(
205
+ context_length=text_cfg.context_length,
206
+ vocab_size=text_cfg.vocab_size,
207
+ width=text_cfg.width,
208
+ heads=text_cfg.heads,
209
+ layers=text_cfg.layers,
210
+ mlp_ratio=text_cfg.mlp_ratio,
211
+ ls_init_value=text_cfg.ls_init_value,
212
+ output_dim=embed_dim,
213
+ embed_cls=text_cfg.embed_cls,
214
+ no_causal_mask=text_cfg.no_causal_mask,
215
+ pad_id=text_cfg.pad_id,
216
+ pool_type=text_cfg.pool_type,
217
+ proj_bias=text_cfg.proj_bias,
218
+ output_tokens=text_cfg.output_tokens,
219
+ act_layer=act_layer,
220
+ norm_layer=norm_layer,
221
+ )
222
+ return text
223
+
224
+ def resize_pos_embed(state_dict, interpolation: str = 'bicubic', antialias: bool = True):
225
+ # Rescale the grid of position embeddings when loading from state_dict
226
+
227
+
228
+ old_pos_embed = state_dict.get('pos_embed', None)[0]
229
+ if old_pos_embed is None:
230
+ print('No positional embedding found in state_dict')
231
+ return
232
+ grid_size = to_2tuple([336 // 14, 336 // 14])
233
+ extra_tokens = 5 # FIXME detect different token configs (ie no class token, or more)
234
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
235
+ if new_seq_len == old_pos_embed.shape[0]:
236
+ print('Positional embedding grid-size matches model, no need to resize')
237
+ return
238
+
239
+ if extra_tokens:
240
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
241
+ else:
242
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
243
+
244
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
245
+
246
+ print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
247
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
248
+ pos_emb_img = F.interpolate(
249
+ pos_emb_img,
250
+ size=grid_size,
251
+ mode=interpolation,
252
+ antialias=antialias,
253
+ align_corners=False,
254
+ )
255
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)
256
+ if pos_emb_tok is not None:
257
+ # import pdb
258
+ # pdb.set_trace()
259
+ new_pos_embed = torch.cat([pos_emb_tok.unsqueeze(0), pos_emb_img], dim=1)
260
+ else:
261
+ new_pos_embed = pos_emb_img
262
+ state_dict['pos_embed'] = new_pos_embed
263
+ return state_dict
264
+
265
+ class CLIP(nn.Module):
266
+ output_dict: torch.jit.Final[bool]
267
+
268
+ def __init__(
269
+ self,
270
+ embed_dim: int,
271
+ vision_cfg: CLIPVisionCfg,
272
+ text_cfg: CLIPTextCfg,
273
+ quick_gelu: bool = False,
274
+ init_logit_scale: float = np.log(1 / 0.07),
275
+ init_logit_bias: Optional[float] = None,
276
+ cast_dtype: Optional[torch.dtype] = None,
277
+ output_dict: bool = False,
278
+ ):
279
+ super().__init__()
280
+ self.output_dict = output_dict
281
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
282
+ model = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=False, mlp_layer=SwiGLUPacked, patch_size=14, img_size=336, # 将 patch_size 修改为 21
283
+ act_layer=torch.nn.SiLU)
284
+ self.visual2 = model
285
+
286
+ config = AutoConfig.from_pretrained("../Qwen-encoder-1.5B")
287
+
288
+ # 初始化模型结构(不加载预训练参数)
289
+ self.text = Qwen2BiModel(config)
290
+ self.proj = nn.Linear(1536, 3328) # 2048+1280
291
+ self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
292
+ if init_logit_bias is not None:
293
+ self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
294
+ else:
295
+ self.logit_bias = None
296
+
297
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
298
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
299
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
300
+
301
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
302
+
303
+ if not unlocked_layers: # full freezing
304
+ for n, p in self.transformer.named_parameters():
305
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
306
+ return
307
+
308
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
309
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
310
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
311
+ embeddings = getattr(
312
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
313
+ modules = [embeddings, *layer_list][:-unlocked_layers]
314
+ # freeze layers
315
+ for module in modules:
316
+ for n, p in module.named_parameters():
317
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
318
+
319
+ ## lock position embedding
320
+ self.positional_embedding.requires_grad = False
321
+ ## lock token embedding
322
+ self.token_embedding.requires_grad = False
323
+ ## lock text projection
324
+ if self.text_projection is not None:
325
+ self.text_projection.requires_grad = False
326
+ @torch.jit.ignore
327
+ def set_grad_checkpointing(self, enable=True):
328
+ self.visual.set_grad_checkpointing(enable)
329
+ self.visual2.set_grad_checkpointing(enable)
330
+ # self.transformer.grad_checkpointing = enable
331
+ self.text._set_gradient_checkpointing(enable)
332
+
333
+ def encode_image(self, image, normalize: bool = False):
334
+ features = self.visual(image)
335
+ features2 = self.visual2(image)
336
+ features2 = torch.cat([features2[:, 0, :], features2[:, 5:, :].mean(1)], dim=-1)
337
+ features = torch.cat([features, features2], dim=-1)
338
+ return F.normalize(features, dim=-1) if normalize else features
339
+
340
+ def encode_text(self, text2, normalize: bool = False):
341
+ features = self.text(**text2)
342
+ ### mask attention
343
+ last_hidden_states = features.last_hidden_state # (batch_size, sequence_length, hidden_size)
344
+ attention_mask = text2['attention_mask'] # (batch_size, sequence_length)
345
+ # (batch_size, sequence_length, 1)
346
+ attention_mask = attention_mask.unsqueeze(-1).float()
347
+ masked_hidden_states = last_hidden_states * attention_mask
348
+ # (batch_size, 1, 1)
349
+ valid_token_count = attention_mask.sum(dim=1, keepdim=True)
350
+ # (batch_size, hidden_size)
351
+ features = masked_hidden_states.sum(dim=1) / valid_token_count.squeeze(1)
352
+ features = self.proj(features)
353
+ return F.normalize(features, dim=-1) if normalize else features
354
+
355
+ def get_logits(self, image, text):
356
+ image_features = self.encode_image(image, normalize=True)
357
+ text_features = self.encode_text(text, normalize=True)
358
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
359
+ if self.logit_bias is not None:
360
+ image_logits += self.logit_bias
361
+ text_logits = image_logits.T
362
+ return image_logits, text_logits
363
+
364
+ def forward(
365
+ self,
366
+ image: Optional[torch.Tensor] = None,
367
+ text: Optional[torch.Tensor] = None,
368
+ ):
369
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
370
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
371
+
372
+ if self.output_dict:
373
+ out_dict = {
374
+ "image_features": image_features,
375
+ "text_features": text_features,
376
+ "logit_scale": self.logit_scale.exp()
377
+ }
378
+ if self.logit_bias is not None:
379
+ out_dict['logit_bias'] = self.logit_bias
380
+ return out_dict
381
+
382
+ if self.logit_bias is not None:
383
+ return image_features, text_features, self.logit_scale.exp(), self.logit_bias
384
+ return image_features, text_features, self.logit_scale.exp()
385
+
386
+
387
+ class CustomTextCLIP(nn.Module):
388
+ output_dict: torch.jit.Final[bool]
389
+
390
+ def __init__(
391
+ self,
392
+ embed_dim: int,
393
+ vision_cfg: CLIPVisionCfg,
394
+ text_cfg: CLIPTextCfg,
395
+ quick_gelu: bool = False,
396
+ init_logit_scale: float = np.log(1 / 0.07),
397
+ init_logit_bias: Optional[float] = None,
398
+ cast_dtype: Optional[torch.dtype] = None,
399
+ output_dict: bool = False,
400
+ ):
401
+ super().__init__()
402
+ self.output_dict = output_dict
403
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
404
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
405
+ self.context_length = self.text.context_length
406
+ self.vocab_size = self.text.vocab_size
407
+ self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
408
+ if init_logit_bias is not None:
409
+ self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
410
+ else:
411
+ self.logit_bias = None
412
+
413
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
414
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
415
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
416
+
417
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
418
+ self.text.lock(unlocked_layers, freeze_layer_norm)
419
+
420
+ @torch.jit.ignore
421
+ def set_grad_checkpointing(self, enable=True):
422
+ self.visual.set_grad_checkpointing(enable)
423
+ self.text.set_grad_checkpointing(enable)
424
+
425
+ def encode_image(self, image, normalize: bool = False):
426
+ features = self.visual(image)
427
+ return F.normalize(features, dim=-1) if normalize else features
428
+
429
+ def encode_text(self, text, normalize: bool = False):
430
+ features = self.text(text)
431
+ return F.normalize(features, dim=-1) if normalize else features
432
+
433
+ def get_logits(self, image, text):
434
+ image_features = self.encode_image(image, normalize=True)
435
+ text_features = self.encode_text(text, normalize=True)
436
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
437
+ if self.logit_bias is not None:
438
+ image_logits += self.logit_bias
439
+ text_logits = image_logits.T
440
+ return image_logits, text_logits
441
+
442
+ def forward(
443
+ self,
444
+ image: Optional[torch.Tensor] = None,
445
+ text: Optional[torch.Tensor] = None,
446
+ ):
447
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
448
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
449
+
450
+ if self.output_dict:
451
+ out_dict = {
452
+ "image_features": image_features,
453
+ "text_features": text_features,
454
+ "logit_scale": self.logit_scale.exp()
455
+ }
456
+ if self.logit_bias is not None:
457
+ out_dict['logit_bias'] = self.logit_bias
458
+ return out_dict
459
+
460
+ if self.logit_bias is not None:
461
+ return image_features, text_features, self.logit_scale.exp(), self.logit_bias
462
+ return image_features, text_features, self.logit_scale.exp()
463
+
464
+
465
+
466
+
467
+
468
+ class CustomCLIP(nn.Module):
469
+ output_dict: torch.jit.Final[bool]
470
+
471
+ def __init__(
472
+ self,
473
+ embed_dim: int,
474
+ vision_cfg: CLIPVisionCfg,
475
+ text_cfg: CLIPTextCfg,
476
+ quick_gelu: bool = False,
477
+ init_logit_scale: float = np.log(1 / 0.07),
478
+ init_logit_bias: Optional[float] = None,
479
+ cast_dtype: Optional[torch.dtype] = None,
480
+ output_dict: bool = False,
481
+ ):
482
+ super().__init__()
483
+ self.output_dict = output_dict
484
+ model = timm.create_model('hf_hub:paige-ai/Virchow2', pretrained=False)
485
+
486
+ # 加载本地保存的模型权重
487
+ checkpoint_path = "/sunyuxuan/project/2024/model/vision_encoder/pathology/virchow2/pytorch_model.bin"
488
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
489
+ model.load_state_dict(state_dict)
490
+ # import pdb
491
+ # pdb.set_trace()
492
+
493
+
494
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
495
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
496
+ self.context_length = self.text.context_length
497
+ self.vocab_size = self.text.vocab_size
498
+ self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
499
+ if init_logit_bias is not None:
500
+ self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
501
+ else:
502
+ self.logit_bias = None
503
+
504
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
505
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
506
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
507
+
508
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
509
+ self.text.lock(unlocked_layers, freeze_layer_norm)
510
+
511
+ @torch.jit.ignore
512
+ def set_grad_checkpointing(self, enable=True):
513
+ self.visual.set_grad_checkpointing(enable)
514
+ self.text.set_grad_checkpointing(enable)
515
+
516
+ def encode_image(self, image, normalize: bool = False):
517
+ features = self.visual(image)
518
+ return F.normalize(features, dim=-1) if normalize else features
519
+
520
+ def encode_text(self, text, normalize: bool = False):
521
+ features = self.text(text)
522
+ return F.normalize(features, dim=-1) if normalize else features
523
+
524
+ def get_logits(self, image, text):
525
+ image_features = self.encode_image(image, normalize=True)
526
+ text_features = self.encode_text(text, normalize=True)
527
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
528
+ if self.logit_bias is not None:
529
+ image_logits += self.logit_bias
530
+ text_logits = image_logits.T
531
+ return image_logits, text_logits
532
+
533
+ def forward(
534
+ self,
535
+ image: Optional[torch.Tensor] = None,
536
+ text: Optional[torch.Tensor] = None,
537
+ ):
538
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
539
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
540
+
541
+ if self.output_dict:
542
+ out_dict = {
543
+ "image_features": image_features,
544
+ "text_features": text_features,
545
+ "logit_scale": self.logit_scale.exp()
546
+ }
547
+ if self.logit_bias is not None:
548
+ out_dict['logit_bias'] = self.logit_bias
549
+ return out_dict
550
+
551
+ if self.logit_bias is not None:
552
+ return image_features, text_features, self.logit_scale.exp(), self.logit_bias
553
+ return image_features, text_features, self.logit_scale.exp()
554
+
555
+
556
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
557
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
558
+
559
+ def _convert_weights(l):
560
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
561
+ l.weight.data = l.weight.data.to(dtype)
562
+ if l.bias is not None:
563
+ l.bias.data = l.bias.data.to(dtype)
564
+
565
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
566
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
567
+ tensor = getattr(l, attr)
568
+ if tensor is not None:
569
+ tensor.data = tensor.data.to(dtype)
570
+
571
+ if isinstance(l, (CLIP, TextTransformer)):
572
+ # convert text nn.Parameter projections
573
+ attr = getattr(l, "text_projection", None)
574
+ if attr is not None:
575
+ attr.data = attr.data.to(dtype)
576
+
577
+ if isinstance(l, VisionTransformer):
578
+ # convert vision nn.Parameter projections
579
+ attr = getattr(l, "proj", None)
580
+ if attr is not None:
581
+ attr.data = attr.data.to(dtype)
582
+
583
+ model.apply(_convert_weights)
584
+
585
+
586
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
587
+
588
+
589
+ # used to maintain checkpoint compatibility
590
+ def convert_to_custom_text_state_dict(state_dict: dict):
591
+ if 'text_projection' in state_dict:
592
+ # old format state_dict, move text tower -> .text
593
+ new_state_dict = {}
594
+ for k, v in state_dict.items():
595
+ if any(k.startswith(p) for p in (
596
+ 'text_projection',
597
+ 'positional_embedding',
598
+ 'token_embedding',
599
+ 'transformer',
600
+ 'ln_final',
601
+ )):
602
+ k = 'text.' + k
603
+ new_state_dict[k] = v
604
+ return new_state_dict
605
+ return state_dict
606
+
607
+
608
+ def build_model_from_openai_state_dict(
609
+ state_dict: dict,
610
+ quick_gelu=True,
611
+ cast_dtype=torch.float16,
612
+ ):
613
+ vit = "visual.proj" in state_dict
614
+
615
+ if vit:
616
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
617
+ vision_layers = len(
618
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
619
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
620
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
621
+ image_size = vision_patch_size * grid_size
622
+ else:
623
+ counts: list = [
624
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
625
+ vision_layers = tuple(counts)
626
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
627
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
628
+ vision_patch_size = None
629
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
630
+ image_size = output_width * 32
631
+
632
+ embed_dim = state_dict["text_projection"].shape[1]
633
+ context_length = state_dict["positional_embedding"].shape[0]
634
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
635
+ transformer_width = state_dict["ln_final.weight"].shape[0]
636
+ transformer_heads = transformer_width // 64
637
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
638
+
639
+ vision_cfg = CLIPVisionCfg(
640
+ layers=vision_layers,
641
+ width=vision_width,
642
+ patch_size=vision_patch_size,
643
+ image_size=image_size,
644
+ )
645
+ text_cfg = CLIPTextCfg(
646
+ context_length=context_length,
647
+ vocab_size=vocab_size,
648
+ width=transformer_width,
649
+ heads=transformer_heads,
650
+ layers=transformer_layers,
651
+ )
652
+ model = CLIP(
653
+ embed_dim,
654
+ vision_cfg=vision_cfg,
655
+ text_cfg=text_cfg,
656
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
657
+ cast_dtype=cast_dtype,
658
+ )
659
+
660
+ for key in ["input_resolution", "context_length", "vocab_size"]:
661
+ state_dict.pop(key, None)
662
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
663
+ model.load_state_dict(state_dict, strict=True)
664
+ return model.eval()
665
+
666
+
667
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
668
+ model.eval()
669
+ image_size = model.visual.image_size
670
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
671
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
672
+ model = torch.jit.trace_module(
673
+ model,
674
+ inputs=dict(
675
+ forward=(example_images, example_text),
676
+ encode_text=(example_text,),
677
+ encode_image=(example_images,)
678
+ ))
679
+ model.visual.image_size = image_size
680
+ return model
681
+ #
682
+ #
683
+ # def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
684
+ # # Rescale the grid of position embeddings when loading from state_dict
685
+ # old_pos_embed = state_dict.get('visual.positional_embedding', None)
686
+ # if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
687
+ # return
688
+ # grid_size = to_2tuple(model.visual.grid_size)
689
+ # extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
690
+ # new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
691
+ # if new_seq_len == old_pos_embed.shape[0]:
692
+ # return
693
+ #
694
+ # if extra_tokens:
695
+ # pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
696
+ # else:
697
+ # pos_emb_tok, pos_emb_img = None, old_pos_embed
698
+ # old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
699
+ #
700
+ # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
701
+ # pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
702
+ # pos_emb_img = F.interpolate(
703
+ # pos_emb_img,
704
+ # size=grid_size,
705
+ # mode=interpolation,
706
+ # antialias=antialias,
707
+ # align_corners=False,
708
+ # )
709
+ # pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
710
+ # if pos_emb_tok is not None:
711
+ # new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
712
+ # else:
713
+ # new_pos_embed = pos_emb_img
714
+ # state_dict['visual.positional_embedding'] = new_pos_embed
715
+
716
+
717
+ def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
718
+ old_pos_embed = state_dict.get('positional_embedding', None)
719
+ if old_pos_embed is None:
720
+ return
721
+ # FIXME add support for text cls_token
722
+ model_pos_embed = getattr(model, 'positional_embedding', None)
723
+ if model_pos_embed is None:
724
+ model_pos_embed = getattr(model.text, 'positional_embedding', None)
725
+
726
+ old_num_pos = old_pos_embed.shape[0]
727
+ old_width = old_pos_embed.shape[1]
728
+ num_pos = model_pos_embed.shape[0]
729
+ width = model_pos_embed.shape[1]
730
+ assert old_width == width, 'text pos_embed width changed!'
731
+ if old_num_pos == num_pos:
732
+ return
733
+
734
+ logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
735
+ old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
736
+ old_pos_embed = F.interpolate(
737
+ old_pos_embed,
738
+ size=num_pos,
739
+ mode=interpolation,
740
+ antialias=antialias,
741
+ align_corners=False,
742
+ )
743
+ old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
744
+ new_pos_embed = old_pos_embed
745
+
746
+ state_dict['positional_embedding'] = new_pos_embed
747
+
748
+
749
+ def get_model_preprocess_cfg(model):
750
+ module = getattr(model, 'visual', model)
751
+ preprocess_cfg = getattr(module, 'preprocess_cfg', {})
752
+ if not preprocess_cfg:
753
+ # use separate legacy attributes if preprocess_cfg dict not found
754
+ size = getattr(module, 'image_size')
755
+ if size is not None:
756
+ preprocess_cfg['size'] = size
757
+ mean = getattr(module, 'image_mean', None)
758
+ if mean is not None:
759
+ preprocess_cfg['mean'] = mean
760
+ std = getattr(module, 'image_std', None)
761
+ if std is not None:
762
+ preprocess_cfg['std'] = std
763
+ return preprocess_cfg
764
+
765
+
766
+ def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
767
+ module = getattr(model, 'visual', model)
768
+ module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
769
+ module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
770
+ module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
771
+
772
+
773
+ def get_model_tokenize_cfg(model):
774
+ module = getattr(model, 'text', model)
775
+ cfg = {}
776
+ context_length = getattr(module, 'context_length', None)
777
+ if context_length is not None:
778
+ cfg['context_length'] = context_length
779
+ vocab_size = getattr(module, 'vocab_size', None)
780
+ if vocab_size is not None:
781
+ cfg['vocab_size'] = vocab_size
782
+ return cfg
src/open_clip/model_configs/ViT-L-14-336.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "patch_size": 14
8
+ },
9
+ "text_cfg": {
10
+ "context_length": 77,
11
+ "vocab_size": 49408,
12
+ "width": 768,
13
+ "heads": 12,
14
+ "layers": 12
15
+ }
16
+ }
src/open_clip/openai.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
14
+ from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
15
+
16
+ __all__ = ["list_openai_models", "load_openai_model"]
17
+
18
+
19
+ def list_openai_models() -> List[str]:
20
+ """Returns the names of available CLIP models"""
21
+ return list_pretrained_models_by_tag('openai')
22
+
23
+
24
+ def load_openai_model(
25
+ name: str,
26
+ precision: Optional[str] = None,
27
+ device: Optional[Union[str, torch.device]] = None,
28
+ cache_dir: Optional[str] = None,
29
+ ):
30
+ """Load a CLIP model
31
+
32
+ Parameters
33
+ ----------
34
+ name : str
35
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
+ precision: str
37
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
+ device : Union[str, torch.device]
39
+ The device to put the loaded model
40
+ cache_dir : Optional[str]
41
+ The directory to cache the downloaded model weights
42
+
43
+ Returns
44
+ -------
45
+ model : torch.nn.Module
46
+ The CLIP model
47
+ preprocess : Callable[[PIL.Image], torch.Tensor]
48
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49
+ """
50
+ if device is None:
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ if precision is None:
53
+ precision = 'fp32' if device == 'cpu' else 'fp16'
54
+
55
+ if get_pretrained_url(name, 'openai'):
56
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
57
+ elif os.path.isfile(name):
58
+ model_path = name
59
+ else:
60
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
61
+
62
+ try:
63
+ # loading JIT archive
64
+ model = torch.jit.load(model_path, map_location="cpu").eval()
65
+ state_dict = None
66
+ except RuntimeError:
67
+ # loading saved state dict
68
+ state_dict = torch.load(model_path, map_location="cpu")
69
+
70
+ # Build a non-jit model from the OpenAI jitted model state dict
71
+ cast_dtype = get_cast_dtype(precision)
72
+ try:
73
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
74
+ except KeyError:
75
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
76
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
77
+
78
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
79
+ model = model.to(device)
80
+ # FIXME support pure fp16/bf16 precision modes
81
+ if precision != 'fp16':
82
+ model.float()
83
+ if precision == 'bf16':
84
+ # for bf16, convert back to low-precision
85
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
86
+
87
+ # add mean / std attributes for consistency with OpenCLIP models
88
+ model.visual.image_mean = OPENAI_DATASET_MEAN
89
+ model.visual.image_std = OPENAI_DATASET_STD
90
+ return model
src/open_clip/pos_embed.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ omega = np.arange(embed_dim // 2, dtype=float)
57
+ omega /= embed_dim / 2.
58
+ omega = 1. / 10000**omega # (D/2,)
59
+
60
+ pos = pos.reshape(-1) # (M,)
61
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62
+
63
+ emb_sin = np.sin(out) # (M, D/2)
64
+ emb_cos = np.cos(out) # (M, D/2)
65
+
66
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67
+ return emb
68
+
69
+
70
+ # --------------------------------------------------------
71
+ # Interpolate position embeddings for high-resolution
72
+ # References:
73
+ # DeiT: https://github.com/facebookresearch/deit
74
+ # --------------------------------------------------------
75
+ def interpolate_pos_embed(model, checkpoint_model):
76
+ if 'pos_embed' in checkpoint_model:
77
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
78
+ embedding_size = pos_embed_checkpoint.shape[-1]
79
+ num_patches = model.patch_embed.num_patches
80
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81
+ # height (== width) for the checkpoint position embedding
82
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83
+ # height (== width) for the new position embedding
84
+ new_size = int(num_patches ** 0.5)
85
+ # class_token and dist_token are kept unchanged
86
+ if orig_size != new_size:
87
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89
+ # only the position tokens are interpolated
90
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92
+ pos_tokens = torch.nn.functional.interpolate(
93
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96
+ checkpoint_model['pos_embed'] = new_pos_embed
src/open_clip/pretrained.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from functools import partial
6
+ from typing import Dict, Union
7
+
8
+ from tqdm import tqdm
9
+
10
+ from .constants import (
11
+ IMAGENET_MEAN,
12
+ IMAGENET_STD,
13
+ INCEPTION_MEAN,
14
+ INCEPTION_STD,
15
+ OPENAI_DATASET_MEAN,
16
+ OPENAI_DATASET_STD,
17
+ )
18
+ from .version import __version__
19
+
20
+ try:
21
+ from huggingface_hub import hf_hub_download
22
+ hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
23
+ _has_hf_hub = True
24
+ except ImportError:
25
+ hf_hub_download = None
26
+ _has_hf_hub = False
27
+
28
+
29
+ def _pcfg(url='', hf_hub='', **kwargs):
30
+ # OpenAI / OpenCLIP defaults
31
+ return {
32
+ 'url': url,
33
+ 'hf_hub': hf_hub,
34
+ 'mean': OPENAI_DATASET_MEAN,
35
+ 'std': OPENAI_DATASET_STD,
36
+ 'interpolation': 'bicubic',
37
+ 'resize_mode': 'shortest',
38
+ **kwargs,
39
+ }
40
+
41
+
42
+ def _slpcfg(url='', hf_hub='', **kwargs):
43
+ # SiGLIP defaults
44
+ return {
45
+ 'url': url,
46
+ 'hf_hub': hf_hub,
47
+ 'mean': INCEPTION_MEAN,
48
+ 'std': INCEPTION_STD,
49
+ 'interpolation': 'bicubic',
50
+ 'resize_mode': 'squash',
51
+ **kwargs,
52
+ }
53
+
54
+
55
+ def _apcfg(url='', hf_hub='', **kwargs):
56
+ # CLIPA defaults
57
+ return {
58
+ 'url': url,
59
+ 'hf_hub': hf_hub,
60
+ 'mean': IMAGENET_MEAN,
61
+ 'std': IMAGENET_STD,
62
+ 'interpolation': 'bilinear',
63
+ 'resize_mode': 'squash',
64
+ **kwargs,
65
+ }
66
+
67
+
68
+ def _mccfg(url='', hf_hub='', **kwargs):
69
+ # MobileCLIP
70
+ return {
71
+ 'url': url,
72
+ 'hf_hub': hf_hub,
73
+ 'mean': (0., 0., 0.),
74
+ 'std': (1., 1., 1.),
75
+ 'interpolation': 'bilinear',
76
+ 'resize_mode': 'shortest',
77
+ **kwargs,
78
+ }
79
+
80
+
81
+
82
+ _RN50 = dict(
83
+ openai=_pcfg(
84
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
85
+ yfcc15m=_pcfg(
86
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
87
+ cc12m=_pcfg(
88
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
89
+ )
90
+
91
+ _RN50_quickgelu = dict(
92
+ openai=_pcfg(
93
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
94
+ yfcc15m=_pcfg(
95
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
96
+ cc12m=_pcfg(
97
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
98
+ )
99
+
100
+ _RN101 = dict(
101
+ openai=_pcfg(
102
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
103
+ yfcc15m=_pcfg(
104
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
105
+ )
106
+
107
+ _RN101_quickgelu = dict(
108
+ openai=_pcfg(
109
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
110
+ yfcc15m=_pcfg(
111
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
112
+ )
113
+
114
+ _RN50x4 = dict(
115
+ openai=_pcfg(
116
+ "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
117
+ )
118
+
119
+ _RN50x16 = dict(
120
+ openai=_pcfg(
121
+ "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
122
+ )
123
+
124
+ _RN50x64 = dict(
125
+ openai=_pcfg(
126
+ "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
127
+ )
128
+
129
+ _VITB32 = dict(
130
+ openai=_pcfg(
131
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
132
+ laion400m_e31=_pcfg(
133
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
134
+ laion400m_e32=_pcfg(
135
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
136
+ laion2b_e16=_pcfg(
137
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
138
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),
139
+ # DataComp-XL models
140
+ datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'),
141
+ # DataComp-M models
142
+ datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),
143
+ commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),
144
+ commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),
145
+ commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),
146
+ commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),
147
+ commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),
148
+ commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),
149
+ # DataComp-S models
150
+ datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),
151
+ commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),
152
+ commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),
153
+ commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),
154
+ commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),
155
+ commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),
156
+ commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),
157
+ )
158
+
159
+ _VITB32_quickgelu = dict(
160
+ openai=_pcfg(
161
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
162
+ laion400m_e31=_pcfg(
163
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
164
+ laion400m_e32=_pcfg(
165
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
166
+ metaclip_400m=_pcfg(
167
+ "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"),
168
+ metaclip_fullcc=_pcfg(
169
+ "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"),
170
+ )
171
+
172
+ _VITB32_256 = dict(
173
+ datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'),
174
+ )
175
+
176
+ _VITB16 = dict(
177
+ openai=_pcfg(
178
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
179
+ laion400m_e31=_pcfg(
180
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
181
+ laion400m_e32=_pcfg(
182
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
183
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
184
+ # DataComp-XL models
185
+ datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'),
186
+ # DataComp-L models
187
+ datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),
188
+ commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),
189
+ commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),
190
+ commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),
191
+ commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),
192
+ commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),
193
+ commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),
194
+ # DFN
195
+ dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/')
196
+ )
197
+
198
+ _VITB16_quickgelu = dict(
199
+ metaclip_400m=_pcfg(
200
+ "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"),
201
+ metaclip_fullcc=_pcfg(
202
+ "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"),
203
+ )
204
+
205
+ _VITB16_PLUS_240 = dict(
206
+ laion400m_e31=_pcfg(
207
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
208
+ laion400m_e32=_pcfg(
209
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
210
+ )
211
+
212
+ _VITL14 = dict(
213
+ openai=_pcfg(
214
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
215
+ laion400m_e31=_pcfg(
216
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
217
+ laion400m_e32=_pcfg(
218
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
219
+ laion2b_s32b_b82k=_pcfg(
220
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
221
+ mean=INCEPTION_MEAN, std=INCEPTION_STD),
222
+ # DataComp-XL models
223
+ datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),
224
+ commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),
225
+ commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),
226
+ commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),
227
+ )
228
+
229
+ _VITL14_quickgelu = dict(
230
+ metaclip_400m=_pcfg(
231
+ "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"),
232
+ metaclip_fullcc=_pcfg(
233
+ "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"),
234
+ dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'),
235
+ )
236
+
237
+ _VITL14_336 = dict(
238
+ openai=_pcfg(
239
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
240
+ )
241
+
242
+ _VITH14 = dict(
243
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
244
+ )
245
+
246
+ _VITH14_quickgelu = dict(
247
+ metaclip_fullcc=_pcfg(
248
+ "https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"),
249
+ dfn5b=_pcfg(
250
+ hf_hub='apple/DFN5B-CLIP-ViT-H-14/',
251
+ interpolation="bicubic",
252
+ resize_mode="squash"
253
+ ),
254
+ )
255
+
256
+ _VITH14_378_quickgelu = dict(
257
+ dfn5b=_pcfg(
258
+ hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/',
259
+ interpolation="bicubic",
260
+ resize_mode="squash"
261
+ ),
262
+ )
263
+
264
+ _VITg14 = dict(
265
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
266
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
267
+ )
268
+
269
+ _VITbigG14 = dict(
270
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
271
+ )
272
+
273
+ _robertaViTB32 = dict(
274
+ laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
275
+ )
276
+
277
+ _xlmRobertaBaseViTB32 = dict(
278
+ laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
279
+ )
280
+
281
+ _xlmRobertaLargeFrozenViTH14 = dict(
282
+ frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
283
+ )
284
+
285
+ _convnext_base = dict(
286
+ laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
287
+ )
288
+
289
+ _convnext_base_w = dict(
290
+ laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
291
+ laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
292
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
293
+ )
294
+
295
+ _convnext_base_w_320 = dict(
296
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
297
+ laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
298
+ )
299
+
300
+ _convnext_large_d = dict(
301
+ laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
302
+ )
303
+
304
+ _convnext_large_d_320 = dict(
305
+ laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
306
+ laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
307
+ )
308
+
309
+ _convnext_xxlarge = dict(
310
+ laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
311
+ laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
312
+ laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
313
+ )
314
+
315
+ _coca_VITB32 = dict(
316
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
317
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
318
+ )
319
+
320
+ _coca_VITL14 = dict(
321
+ laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
322
+ mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
323
+ )
324
+
325
+
326
+ _PRETRAINED = {
327
+ "RN50": _RN50,
328
+ "RN50-quickgelu": _RN50_quickgelu,
329
+ "RN101": _RN101,
330
+ "RN101-quickgelu": _RN101_quickgelu,
331
+ "RN50x4": _RN50x4,
332
+ "RN50x16": _RN50x16,
333
+ "RN50x64": _RN50x64,
334
+
335
+ "ViT-B-32": _VITB32,
336
+ "ViT-B-32-256": _VITB32_256,
337
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
338
+ "ViT-B-16": _VITB16,
339
+ "ViT-B-16-quickgelu": _VITB16_quickgelu,
340
+ "ViT-B-16-plus-240": _VITB16_PLUS_240,
341
+ "ViT-L-14": _VITL14,
342
+ "ViT-L-14-quickgelu": _VITL14_quickgelu,
343
+ "ViT-L-14-336": _VITL14_336,
344
+ "ViT-H-14": _VITH14,
345
+ "ViT-H-14-quickgelu": _VITH14_quickgelu,
346
+ "ViT-H-14-378-quickgelu": _VITH14_378_quickgelu,
347
+ "ViT-g-14": _VITg14,
348
+ "ViT-bigG-14": _VITbigG14,
349
+
350
+ "roberta-ViT-B-32": _robertaViTB32,
351
+ "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
352
+ "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
353
+
354
+ "convnext_base": _convnext_base,
355
+ "convnext_base_w": _convnext_base_w,
356
+ "convnext_base_w_320": _convnext_base_w_320,
357
+ "convnext_large_d": _convnext_large_d,
358
+ "convnext_large_d_320": _convnext_large_d_320,
359
+ "convnext_xxlarge": _convnext_xxlarge,
360
+
361
+ "coca_ViT-B-32": _coca_VITB32,
362
+ "coca_ViT-L-14": _coca_VITL14,
363
+
364
+ "EVA01-g-14": dict(
365
+ # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt
366
+ laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),
367
+ ),
368
+ "EVA01-g-14-plus": dict(
369
+ # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt
370
+ merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),
371
+ ),
372
+ "EVA02-B-16": dict(
373
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt
374
+ merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),
375
+ ),
376
+ "EVA02-L-14": dict(
377
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt
378
+ merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),
379
+ ),
380
+ "EVA02-L-14-336": dict(
381
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt
382
+ merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),
383
+ ),
384
+ "EVA02-E-14": dict(
385
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt
386
+ laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),
387
+ ),
388
+ "EVA02-E-14-plus": dict(
389
+ # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt
390
+ laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),
391
+ ),
392
+
393
+ "ViT-B-16-SigLIP": dict(
394
+ webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'),
395
+ ),
396
+ "ViT-B-16-SigLIP-256": dict(
397
+ webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'),
398
+ ),
399
+ "ViT-B-16-SigLIP-i18n-256": dict(
400
+ webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'),
401
+ ),
402
+ "ViT-B-16-SigLIP-384": dict(
403
+ webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'),
404
+ ),
405
+ "ViT-B-16-SigLIP-512": dict(
406
+ webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'),
407
+ ),
408
+ "ViT-L-16-SigLIP-256": dict(
409
+ webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'),
410
+ ),
411
+ "ViT-L-16-SigLIP-384": dict(
412
+ webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'),
413
+ ),
414
+ "ViT-SO400M-14-SigLIP": dict(
415
+ webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),
416
+ ),
417
+ "ViT-SO400M-14-SigLIP-384": dict(
418
+ webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),
419
+ ),
420
+
421
+ "ViT-L-14-CLIPA": dict(
422
+ datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'),
423
+ ),
424
+ "ViT-L-14-CLIPA-336": dict(
425
+ datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'),
426
+ ),
427
+ "ViT-H-14-CLIPA": dict(
428
+ datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'),
429
+ ),
430
+ "ViT-H-14-CLIPA-336": dict(
431
+ laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'),
432
+ datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'),
433
+ ),
434
+ "ViT-bigG-14-CLIPA": dict(
435
+ datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'),
436
+ ),
437
+ "ViT-bigG-14-CLIPA-336": dict(
438
+ datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'),
439
+ ),
440
+
441
+ "nllb-clip-base": dict(
442
+ v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),
443
+ ),
444
+ "nllb-clip-large": dict(
445
+ v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),
446
+ ),
447
+
448
+ "nllb-clip-base-siglip": dict(
449
+ v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'),
450
+ mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'),
451
+ ),
452
+ "nllb-clip-large-siglip": dict(
453
+ v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),
454
+ mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),
455
+ ),
456
+
457
+ "MobileCLIP-S1": dict(
458
+ datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')),
459
+ "MobileCLIP-S2": dict(
460
+ datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')),
461
+ "MobileCLIP-B": dict(
462
+ datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'),
463
+ datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'),
464
+ ),
465
+
466
+ "ViTamin-S": dict(
467
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'),
468
+ ),
469
+ "ViTamin-S-LTT": dict(
470
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'),
471
+ ),
472
+ "ViTamin-B": dict(
473
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'),
474
+ ),
475
+ "ViTamin-B-LTT": dict(
476
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'),
477
+ ),
478
+ "ViTamin-L": dict(
479
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'),
480
+ ),
481
+ "ViTamin-L-256": dict(
482
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'),
483
+ ),
484
+ "ViTamin-L-336": dict(
485
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'),
486
+ ),
487
+ "ViTamin-L-384": dict(
488
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'),
489
+ ),
490
+ "ViTamin-L2": dict(
491
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'),
492
+ ),
493
+ "ViTamin-L2-256": dict(
494
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'),
495
+ ),
496
+ "ViTamin-L2-336": dict(
497
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'),
498
+ ),
499
+ "ViTamin-L2-384": dict(
500
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'),
501
+ ),
502
+ "ViTamin-XL-256": dict(
503
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'),
504
+ ),
505
+ "ViTamin-XL-336": dict(
506
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'),
507
+ ),
508
+ "ViTamin-XL-384": dict(
509
+ datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'),
510
+ ),
511
+ }
512
+
513
+
514
+ def _clean_tag(tag: str):
515
+ # normalize pretrained tags
516
+ return tag.lower().replace('-', '_')
517
+
518
+
519
+ def list_pretrained(as_str: bool = False):
520
+ """ returns list of pretrained models
521
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
522
+ """
523
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
524
+
525
+
526
+ def list_pretrained_models_by_tag(tag: str):
527
+ """ return all models having the specified pretrain tag """
528
+ models = []
529
+ tag = _clean_tag(tag)
530
+ for k in _PRETRAINED.keys():
531
+ if tag in _PRETRAINED[k]:
532
+ models.append(k)
533
+ return models
534
+
535
+
536
+ def list_pretrained_tags_by_model(model: str):
537
+ """ return all pretrain tags for the specified model architecture """
538
+ tags = []
539
+ if model in _PRETRAINED:
540
+ tags.extend(_PRETRAINED[model].keys())
541
+ return tags
542
+
543
+
544
+ def is_pretrained_cfg(model: str, tag: str):
545
+ if model not in _PRETRAINED:
546
+ return False
547
+ return _clean_tag(tag) in _PRETRAINED[model]
548
+
549
+
550
+ def get_pretrained_cfg(model: str, tag: str):
551
+ if model not in _PRETRAINED:
552
+ return {}
553
+ model_pretrained = _PRETRAINED[model]
554
+ return model_pretrained.get(_clean_tag(tag), {})
555
+
556
+
557
+ def get_pretrained_url(model: str, tag: str):
558
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
559
+ return cfg.get('url', '')
560
+
561
+
562
+ def download_pretrained_from_url(
563
+ url: str,
564
+ cache_dir: Union[str, None] = None,
565
+ ):
566
+ if not cache_dir:
567
+ cache_dir = os.path.expanduser("~/.cache/clip")
568
+ os.makedirs(cache_dir, exist_ok=True)
569
+ filename = os.path.basename(url)
570
+
571
+ if 'openaipublic' in url:
572
+ expected_sha256 = url.split("/")[-2]
573
+ elif 'mlfoundations' in url:
574
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
575
+ else:
576
+ expected_sha256 = ''
577
+
578
+ download_target = os.path.join(cache_dir, filename)
579
+
580
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
581
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
582
+
583
+ if os.path.isfile(download_target):
584
+ if expected_sha256:
585
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
586
+ return download_target
587
+ else:
588
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
589
+ else:
590
+ return download_target
591
+
592
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
593
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
594
+ while True:
595
+ buffer = source.read(8192)
596
+ if not buffer:
597
+ break
598
+
599
+ output.write(buffer)
600
+ loop.update(len(buffer))
601
+
602
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
603
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
604
+
605
+ return download_target
606
+
607
+
608
+ def has_hf_hub(necessary=False):
609
+ if not _has_hf_hub and necessary:
610
+ # if no HF Hub module installed, and it is necessary to continue, raise error
611
+ raise RuntimeError(
612
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
613
+ return _has_hf_hub
614
+
615
+
616
+ def download_pretrained_from_hf(
617
+ model_id: str,
618
+ filename: str = 'open_clip_pytorch_model.bin',
619
+ revision=None,
620
+ cache_dir: Union[str, None] = None,
621
+ ):
622
+ has_hf_hub(True)
623
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
624
+ return cached_file
625
+
626
+
627
+ def download_pretrained(
628
+ cfg: Dict,
629
+ force_hf_hub: bool = False,
630
+ cache_dir: Union[str, None] = None,
631
+ ):
632
+ target = ''
633
+ if not cfg:
634
+ return target
635
+
636
+ download_url = cfg.get('url', '')
637
+ download_hf_hub = cfg.get('hf_hub', '')
638
+ if download_hf_hub and force_hf_hub:
639
+ # use HF hub even if url exists
640
+ download_url = ''
641
+
642
+ if download_url:
643
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
644
+ elif download_hf_hub:
645
+ has_hf_hub(True)
646
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
647
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
648
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
649
+ model_id, filename = os.path.split(download_hf_hub)
650
+ if filename:
651
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
652
+ else:
653
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
654
+
655
+ return target
src/open_clip/timm_model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+ import logging
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import timm
13
+ from timm.models.layers import Mlp, to_2tuple
14
+ try:
15
+ # old timm imports < 0.8.1
16
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
17
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18
+ except ImportError:
19
+ # new timm imports >= 0.8.1
20
+ from timm.layers import RotAttentionPool2d
21
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
22
+ except ImportError:
23
+ timm = None
24
+
25
+ from .utils import freeze_batch_norm_2d
26
+
27
+
28
+ class TimmModel(nn.Module):
29
+ """ timm model adapter
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ model_name,
35
+ embed_dim,
36
+ image_size=224,
37
+ pool='avg',
38
+ proj='linear',
39
+ proj_bias=False,
40
+ drop=0.,
41
+ drop_path=None,
42
+ patch_drop=None,
43
+ pretrained=False,
44
+ ):
45
+ super().__init__()
46
+ if timm is None:
47
+ raise RuntimeError("Please `pip install timm` to use timm models.")
48
+ self.image_size = to_2tuple(image_size)
49
+
50
+ # setup kwargs that may not be common across all models
51
+ timm_kwargs = {}
52
+ if drop_path is not None:
53
+ timm_kwargs['drop_path_rate'] = drop_path
54
+ if patch_drop is not None:
55
+ timm_kwargs['patch_drop_rate'] = patch_drop
56
+
57
+ custom_pool = pool in ('abs_attn', 'rot_attn')
58
+ if proj:
59
+ assert proj in ("linear", "mlp", "none")
60
+ extra_proj = proj in ("linear", "mlp")
61
+ if not extra_proj and not custom_pool:
62
+ # use network classifier head as projection if no proj specified and no custom pooling used
63
+ # if projection is explicitly set to "none" will be pass through from network trunk
64
+ proj_dim = 0 if proj == 'none' else embed_dim
65
+ self.trunk = timm.create_model(
66
+ model_name,
67
+ num_classes=proj_dim,
68
+ global_pool=pool,
69
+ pretrained=pretrained,
70
+ **timm_kwargs,
71
+ )
72
+ prev_chs = embed_dim
73
+ else:
74
+ self.trunk = timm.create_model(
75
+ model_name,
76
+ pretrained=pretrained,
77
+ **timm_kwargs,
78
+ )
79
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
80
+ feature_ndim = 1 if not feat_size else 2
81
+ if custom_pool:
82
+ assert feature_ndim == 2
83
+ # if attn pooling used, remove both classifier and default pool
84
+ self.trunk.reset_classifier(0, global_pool='')
85
+ else:
86
+ # reset global pool if pool config set, otherwise leave as network default
87
+ reset_kwargs = dict(global_pool=pool) if pool else {}
88
+ self.trunk.reset_classifier(0, **reset_kwargs)
89
+ prev_chs = self.trunk.num_features
90
+
91
+ head_layers = OrderedDict()
92
+
93
+ # Add custom pooling to head
94
+ if pool == 'abs_attn':
95
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
96
+ prev_chs = embed_dim
97
+ elif pool == 'rot_attn':
98
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
99
+ prev_chs = embed_dim
100
+
101
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
102
+ if proj == 'linear':
103
+ head_layers['drop'] = nn.Dropout(drop)
104
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
105
+ elif proj == 'mlp':
106
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
107
+
108
+ self.head = nn.Sequential(head_layers)
109
+
110
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
111
+ """ lock modules
112
+ Args:
113
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
114
+ """
115
+ if not unlocked_groups:
116
+ # lock full model
117
+ for param in self.trunk.parameters():
118
+ param.requires_grad = False
119
+ if freeze_bn_stats:
120
+ freeze_batch_norm_2d(self.trunk)
121
+ else:
122
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
123
+ try:
124
+ # FIXME import here until API stable and in an official release
125
+ from timm.models.helpers import group_parameters, group_modules
126
+ except ImportError:
127
+ raise RuntimeError(
128
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
129
+ matcher = self.trunk.group_matcher()
130
+ gparams = group_parameters(self.trunk, matcher)
131
+ max_layer_id = max(gparams.keys())
132
+ max_layer_id = max_layer_id - unlocked_groups
133
+ for group_idx in range(max_layer_id + 1):
134
+ group = gparams[group_idx]
135
+ for param in group:
136
+ self.trunk.get_parameter(param).requires_grad = False
137
+ if freeze_bn_stats:
138
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
139
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
140
+ freeze_batch_norm_2d(self.trunk, gmodules)
141
+
142
+ @torch.jit.ignore
143
+ def set_grad_checkpointing(self, enable=True):
144
+ try:
145
+ self.trunk.set_grad_checkpointing(enable)
146
+ except Exception as e:
147
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
148
+
149
+ def forward(self, x):
150
+ x = self.trunk(x)
151
+ x = self.head(x)
152
+ return x
src/open_clip/tokenizer.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import gzip
6
+ import html
7
+ import os
8
+ import random
9
+ import string
10
+ from functools import lru_cache, partial
11
+ from typing import Callable, List, Optional, Union
12
+ import warnings
13
+
14
+ import ftfy
15
+ import numpy as np
16
+ import regex as re
17
+ import torch
18
+
19
+ # https://stackoverflow.com/q/62691279
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ _nltk_init = False
22
+
23
+ DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
24
+
25
+
26
+ @lru_cache()
27
+ def default_bpe():
28
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
29
+
30
+
31
+ @lru_cache()
32
+ def bytes_to_unicode():
33
+ """
34
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
35
+ The reversible bpe codes work on unicode strings.
36
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
37
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
38
+ This is a significant percentage of your normal, say, 32K bpe vocab.
39
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
40
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
41
+ """
42
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
43
+ cs = bs[:]
44
+ n = 0
45
+ for b in range(2**8):
46
+ if b not in bs:
47
+ bs.append(b)
48
+ cs.append(2**8+n)
49
+ n += 1
50
+ cs = [chr(n) for n in cs]
51
+ return dict(zip(bs, cs))
52
+
53
+
54
+ def get_pairs(word):
55
+ """Return set of symbol pairs in a word.
56
+ Word is represented as tuple of symbols (symbols being variable-length strings).
57
+ """
58
+ pairs = set()
59
+ prev_char = word[0]
60
+ for char in word[1:]:
61
+ pairs.add((prev_char, char))
62
+ prev_char = char
63
+ return pairs
64
+
65
+
66
+ def basic_clean(text):
67
+ text = ftfy.fix_text(text)
68
+ text = html.unescape(html.unescape(text))
69
+ return text.strip()
70
+
71
+
72
+ def whitespace_clean(text):
73
+ text = " ".join(text.split())
74
+ text = text.strip()
75
+ return text
76
+
77
+
78
+ def _clean_canonicalize(x):
79
+ # basic, remove whitespace, remove punctuation, lower case
80
+ return canonicalize_text(basic_clean(x))
81
+
82
+
83
+ def _clean_lower(x):
84
+ # basic, remove whitespace, lower case
85
+ return whitespace_clean(basic_clean(x)).lower()
86
+
87
+
88
+ def _clean_whitespace(x):
89
+ # basic, remove whitespace
90
+ return whitespace_clean(basic_clean(x))
91
+
92
+
93
+ def get_clean_fn(type: str):
94
+ if type == 'canonicalize':
95
+ return _clean_canonicalize
96
+ elif type == 'lower':
97
+ return _clean_lower
98
+ elif type == 'whitespace':
99
+ return _clean_whitespace
100
+ else:
101
+ assert False, f"Invalid clean function ({type})."
102
+
103
+
104
+ def canonicalize_text(
105
+ text,
106
+ *,
107
+ keep_punctuation_exact_string=None,
108
+ trans_punctuation: dict = str.maketrans("", "", string.punctuation),
109
+ ):
110
+ """Returns canonicalized `text` (lowercase and punctuation removed).
111
+
112
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
113
+
114
+ Args:
115
+ text: string to be canonicalized.
116
+ keep_punctuation_exact_string: If provided, then this exact string kept.
117
+ For example providing '{}' will keep any occurrences of '{}' (but will
118
+ still remove '{' and '}' that appear separately).
119
+ """
120
+ text = text.replace("_", " ")
121
+ if keep_punctuation_exact_string:
122
+ text = keep_punctuation_exact_string.join(
123
+ part.translate(trans_punctuation)
124
+ for part in text.split(keep_punctuation_exact_string)
125
+ )
126
+ else:
127
+ text = text.translate(trans_punctuation)
128
+ text = text.lower()
129
+ text = " ".join(text.split())
130
+ return text.strip()
131
+
132
+
133
+ class SimpleTokenizer(object):
134
+ def __init__(
135
+ self,
136
+ bpe_path: str = default_bpe(),
137
+ additional_special_tokens: Optional[List[str]] = None,
138
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
139
+ clean: str = 'lower',
140
+ reduction_mask: str = ''
141
+ ):
142
+ self.byte_encoder = bytes_to_unicode()
143
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
144
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
145
+ merges = merges[1:49152-256-2+1]
146
+ merges = [tuple(merge.split()) for merge in merges]
147
+ vocab = list(bytes_to_unicode().values())
148
+ vocab = vocab + [v+'</w>' for v in vocab]
149
+ for merge in merges:
150
+ vocab.append(''.join(merge))
151
+ special_tokens = ['<start_of_text>', '<end_of_text>']
152
+ if additional_special_tokens:
153
+ special_tokens += additional_special_tokens
154
+ vocab.extend(special_tokens)
155
+ self.encoder = dict(zip(vocab, range(len(vocab))))
156
+ self.decoder = {v: k for k, v in self.encoder.items()}
157
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
158
+ self.cache = {t:t for t in special_tokens}
159
+ special = "|".join(special_tokens)
160
+ self.pat = re.compile(
161
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
162
+ re.IGNORECASE,
163
+ )
164
+ self.vocab_size = len(self.encoder)
165
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
166
+ self.sot_token_id = self.all_special_ids[0]
167
+ self.eot_token_id = self.all_special_ids[1]
168
+ self.context_length = context_length
169
+ self.clean_fn = get_clean_fn(clean)
170
+ self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
171
+
172
+ def bpe(self, token):
173
+ if token in self.cache:
174
+ return self.cache[token]
175
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
176
+ pairs = get_pairs(word)
177
+
178
+ if not pairs:
179
+ return token+'</w>'
180
+
181
+ while True:
182
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
183
+ if bigram not in self.bpe_ranks:
184
+ break
185
+ first, second = bigram
186
+ new_word = []
187
+ i = 0
188
+ while i < len(word):
189
+ try:
190
+ j = word.index(first, i)
191
+ new_word.extend(word[i:j])
192
+ i = j
193
+ except Exception:
194
+ new_word.extend(word[i:])
195
+ break
196
+
197
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
198
+ new_word.append(first+second)
199
+ i += 2
200
+ else:
201
+ new_word.append(word[i])
202
+ i += 1
203
+ new_word = tuple(new_word)
204
+ word = new_word
205
+ if len(word) == 1:
206
+ break
207
+ else:
208
+ pairs = get_pairs(word)
209
+ word = ' '.join(word)
210
+ self.cache[token] = word
211
+ return word
212
+
213
+ def encode(self, text):
214
+ bpe_tokens = []
215
+ text = self.clean_fn(text)
216
+ for token in re.findall(self.pat, text):
217
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
218
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
219
+ return bpe_tokens
220
+
221
+ def decode(self, tokens):
222
+ text = ''.join([self.decoder[token] for token in tokens])
223
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
224
+ return text
225
+
226
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
227
+ """ Returns the tokenized representation of given input string(s)
228
+
229
+ Parameters
230
+ ----------
231
+ texts : Union[str, List[str]]
232
+ An input string or a list of input strings to tokenize
233
+ context_length : int
234
+ The context length to use; all CLIP models use 77 as the context length
235
+
236
+ Returns
237
+ -------
238
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
239
+ """
240
+ if isinstance(texts, str):
241
+ texts = [texts]
242
+
243
+ context_length = context_length or self.context_length
244
+ assert context_length, 'Please set a valid context length'
245
+
246
+ if self.reduction_fn is not None:
247
+ # use reduction strategy for tokenize if set, otherwise default to truncation below
248
+ return self.reduction_fn(
249
+ texts,
250
+ context_length=context_length,
251
+ sot_token_id=self.sot_token_id,
252
+ eot_token_id=self.eot_token_id,
253
+ encode_fn=self.encode,
254
+ )
255
+
256
+ all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
257
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
258
+
259
+ for i, tokens in enumerate(all_tokens):
260
+ if len(tokens) > context_length:
261
+ tokens = tokens[:context_length] # Truncate
262
+ tokens[-1] = self.eot_token_id
263
+ result[i, :len(tokens)] = torch.tensor(tokens)
264
+
265
+ return result
266
+
267
+
268
+ _tokenizer = SimpleTokenizer()
269
+
270
+
271
+ def decode(output_ids: torch.Tensor):
272
+ output_ids = output_ids.cpu().numpy()
273
+ return _tokenizer.decode(output_ids)
274
+
275
+
276
+ def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
277
+ return _tokenizer(texts, context_length=context_length)
278
+
279
+
280
+ def random_mask_tokenize(
281
+ texts: Union[str, List[str]],
282
+ context_length: int,
283
+ sot_token_id: int,
284
+ eot_token_id: int,
285
+ encode_fn: Callable,
286
+ shuffle: bool = False,
287
+ ):
288
+ all_tokens = [encode_fn(text) for text in texts]
289
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
290
+
291
+ for i, tokens in enumerate(all_tokens):
292
+ tokens = torch.tensor(tokens)
293
+ num_tokens = len(tokens)
294
+ if num_tokens > context_length - 2: # 2 for sot and eot token
295
+ num_keep = context_length - 2
296
+ indices = torch.randperm(len(tokens))
297
+ indices = indices[:num_keep]
298
+ if not shuffle:
299
+ indices = indices.msort()
300
+ tokens = tokens[indices]
301
+ num_tokens = num_keep
302
+ result[i, 0] = sot_token_id
303
+ result[i, 1:num_tokens + 1] = tokens
304
+ result[i, num_tokens + 1] = eot_token_id
305
+
306
+ return result
307
+
308
+
309
+ def simple_mask_tokenize(
310
+ texts: Union[str, List[str]],
311
+ context_length: int,
312
+ sot_token_id: int,
313
+ eot_token_id: int,
314
+ encode_fn: Callable,
315
+ ):
316
+ all_tokens = [encode_fn(text) for text in texts]
317
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
318
+
319
+ for i, tokens in enumerate(all_tokens):
320
+ num_tokens = len(tokens)
321
+ if num_tokens > context_length - 2: # 2 for sot and eot token
322
+ num_keep = context_length - 2
323
+ start_index = random.randint(0, num_tokens - num_keep) # high is incl
324
+ tokens = tokens[start_index: start_index + num_keep]
325
+ tokens = [sot_token_id] + tokens + [eot_token_id]
326
+ result[i, :len(tokens)] = torch.tensor(tokens)
327
+
328
+ return result
329
+
330
+
331
+ def syntax_mask_tokenize(
332
+ texts: Union[str, List[str]],
333
+ context_length: int,
334
+ sot_token_id: int,
335
+ eot_token_id: int,
336
+ encode_fn: Callable,
337
+ ) -> torch.LongTensor:
338
+ """ Returns the tokenized representation of given input string(s).
339
+ Apply syntax masking before tokenize.
340
+ """
341
+ import nltk
342
+ global _nltk_init
343
+ if not _nltk_init:
344
+ # run them for the first time
345
+ nltk.download('punkt')
346
+ nltk.download('averaged_perceptron_tagger')
347
+ _nltk_init = True
348
+
349
+ def get_order(x):
350
+ if x.startswith('NN'):
351
+ return 1
352
+ elif x.startswith('JJ'):
353
+ return 2
354
+ elif x.startswith('VB'):
355
+ return 3
356
+ else:
357
+ return 4
358
+
359
+ # syntax masking
360
+ new_texts = []
361
+ for text in texts:
362
+ list_tokens = nltk.tokenize.word_tokenize(text)
363
+ pos_tags = nltk.pos_tag(list_tokens)
364
+ # sample the words by get_order method
365
+ order_list = [get_order(tag) for _, tag in pos_tags]
366
+ sorted_ids = np.argsort(np.array(order_list))
367
+ sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
368
+ sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
369
+
370
+ new_text = ''
371
+ for token in sampled_tokens:
372
+ new_text = new_text + str(token) + ' '
373
+ new_text = new_text.strip()
374
+ new_texts.append(new_text)
375
+ texts = new_texts
376
+
377
+ all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
378
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
379
+
380
+ for i, tokens in enumerate(all_tokens):
381
+ # still need first truncate because some words produces two tokens
382
+ if len(tokens) > context_length:
383
+ tokens = tokens[:context_length] # Truncate
384
+ tokens[-1] = eot_token_id
385
+ result[i, :len(tokens)] = torch.tensor(tokens)
386
+
387
+ return result
388
+
389
+
390
+ def get_reduction_mask_fn(type: str):
391
+ """ Choose strategy for dropping (masking) tokens to achieve target context length"""
392
+ assert type in ('simple', 'random', 'shuffle', 'syntax')
393
+ if type == 'simple':
394
+ return simple_mask_tokenize # randomly select block [start:end]
395
+ elif type == 'random':
396
+ return random_mask_tokenize # randomly drop tokens (keep order)
397
+ elif type == 'shuffle':
398
+ return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
399
+ elif type == 'syntax':
400
+ return syntax_mask_tokenize # randomly drop prioritized by syntax
401
+
402
+
403
+ class HFTokenizer:
404
+ """HuggingFace tokenizer wrapper"""
405
+
406
+ def __init__(
407
+ self,
408
+ tokenizer_name: str,
409
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
410
+ clean: str = 'whitespace',
411
+ strip_sep_token: bool = False,
412
+ language: Optional[str] = None,
413
+ **kwargs
414
+ ):
415
+ from transformers import AutoTokenizer
416
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs)
417
+ set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
418
+ if callable(set_lang_fn):
419
+ self.set_lang_fn = set_lang_fn
420
+ if language is not None:
421
+ self.set_language(language)
422
+ self.context_length = context_length
423
+ self.clean_fn = get_clean_fn(clean)
424
+ self.strip_sep_token = strip_sep_token
425
+
426
+ def save_pretrained(self, dest):
427
+ self.tokenizer.save_pretrained(dest)
428
+
429
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
430
+ # same cleaning as for default tokenizer, except lowercasing
431
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
432
+ if isinstance(texts, str):
433
+ texts = [texts]
434
+
435
+ context_length = context_length or self.context_length
436
+ assert context_length, 'Please set a valid context length in class init or call.'
437
+
438
+ texts = [self.clean_fn(text) for text in texts]
439
+ input_ids = self.tokenizer.batch_encode_plus(
440
+ texts,
441
+ return_tensors='pt',
442
+ max_length=context_length,
443
+ padding='max_length',
444
+ truncation=True,
445
+ ).input_ids
446
+
447
+ if self.strip_sep_token:
448
+ input_ids = torch.where(
449
+ input_ids == self.tokenizer.sep_token_id,
450
+ torch.zeros_like(input_ids),
451
+ input_ids,
452
+ )
453
+
454
+ return input_ids
455
+
456
+ def set_language(self, src_lang):
457
+ if hasattr(self, 'set_lang_fn'):
458
+ self.set_lang_fn(src_lang)
459
+ else:
460
+ warnings.warn('Cannot set language for the tokenizer.')
461
+
462
+
463
+ class SigLipTokenizer:
464
+ """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
465
+ """
466
+ VOCAB_FILES = {
467
+ # english, vocab_size=32_000
468
+ "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
469
+ # used in multilingual models (mT5, PaLI), vocab_size=250_000
470
+ "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
471
+ }
472
+
473
+ def __init__(
474
+ self,
475
+ tokenizer_name: str,
476
+ context_length: Optional[int] = 64,
477
+ ):
478
+ from transformers import T5TokenizerFast
479
+
480
+ if tokenizer_name in self.VOCAB_FILES:
481
+ # FIXME temporary hack?
482
+ import tempfile
483
+
484
+ import fsspec
485
+ vocab_file = self.VOCAB_FILES[tokenizer_name]
486
+ with tempfile.NamedTemporaryFile('wb') as dst:
487
+ with fsspec.open(vocab_file, 'rb') as src:
488
+ dst.write(src.read())
489
+ self.tokenizer = T5TokenizerFast(dst.name, legacy=False)
490
+ else:
491
+ self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False)
492
+
493
+ self.tokenizer.pad_token_id = 1
494
+ self.tokenizer.eos_token_id = 1
495
+ self.context_length = context_length
496
+
497
+ def save_pretrained(self, dest):
498
+ self.tokenizer.save_pretrained(dest)
499
+
500
+ def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
501
+ # same cleaning as for default tokenizer, except lowercasing
502
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
503
+ if isinstance(texts, str):
504
+ texts = [texts]
505
+
506
+ context_length = context_length or self.context_length
507
+ assert context_length, 'Please set a valid context length in class init or call.'
508
+
509
+ texts = [canonicalize_text(basic_clean(text)) for text in texts]
510
+ output = self.tokenizer(
511
+ texts,
512
+ return_tensors='pt',
513
+ max_length=context_length,
514
+ padding='max_length',
515
+ truncation=True,
516
+ )
517
+ return output.input_ids
src/open_clip/transform.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import random
3
+ import warnings
4
+ from dataclasses import dataclass, asdict
5
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import torch
8
+ import torchvision.transforms.functional as F
9
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
10
+ CenterCrop, ColorJitter, Grayscale
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+ from .utils import to_2tuple
14
+
15
+
16
+ @dataclass
17
+ class PreprocessCfg:
18
+ size: Union[int, Tuple[int, int]] = 224
19
+ mode: str = 'RGB'
20
+ mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
21
+ std: Tuple[float, ...] = OPENAI_DATASET_STD
22
+ interpolation: str = 'bicubic'
23
+ resize_mode: str = 'shortest'
24
+ fill_color: int = 0
25
+
26
+ def __post_init__(self):
27
+ assert self.mode in ('RGB',)
28
+
29
+ @property
30
+ def num_channels(self):
31
+ return 3
32
+
33
+ @property
34
+ def input_size(self):
35
+ return (self.num_channels,) + to_2tuple(self.size)
36
+
37
+ _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
38
+
39
+
40
+ def merge_preprocess_dict(
41
+ base: Union[PreprocessCfg, Dict],
42
+ overlay: Dict,
43
+ ):
44
+ """ Merge overlay key-value pairs on top of base preprocess cfg or dict.
45
+ Input dicts are filtered based on PreprocessCfg fields.
46
+ """
47
+ if isinstance(base, PreprocessCfg):
48
+ base_clean = asdict(base)
49
+ else:
50
+ base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
51
+ if overlay:
52
+ overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}
53
+ base_clean.update(overlay_clean)
54
+ return base_clean
55
+
56
+
57
+ def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
58
+ return merge_preprocess_dict(base, kwargs)
59
+
60
+
61
+ @dataclass
62
+ class AugmentationCfg:
63
+ scale: Tuple[float, float] = (0.9, 1.0)
64
+ ratio: Optional[Tuple[float, float]] = None
65
+ color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None
66
+ re_prob: Optional[float] = None
67
+ re_count: Optional[int] = None
68
+ use_timm: bool = False
69
+
70
+ # params for simclr_jitter_gray
71
+ color_jitter_prob: float = None
72
+ gray_scale_prob: float = None
73
+
74
+
75
+ def _setup_size(size, error_msg):
76
+ if isinstance(size, numbers.Number):
77
+ return int(size), int(size)
78
+
79
+ if isinstance(size, Sequence) and len(size) == 1:
80
+ return size[0], size[0]
81
+
82
+ if len(size) != 2:
83
+ raise ValueError(error_msg)
84
+
85
+ return size
86
+
87
+
88
+ class ResizeKeepRatio:
89
+ """ Resize and Keep Ratio
90
+
91
+ Copy & paste from `timm`
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ size,
97
+ longest=0.,
98
+ interpolation=InterpolationMode.BICUBIC,
99
+ random_scale_prob=0.,
100
+ random_scale_range=(0.85, 1.05),
101
+ random_aspect_prob=0.,
102
+ random_aspect_range=(0.9, 1.11)
103
+ ):
104
+ if isinstance(size, (list, tuple)):
105
+ self.size = tuple(size)
106
+ else:
107
+ self.size = (size, size)
108
+ self.interpolation = interpolation
109
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
110
+ self.random_scale_prob = random_scale_prob
111
+ self.random_scale_range = random_scale_range
112
+ self.random_aspect_prob = random_aspect_prob
113
+ self.random_aspect_range = random_aspect_range
114
+
115
+ @staticmethod
116
+ def get_params(
117
+ img,
118
+ target_size,
119
+ longest,
120
+ random_scale_prob=0.,
121
+ random_scale_range=(0.85, 1.05),
122
+ random_aspect_prob=0.,
123
+ random_aspect_range=(0.9, 1.11)
124
+ ):
125
+ """Get parameters
126
+ """
127
+ source_size = img.size[::-1] # h, w
128
+ h, w = source_size
129
+ target_h, target_w = target_size
130
+ ratio_h = h / target_h
131
+ ratio_w = w / target_w
132
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
133
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
134
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
135
+ ratio_factor = (ratio_factor, ratio_factor)
136
+ else:
137
+ ratio_factor = (1., 1.)
138
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
139
+ aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
140
+ ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
141
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
142
+ return size
143
+
144
+ def __call__(self, img):
145
+ """
146
+ Args:
147
+ img (PIL Image): Image to be cropped and resized.
148
+
149
+ Returns:
150
+ PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
151
+ """
152
+ size = self.get_params(
153
+ img, self.size, self.longest,
154
+ self.random_scale_prob, self.random_scale_range,
155
+ self.random_aspect_prob, self.random_aspect_range
156
+ )
157
+ img = F.resize(img, size, self.interpolation)
158
+ return img
159
+
160
+ def __repr__(self):
161
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
162
+ format_string += f', interpolation={self.interpolation})'
163
+ format_string += f', longest={self.longest:.3f})'
164
+ return format_string
165
+
166
+
167
+ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
168
+ """Center crops and/or pads the given image.
169
+ If the image is torch Tensor, it is expected
170
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
171
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
172
+
173
+ Args:
174
+ img (PIL Image or Tensor): Image to be cropped.
175
+ output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
176
+ it is used for both directions.
177
+ fill (int, Tuple[int]): Padding color
178
+
179
+ Returns:
180
+ PIL Image or Tensor: Cropped image.
181
+ """
182
+ if isinstance(output_size, numbers.Number):
183
+ output_size = (int(output_size), int(output_size))
184
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
185
+ output_size = (output_size[0], output_size[0])
186
+
187
+ _, image_height, image_width = F.get_dimensions(img)
188
+ crop_height, crop_width = output_size
189
+
190
+ if crop_width > image_width or crop_height > image_height:
191
+ padding_ltrb = [
192
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
193
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
194
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
195
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
196
+ ]
197
+ img = F.pad(img, padding_ltrb, fill=fill)
198
+ _, image_height, image_width = F.get_dimensions(img)
199
+ if crop_width == image_width and crop_height == image_height:
200
+ return img
201
+
202
+ crop_top = int(round((image_height - crop_height) / 2.0))
203
+ crop_left = int(round((image_width - crop_width) / 2.0))
204
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
205
+
206
+
207
+ class CenterCropOrPad(torch.nn.Module):
208
+ """Crops the given image at the center.
209
+ If the image is torch Tensor, it is expected
210
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
211
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
212
+
213
+ Args:
214
+ size (sequence or int): Desired output size of the crop. If size is an
215
+ int instead of sequence like (h, w), a square crop (size, size) is
216
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
217
+ """
218
+
219
+ def __init__(self, size, fill=0):
220
+ super().__init__()
221
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
222
+ self.fill = fill
223
+
224
+ def forward(self, img):
225
+ """
226
+ Args:
227
+ img (PIL Image or Tensor): Image to be cropped.
228
+
229
+ Returns:
230
+ PIL Image or Tensor: Cropped image.
231
+ """
232
+ return center_crop_or_pad(img, self.size, fill=self.fill)
233
+
234
+ def __repr__(self) -> str:
235
+ return f"{self.__class__.__name__}(size={self.size})"
236
+
237
+
238
+ def _convert_to_rgb(image):
239
+ return image.convert('RGB')
240
+
241
+
242
+ class color_jitter(object):
243
+ """
244
+ Apply Color Jitter to the PIL image with a specified probability.
245
+ """
246
+ def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):
247
+ assert 0. <= p <= 1.
248
+ self.p = p
249
+ self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
250
+
251
+ def __call__(self, img):
252
+ if random.random() < self.p:
253
+ return self.transf(img)
254
+ else:
255
+ return img
256
+
257
+
258
+ class gray_scale(object):
259
+ """
260
+ Apply Gray Scale to the PIL image with a specified probability.
261
+ """
262
+ def __init__(self, p=0.2):
263
+ assert 0. <= p <= 1.
264
+ self.p = p
265
+ self.transf = Grayscale(num_output_channels=3)
266
+
267
+ def __call__(self, img):
268
+ if random.random() < self.p:
269
+ return self.transf(img)
270
+ else:
271
+ return img
272
+
273
+
274
+ def image_transform(
275
+ image_size: Union[int, Tuple[int, int]],
276
+ is_train: bool,
277
+ mean: Optional[Tuple[float, ...]] = None,
278
+ std: Optional[Tuple[float, ...]] = None,
279
+ resize_mode: Optional[str] = None,
280
+ interpolation: Optional[str] = None,
281
+ fill_color: int = 0,
282
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
283
+ ):
284
+ mean = mean or OPENAI_DATASET_MEAN
285
+ if not isinstance(mean, (list, tuple)):
286
+ mean = (mean,) * 3
287
+
288
+ std = std or OPENAI_DATASET_STD
289
+ if not isinstance(std, (list, tuple)):
290
+ std = (std,) * 3
291
+
292
+ interpolation = interpolation or 'bicubic'
293
+ assert interpolation in ['bicubic', 'bilinear', 'random']
294
+ # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
295
+ interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
296
+
297
+ resize_mode = resize_mode or 'shortest'
298
+ assert resize_mode in ('shortest', 'longest', 'squash')
299
+
300
+ if isinstance(aug_cfg, dict):
301
+ aug_cfg = AugmentationCfg(**aug_cfg)
302
+ else:
303
+ aug_cfg = aug_cfg or AugmentationCfg()
304
+
305
+ normalize = Normalize(mean=mean, std=std)
306
+
307
+ if is_train:
308
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
309
+ use_timm = aug_cfg_dict.pop('use_timm', False)
310
+ if use_timm:
311
+ from timm.data import create_transform # timm can still be optional
312
+ if isinstance(image_size, (tuple, list)):
313
+ assert len(image_size) >= 2
314
+ input_size = (3,) + image_size[-2:]
315
+ else:
316
+ input_size = (3, image_size, image_size)
317
+
318
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
319
+ # drop extra non-timm items
320
+ aug_cfg_dict.pop('color_jitter_prob', None)
321
+ aug_cfg_dict.pop('gray_scale_prob', None)
322
+
323
+ train_transform = create_transform(
324
+ input_size=input_size,
325
+ is_training=True,
326
+ hflip=0.,
327
+ mean=mean,
328
+ std=std,
329
+ re_mode='pixel',
330
+ interpolation=interpolation,
331
+ **aug_cfg_dict,
332
+ )
333
+ else:
334
+ train_transform = [
335
+ RandomResizedCrop(
336
+ image_size,
337
+ scale=aug_cfg_dict.pop('scale'),
338
+ interpolation=InterpolationMode.BICUBIC,
339
+ ),
340
+ _convert_to_rgb,
341
+ ]
342
+ if aug_cfg.color_jitter_prob:
343
+ assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
344
+ train_transform.extend([
345
+ color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)
346
+ ])
347
+ if aug_cfg.gray_scale_prob:
348
+ train_transform.extend([
349
+ gray_scale(aug_cfg.gray_scale_prob)
350
+ ])
351
+ train_transform.extend([
352
+ ToTensor(),
353
+ normalize,
354
+ ])
355
+ train_transform = Compose(train_transform)
356
+ if aug_cfg_dict:
357
+ warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
358
+ return train_transform
359
+ else:
360
+ if resize_mode == 'longest':
361
+ transforms = [
362
+ ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
363
+ CenterCropOrPad(image_size, fill=fill_color)
364
+ ]
365
+ elif resize_mode == 'squash':
366
+ if isinstance(image_size, int):
367
+ image_size = (image_size, image_size)
368
+ transforms = [
369
+ Resize(image_size, interpolation=interpolation_mode),
370
+ ]
371
+ else:
372
+ assert resize_mode == 'shortest'
373
+ if not isinstance(image_size, (tuple, list)):
374
+ image_size = (image_size, image_size)
375
+ if image_size[0] == image_size[1]:
376
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
377
+ transforms = [
378
+ Resize(image_size[0], interpolation=interpolation_mode)
379
+ ]
380
+ else:
381
+ # resize shortest edge to matching target dim for non-square target
382
+ transforms = [ResizeKeepRatio(image_size)]
383
+ transforms += [CenterCrop(image_size)]
384
+
385
+ transforms.extend([
386
+ _convert_to_rgb,
387
+ ToTensor(),
388
+ normalize,
389
+ ])
390
+ return Compose(transforms)
391
+
392
+
393
+ def image_transform_v2(
394
+ cfg: PreprocessCfg,
395
+ is_train: bool,
396
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
397
+ ):
398
+ return image_transform(
399
+ image_size=cfg.size,
400
+ is_train=is_train,
401
+ mean=cfg.mean,
402
+ std=cfg.std,
403
+ interpolation=cfg.interpolation,
404
+ resize_mode=cfg.resize_mode,
405
+ fill_color=cfg.fill_color,
406
+ aug_cfg=aug_cfg,
407
+ )
src/open_clip/transformer.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import math
3
+ from typing import Callable, List, Optional, Sequence, Tuple, Union
4
+ from functools import partial
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from .utils import to_2tuple
12
+ from .pos_embed import get_2d_sincos_pos_embed
13
+
14
+
15
+ class LayerNormFp32(nn.LayerNorm):
16
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
17
+
18
+ def forward(self, x: torch.Tensor):
19
+ orig_type = x.dtype
20
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
21
+ return x.to(orig_type)
22
+
23
+
24
+ class LayerNorm(nn.LayerNorm):
25
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
26
+
27
+ def forward(self, x: torch.Tensor):
28
+ orig_type = x.dtype
29
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
30
+ return x.to(orig_type)
31
+
32
+
33
+ class QuickGELU(nn.Module):
34
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
35
+ def forward(self, x: torch.Tensor):
36
+ return x * torch.sigmoid(1.702 * x)
37
+
38
+
39
+ class LayerScale(nn.Module):
40
+ def __init__(self, dim, init_values=1e-5, inplace=False):
41
+ super().__init__()
42
+ self.inplace = inplace
43
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
44
+
45
+ def forward(self, x):
46
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
47
+
48
+
49
+ class PatchDropout(nn.Module):
50
+ """
51
+ https://arxiv.org/abs/2212.00794
52
+ """
53
+
54
+ def __init__(self, prob, exclude_first_token=True):
55
+ super().__init__()
56
+ assert 0 <= prob < 1.
57
+ self.prob = prob
58
+ self.exclude_first_token = exclude_first_token # exclude CLS token
59
+
60
+ def forward(self, x):
61
+ if not self.training or self.prob == 0.:
62
+ return x
63
+
64
+ if self.exclude_first_token:
65
+ cls_tokens, x = x[:, :1], x[:, 1:]
66
+ else:
67
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
68
+
69
+ batch = x.size()[0]
70
+ num_tokens = x.size()[1]
71
+
72
+ batch_indices = torch.arange(batch)
73
+ batch_indices = batch_indices[..., None]
74
+
75
+ keep_prob = 1 - self.prob
76
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
77
+
78
+ rand = torch.randn(batch, num_tokens)
79
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
80
+
81
+ x = x[batch_indices, patch_indices_keep]
82
+
83
+ if self.exclude_first_token:
84
+ x = torch.cat((cls_tokens, x), dim=1)
85
+
86
+ return x
87
+
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(
91
+ self,
92
+ dim: int,
93
+ num_heads: int = 8,
94
+ qkv_bias: bool = True,
95
+ scaled_cosine: bool = False,
96
+ scale_heads: bool = False,
97
+ logit_scale_max: float = math.log(1. / 0.01),
98
+ batch_first: bool = True,
99
+ attn_drop: float = 0.,
100
+ proj_drop: float = 0.
101
+ ):
102
+ super().__init__()
103
+ self.scaled_cosine = scaled_cosine
104
+ self.scale_heads = scale_heads
105
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
106
+ self.num_heads = num_heads
107
+ self.head_dim = dim // num_heads
108
+ self.scale = self.head_dim ** -0.5
109
+ self.logit_scale_max = logit_scale_max
110
+ self.batch_first = batch_first
111
+ self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
112
+
113
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
114
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
115
+ if qkv_bias:
116
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
117
+ else:
118
+ self.in_proj_bias = None
119
+
120
+ if self.scaled_cosine:
121
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
122
+ else:
123
+ self.logit_scale = None
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ if self.scale_heads:
126
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
127
+ else:
128
+ self.head_scale = None
129
+ self.out_proj = nn.Linear(dim, dim)
130
+ self.out_drop = nn.Dropout(proj_drop)
131
+
132
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
133
+ if self.batch_first:
134
+ x = x.transpose(0, 1)
135
+
136
+ L, N, C = x.shape
137
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
138
+ q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)
139
+ k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)
140
+ v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)
141
+
142
+ if attn_mask is not None and attn_mask.dtype == torch.bool:
143
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
144
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
145
+ attn_mask = new_attn_mask
146
+
147
+ if self.logit_scale is not None:
148
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
149
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
150
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
151
+ attn = attn.view(-1, L, L)
152
+ if attn_mask is not None:
153
+ attn = attn + attn_mask
154
+ attn = attn.softmax(dim=-1)
155
+ attn = self.attn_drop(attn)
156
+ x = torch.bmm(attn, v)
157
+ else:
158
+ if self.use_fsdpa:
159
+ x = F.scaled_dot_product_attention(
160
+ q, k, v,
161
+ attn_mask=attn_mask,
162
+ dropout_p=self.attn_drop.p if self.training else 0.,
163
+ )
164
+ else:
165
+ q = q * self.scale
166
+ attn = torch.bmm(q, k.transpose(-1, -2))
167
+ if attn_mask is not None:
168
+ attn += attn_mask
169
+ attn = attn.softmax(dim=-1)
170
+ attn = self.attn_drop(attn)
171
+ x = torch.bmm(attn, v)
172
+
173
+ if self.head_scale is not None:
174
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
175
+ x = x.view(-1, L, C)
176
+
177
+ x = x.transpose(0, 1).reshape(L, N, C)
178
+
179
+ if self.batch_first:
180
+ x = x.transpose(0, 1)
181
+
182
+ x = self.out_proj(x)
183
+ x = self.out_drop(x)
184
+ return x
185
+
186
+
187
+ class AttentionalPooler(nn.Module):
188
+ def __init__(
189
+ self,
190
+ d_model: int,
191
+ context_dim: int,
192
+ n_head: int = 8,
193
+ n_queries: int = 256,
194
+ norm_layer: Callable = LayerNorm,
195
+ ):
196
+ super().__init__()
197
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
198
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
199
+ self.ln_q = norm_layer(d_model)
200
+ self.ln_k = norm_layer(context_dim)
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ N = x.shape[0]
204
+ x = self.ln_k(x)
205
+ q = self.ln_q(self.query)
206
+ out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
207
+ return out
208
+
209
+
210
+ class ResidualAttentionBlock(nn.Module):
211
+ def __init__(
212
+ self,
213
+ d_model: int,
214
+ n_head: int,
215
+ mlp_ratio: float = 4.0,
216
+ ls_init_value: float = None,
217
+ act_layer: Callable = nn.GELU,
218
+ norm_layer: Callable = LayerNorm,
219
+ is_cross_attention: bool = False,
220
+ batch_first: bool = True,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.ln_1 = norm_layer(d_model)
225
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
226
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
227
+ if is_cross_attention:
228
+ self.ln_1_kv = norm_layer(d_model)
229
+
230
+ self.ln_2 = norm_layer(d_model)
231
+ mlp_width = int(d_model * mlp_ratio)
232
+ self.mlp = nn.Sequential(OrderedDict([
233
+ ("c_fc", nn.Linear(d_model, mlp_width)),
234
+ ("gelu", act_layer()),
235
+ ("c_proj", nn.Linear(mlp_width, d_model))
236
+ ]))
237
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
238
+
239
+ def attention(
240
+ self,
241
+ q_x: torch.Tensor,
242
+ k_x: Optional[torch.Tensor] = None,
243
+ v_x: Optional[torch.Tensor] = None,
244
+ attn_mask: Optional[torch.Tensor] = None,
245
+ ):
246
+ k_x = k_x if k_x is not None else q_x
247
+ v_x = v_x if v_x is not None else q_x
248
+
249
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
250
+ return self.attn(
251
+ q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
252
+ )[0]
253
+
254
+ def forward(
255
+ self,
256
+ q_x: torch.Tensor,
257
+ k_x: Optional[torch.Tensor] = None,
258
+ v_x: Optional[torch.Tensor] = None,
259
+ attn_mask: Optional[torch.Tensor] = None,
260
+ ):
261
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
262
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
263
+ x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
264
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
265
+ return x
266
+
267
+
268
+ class CustomResidualAttentionBlock(nn.Module):
269
+ def __init__(
270
+ self,
271
+ d_model: int,
272
+ n_head: int,
273
+ mlp_ratio: float = 4.0,
274
+ ls_init_value: float = None,
275
+ act_layer: Callable = nn.GELU,
276
+ norm_layer: Callable = LayerNorm,
277
+ scale_cosine_attn: bool = False,
278
+ scale_heads: bool = False,
279
+ scale_attn: bool = False,
280
+ scale_fc: bool = False,
281
+ batch_first: bool = True,
282
+ ):
283
+ super().__init__()
284
+
285
+ self.ln_1 = norm_layer(d_model)
286
+ self.attn = Attention(
287
+ d_model,
288
+ n_head,
289
+ scaled_cosine=scale_cosine_attn,
290
+ scale_heads=scale_heads,
291
+ batch_first=batch_first,
292
+ )
293
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
294
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
295
+
296
+ self.ln_2 = norm_layer(d_model)
297
+ mlp_width = int(d_model * mlp_ratio)
298
+ self.mlp = nn.Sequential(OrderedDict([
299
+ ("c_fc", nn.Linear(d_model, mlp_width)),
300
+ ("gelu", act_layer()),
301
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
302
+ ("c_proj", nn.Linear(mlp_width, d_model))
303
+ ]))
304
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
305
+
306
+ def get_reference_weight(self):
307
+ return self.mlp.c_fc.weight
308
+
309
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
310
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
311
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
312
+ return x
313
+
314
+
315
+ def _expand_token(token, batch_size: int):
316
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
317
+
318
+
319
+ class Transformer(nn.Module):
320
+ def __init__(
321
+ self,
322
+ width: int,
323
+ layers: int,
324
+ heads: int,
325
+ mlp_ratio: float = 4.0,
326
+ ls_init_value: float = None,
327
+ act_layer: Callable = nn.GELU,
328
+ norm_layer: Callable = LayerNorm,
329
+ batch_first: bool = True,
330
+ ):
331
+ super().__init__()
332
+ self.width = width
333
+ self.layers = layers
334
+ self.batch_first = batch_first
335
+ self.grad_checkpointing = False
336
+
337
+ self.resblocks = nn.ModuleList([
338
+ ResidualAttentionBlock(
339
+ width,
340
+ heads,
341
+ mlp_ratio,
342
+ ls_init_value=ls_init_value,
343
+ act_layer=act_layer,
344
+ norm_layer=norm_layer,
345
+ batch_first=batch_first,
346
+ )
347
+ for _ in range(layers)
348
+ ])
349
+
350
+ def get_cast_dtype(self) -> torch.dtype:
351
+ if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
352
+ return self.resblocks[0].mlp.c_fc.int8_original_dtype
353
+ return self.resblocks[0].mlp.c_fc.weight.dtype
354
+
355
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
356
+ if not self.batch_first:
357
+ x = x.transpose(0, 1).contiguous() # NLD -> LND
358
+ for r in self.resblocks:
359
+ if self.grad_checkpointing and not torch.jit.is_scripting():
360
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
361
+ x = checkpoint(r, x, None, None, attn_mask)
362
+ else:
363
+ x = r(x, attn_mask=attn_mask)
364
+ if not self.batch_first:
365
+ x = x.transpose(0, 1) # LND -> NLD
366
+ return x
367
+
368
+
369
+ class CustomTransformer(nn.Module):
370
+ """ A custom transformer that can use different block types. """
371
+ def __init__(
372
+ self,
373
+ width: int,
374
+ layers: int,
375
+ heads: int,
376
+ mlp_ratio: float = 4.0,
377
+ ls_init_value: float = None,
378
+ act_layer: Callable = nn.GELU,
379
+ norm_layer: Callable = LayerNorm,
380
+ batch_first: bool = True,
381
+ block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',
382
+ ):
383
+ super().__init__()
384
+ self.width = width
385
+ self.layers = layers
386
+ self.batch_first = batch_first # run trasnformer stack in batch first (N, L, D)
387
+ self.grad_checkpointing = False
388
+
389
+ if isinstance(block_types, str):
390
+ block_types = [block_types] * layers
391
+ assert len(block_types) == layers
392
+
393
+ def _create_block(bt: str):
394
+ if bt == 'CustomResidualAttentionBlock':
395
+ return CustomResidualAttentionBlock(
396
+ width,
397
+ heads,
398
+ mlp_ratio=mlp_ratio,
399
+ ls_init_value=ls_init_value,
400
+ act_layer=act_layer,
401
+ norm_layer=norm_layer,
402
+ batch_first=batch_first,
403
+ )
404
+ else:
405
+ assert False
406
+
407
+ self.resblocks = nn.ModuleList([
408
+ _create_block(bt)
409
+ for bt in block_types
410
+ ])
411
+
412
+ def get_cast_dtype(self) -> torch.dtype:
413
+ weight = self.resblocks[0].get_reference_weight()
414
+ if hasattr(weight, 'int8_original_dtype'):
415
+ return weight.int8_original_dtype
416
+ return weight.dtype
417
+
418
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
419
+ if not self.batch_first:
420
+ x = x.transpose(0, 1) # NLD -> LND
421
+
422
+ for r in self.resblocks:
423
+ if self.grad_checkpointing and not torch.jit.is_scripting():
424
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
425
+ x = checkpoint(r, x, None, None, attn_mask)
426
+ else:
427
+ x = r(x, attn_mask=attn_mask)
428
+
429
+ if not self.batch_first:
430
+ x = x.transpose(0, 1) # NLD -> LND
431
+ return x
432
+
433
+
434
+ class VisionTransformer(nn.Module):
435
+ output_tokens: torch.jit.Final[bool]
436
+
437
+ def __init__(
438
+ self,
439
+ image_size: int,
440
+ patch_size: int,
441
+ width: int,
442
+ layers: int,
443
+ heads: int,
444
+ mlp_ratio: float,
445
+ ls_init_value: float = None,
446
+ attentional_pool: bool = False,
447
+ attn_pooler_queries: int = 256,
448
+ attn_pooler_heads: int = 8,
449
+ output_dim: int = 512,
450
+ patch_dropout: float = 0.,
451
+ no_ln_pre: bool = False,
452
+ pos_embed_type: str = 'learnable',
453
+ pool_type: str = 'tok',
454
+ final_ln_after_pool: bool = False,
455
+ act_layer: Callable = nn.GELU,
456
+ norm_layer: Callable = LayerNorm,
457
+ output_tokens: bool = False,
458
+ ):
459
+ super().__init__()
460
+ assert pool_type in ('tok', 'avg', 'none')
461
+ self.output_tokens = output_tokens
462
+ image_height, image_width = self.image_size = to_2tuple(image_size)
463
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
464
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
465
+ self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
466
+ self.output_dim = output_dim
467
+
468
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
469
+
470
+ # class embeddings and positional embeddings
471
+ scale = width ** -0.5
472
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
473
+ if pos_embed_type == 'learnable':
474
+ self.positional_embedding = nn.Parameter(
475
+ scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
476
+ elif pos_embed_type == 'sin_cos_2d':
477
+ # fixed sin-cos embedding
478
+ assert self.grid_size[0] == self.grid_size[1],\
479
+ 'currently sin cos 2d pos embedding only supports square input'
480
+ self.positional_embedding = nn.Parameter(
481
+ torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)
482
+ pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)
483
+ self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())
484
+ else:
485
+ raise ValueError
486
+
487
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
488
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
489
+
490
+ self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
491
+ self.transformer = Transformer(
492
+ width,
493
+ layers,
494
+ heads,
495
+ mlp_ratio,
496
+ ls_init_value=ls_init_value,
497
+ act_layer=act_layer,
498
+ norm_layer=norm_layer,
499
+ )
500
+
501
+ if attentional_pool:
502
+ if isinstance(attentional_pool, str):
503
+ self.attn_pool_type = attentional_pool
504
+ self.pool_type = 'none'
505
+ if attentional_pool in ('parallel', 'cascade'):
506
+ self.attn_pool = AttentionalPooler(
507
+ output_dim,
508
+ width,
509
+ n_head=attn_pooler_heads,
510
+ n_queries=attn_pooler_queries,
511
+ )
512
+ self.attn_pool_contrastive = AttentionalPooler(
513
+ output_dim,
514
+ width,
515
+ n_head=attn_pooler_heads,
516
+ n_queries=1,
517
+ )
518
+ else:
519
+ assert False
520
+ else:
521
+ self.attn_pool_type = ''
522
+ self.pool_type = pool_type
523
+ self.attn_pool = AttentionalPooler(
524
+ output_dim,
525
+ width,
526
+ n_head=attn_pooler_heads,
527
+ n_queries=attn_pooler_queries,
528
+ )
529
+ self.attn_pool_contrastive = None
530
+ pool_dim = output_dim
531
+ else:
532
+ self.attn_pool = None
533
+ pool_dim = width
534
+ self.pool_type = pool_type
535
+
536
+ self.ln_post = norm_layer(pool_dim)
537
+ self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
538
+
539
+ self.init_parameters()
540
+
541
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
542
+ for param in self.parameters():
543
+ param.requires_grad = False
544
+
545
+ if unlocked_groups != 0:
546
+ groups = [
547
+ [
548
+ self.conv1,
549
+ self.class_embedding,
550
+ self.positional_embedding,
551
+ self.ln_pre,
552
+ ],
553
+ *self.transformer.resblocks[:-1],
554
+ [
555
+ self.transformer.resblocks[-1],
556
+ self.ln_post,
557
+ ],
558
+ self.proj,
559
+ ]
560
+
561
+ def _unlock(x):
562
+ if isinstance(x, Sequence):
563
+ for g in x:
564
+ _unlock(g)
565
+ else:
566
+ if isinstance(x, torch.nn.Parameter):
567
+ x.requires_grad = True
568
+ else:
569
+ for p in x.parameters():
570
+ p.requires_grad = True
571
+
572
+ _unlock(groups[-unlocked_groups:])
573
+
574
+ def init_parameters(self):
575
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
576
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
577
+
578
+ # nn.init.normal_(self.class_embedding, std=self.scale)
579
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
580
+ #
581
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
582
+ # attn_std = self.transformer.width ** -0.5
583
+ # fc_std = (2 * self.transformer.width) ** -0.5
584
+ # for block in self.transformer.resblocks:
585
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
586
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
587
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
588
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
589
+ #
590
+ # if self.text_projection is not None:
591
+ # nn.init.normal_(self.text_projection, std=self.scale)
592
+ pass
593
+
594
+ @torch.jit.ignore
595
+ def set_grad_checkpointing(self, enable=True):
596
+ self.transformer.grad_checkpointing = enable
597
+
598
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
599
+ if self.pool_type == 'avg':
600
+ pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
601
+ elif self.pool_type == 'tok':
602
+ pooled, tokens = x[:, 0], x[:, 1:]
603
+ else:
604
+ pooled = tokens = x
605
+
606
+ return pooled, tokens
607
+
608
+ def forward(self, x: torch.Tensor):
609
+ x = self.conv1(x) # shape = [*, width, grid, grid]
610
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
611
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
612
+
613
+ # class embeddings and positional embeddings
614
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
615
+ # shape = [*, grid ** 2 + 1, width]
616
+ x = x + self.positional_embedding.to(x.dtype)
617
+
618
+ x = self.patch_dropout(x)
619
+ x = self.ln_pre(x)
620
+ x = self.transformer(x)
621
+
622
+ if self.attn_pool is not None:
623
+ if self.attn_pool_contrastive is not None:
624
+ # This is untested, WIP pooling that should match paper
625
+ x = self.ln_post(x) # TBD LN first or separate one after each pool?
626
+ tokens = self.attn_pool(x)
627
+ if self.attn_pool_type == 'parallel':
628
+ pooled = self.attn_pool_contrastive(x)
629
+ else:
630
+ assert self.attn_pool_type == 'cascade'
631
+ pooled = self.attn_pool_contrastive(tokens)
632
+ else:
633
+ # this is the original OpenCLIP CoCa setup, does not match paper
634
+ x = self.attn_pool(x)
635
+ x = self.ln_post(x)
636
+ pooled, tokens = self._global_pool(x)
637
+ elif self.final_ln_after_pool:
638
+ pooled, tokens = self._global_pool(x)
639
+ pooled = self.ln_post(pooled)
640
+ else:
641
+ x = self.ln_post(x)
642
+ pooled, tokens = self._global_pool(x)
643
+
644
+ if self.proj is not None:
645
+ pooled = pooled @ self.proj
646
+
647
+ if self.output_tokens:
648
+ return pooled, tokens
649
+
650
+ return pooled
651
+
652
+
653
+ def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
654
+ if pool_type == 'first':
655
+ pooled, tokens = x[:, 0], x[:, 1:]
656
+ elif pool_type == 'last':
657
+ pooled, tokens = x[:, -1], x[:, :-1]
658
+ elif pool_type == 'argmax':
659
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
660
+ assert text is not None
661
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
662
+ else:
663
+ pooled = tokens = x
664
+
665
+ return pooled, tokens
666
+
667
+
668
+ class TextTransformer(nn.Module):
669
+ output_tokens: torch.jit.Final[bool]
670
+
671
+ def __init__(
672
+ self,
673
+ context_length: int = 77,
674
+ vocab_size: int = 49408,
675
+ width: int = 512,
676
+ heads: int = 8,
677
+ layers: int = 12,
678
+ mlp_ratio: float = 4.0,
679
+ ls_init_value: float = None,
680
+ output_dim: int = 512,
681
+ embed_cls: bool = False,
682
+ no_causal_mask: bool = False,
683
+ pad_id: int = 0,
684
+ pool_type: str = 'argmax',
685
+ proj_bias: bool = False,
686
+ act_layer: Callable = nn.GELU,
687
+ norm_layer: Callable = LayerNorm,
688
+ output_tokens: bool = False,
689
+ ):
690
+ super().__init__()
691
+ assert pool_type in ('first', 'last', 'argmax', 'none')
692
+ self.output_tokens = output_tokens
693
+ self.num_pos = self.context_length = context_length
694
+ self.vocab_size = vocab_size
695
+ self.width = width
696
+ self.output_dim = output_dim
697
+ self.heads = heads
698
+ self.pad_id = pad_id
699
+ self.pool_type = pool_type
700
+
701
+ self.token_embedding = nn.Embedding(vocab_size, width)
702
+ if embed_cls:
703
+ self.cls_emb = nn.Parameter(torch.empty(width))
704
+ self.num_pos += 1
705
+ else:
706
+ self.cls_emb = None
707
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
708
+ self.transformer = Transformer(
709
+ width=width,
710
+ layers=layers,
711
+ heads=heads,
712
+ mlp_ratio=mlp_ratio,
713
+ ls_init_value=ls_init_value,
714
+ act_layer=act_layer,
715
+ norm_layer=norm_layer,
716
+ )
717
+ self.ln_final = norm_layer(width)
718
+
719
+ if no_causal_mask:
720
+ self.attn_mask = None
721
+ else:
722
+ self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)
723
+
724
+ if proj_bias:
725
+ self.text_projection = nn.Linear(width, output_dim)
726
+ else:
727
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
728
+
729
+ self.init_parameters()
730
+
731
+ def init_parameters(self):
732
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
733
+ nn.init.normal_(self.positional_embedding, std=0.01)
734
+ if self.cls_emb is not None:
735
+ nn.init.normal_(self.cls_emb, std=0.01)
736
+
737
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
738
+ attn_std = self.transformer.width ** -0.5
739
+ fc_std = (2 * self.transformer.width) ** -0.5
740
+ for block in self.transformer.resblocks:
741
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
742
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
743
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
744
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
745
+
746
+ if self.text_projection is not None:
747
+ if isinstance(self.text_projection, nn.Linear):
748
+ nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)
749
+ if self.text_projection.bias is not None:
750
+ nn.init.zeros_(self.text_projection.bias)
751
+ else:
752
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
753
+
754
+ @torch.jit.ignore
755
+ def set_grad_checkpointing(self, enable=True):
756
+ self.transformer.grad_checkpointing = enable
757
+
758
+ def build_causal_mask(self):
759
+ # lazily create causal attention mask, with full attention between the tokens
760
+ # pytorch uses additive attention mask; fill with -inf
761
+ mask = torch.empty(self.num_pos, self.num_pos)
762
+ mask.fill_(float("-inf"))
763
+ mask.triu_(1) # zero out the lower diagonal
764
+ return mask
765
+
766
+ def build_cls_mask(self, text, cast_dtype: torch.dtype):
767
+ cls_mask = (text != self.pad_id).unsqueeze(1)
768
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
769
+ additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
770
+ additive_mask.fill_(0)
771
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
772
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
773
+ return additive_mask
774
+
775
+ def forward(self, text):
776
+ cast_dtype = self.transformer.get_cast_dtype()
777
+ seq_len = text.shape[1]
778
+
779
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
780
+ attn_mask = self.attn_mask
781
+ if self.cls_emb is not None:
782
+ seq_len += 1
783
+ x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)
784
+ cls_mask = self.build_cls_mask(text, cast_dtype)
785
+ if attn_mask is not None:
786
+ attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
787
+
788
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
789
+ x = self.transformer(x, attn_mask=attn_mask)
790
+
791
+ # x.shape = [batch_size, n_ctx, transformer.width]
792
+ if self.cls_emb is not None:
793
+ # presence of appended cls embed (CoCa) overrides pool_type, always take last token
794
+ pooled, tokens = text_global_pool(x, pool_type='last')
795
+ pooled = self.ln_final(pooled) # final LN applied after pooling in this case
796
+ else:
797
+ x = self.ln_final(x)
798
+ pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
799
+
800
+ if self.text_projection is not None:
801
+ if isinstance(self.text_projection, nn.Linear):
802
+ pooled = self.text_projection(pooled)
803
+ else:
804
+ pooled = pooled @ self.text_projection
805
+
806
+ if self.output_tokens:
807
+ return pooled, tokens
808
+
809
+ return pooled
810
+
811
+
812
+ class MultimodalTransformer(Transformer):
813
+ def __init__(
814
+ self,
815
+ width: int,
816
+ layers: int,
817
+ heads: int,
818
+ context_length: int = 77,
819
+ mlp_ratio: float = 4.0,
820
+ ls_init_value: float = None,
821
+ act_layer: Callable = nn.GELU,
822
+ norm_layer: Callable = LayerNorm,
823
+ output_dim: int = 512,
824
+ batch_first: bool = True,
825
+ ):
826
+ super().__init__(
827
+ width=width,
828
+ layers=layers,
829
+ heads=heads,
830
+ mlp_ratio=mlp_ratio,
831
+ ls_init_value=ls_init_value,
832
+ act_layer=act_layer,
833
+ norm_layer=norm_layer,
834
+ batch_first=batch_first,
835
+ )
836
+ self.context_length = context_length
837
+ self.cross_attn = nn.ModuleList([
838
+ ResidualAttentionBlock(
839
+ width,
840
+ heads,
841
+ mlp_ratio,
842
+ ls_init_value=ls_init_value,
843
+ act_layer=act_layer,
844
+ norm_layer=norm_layer,
845
+ is_cross_attention=True,
846
+ batch_first=batch_first,
847
+ )
848
+ for _ in range(layers)
849
+ ])
850
+
851
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
852
+
853
+ self.ln_final = norm_layer(width)
854
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
855
+
856
+ def init_parameters(self):
857
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
858
+ attn_std = self.transformer.width ** -0.5
859
+ fc_std = (2 * self.transformer.width) ** -0.5
860
+ for block in self.transformer.resblocks:
861
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
862
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
863
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
864
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
865
+ for block in self.transformer.cross_attn:
866
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
867
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
868
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
869
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
870
+
871
+ if self.text_projection is not None:
872
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
873
+
874
+ def build_attention_mask(self):
875
+ # lazily create causal attention mask, with full attention between the tokens
876
+ # pytorch uses additive attention mask; fill with -inf
877
+ mask = torch.empty(self.context_length, self.context_length)
878
+ mask.fill_(float("-inf"))
879
+ mask.triu_(1) # zero out the lower diagonal
880
+ return mask
881
+
882
+ def forward(self, image_embs, text_embs):
883
+ seq_len = text_embs.shape[1]
884
+ if not self.batch_first:
885
+ image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
886
+ text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
887
+
888
+ for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
889
+ if self.grad_checkpointing and not torch.jit.is_scripting():
890
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
891
+ text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
892
+ text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
893
+ else:
894
+ text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
895
+ text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
896
+
897
+ if not self.batch_first:
898
+ text_embs = text_embs.permute(1, 0, 2) # LND -> NLD
899
+
900
+ out = self.ln_final(text_embs)
901
+ if self.text_projection is not None:
902
+ out = out @ self.text_projection
903
+
904
+ return out
905
+
906
+ @torch.jit.ignore
907
+ def set_grad_checkpointing(self, enable=True):
908
+ self.grad_checkpointing = enable
src/open_clip/utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ import collections.abc
3
+
4
+ import torch
5
+ from torch import nn as nn
6
+ from torchvision.ops.misc import FrozenBatchNorm2d
7
+
8
+
9
+ def freeze_batch_norm_2d(module, module_match={}, name=''):
10
+ """
11
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
12
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
13
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
14
+
15
+ Args:
16
+ module (torch.nn.Module): Any PyTorch module.
17
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
18
+ name (str): Full module name (prefix)
19
+
20
+ Returns:
21
+ torch.nn.Module: Resulting module
22
+
23
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
24
+ """
25
+ res = module
26
+ is_match = True
27
+ if module_match:
28
+ is_match = name in module_match
29
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
30
+ res = FrozenBatchNorm2d(module.num_features)
31
+ res.num_features = module.num_features
32
+ res.affine = module.affine
33
+ if module.affine:
34
+ res.weight.data = module.weight.data.clone().detach()
35
+ res.bias.data = module.bias.data.clone().detach()
36
+ res.running_mean.data = module.running_mean.data
37
+ res.running_var.data = module.running_var.data
38
+ res.eps = module.eps
39
+ else:
40
+ for child_name, child in module.named_children():
41
+ full_child_name = '.'.join([name, child_name]) if name else child_name
42
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
43
+ if new_child is not child:
44
+ res.add_module(child_name, new_child)
45
+ return res
46
+
47
+
48
+ # From PyTorch internals
49
+ def _ntuple(n):
50
+ def parse(x):
51
+ if isinstance(x, collections.abc.Iterable):
52
+ return x
53
+ return tuple(repeat(x, n))
54
+ return parse
55
+
56
+
57
+ to_1tuple = _ntuple(1)
58
+ to_2tuple = _ntuple(2)
59
+ to_3tuple = _ntuple(3)
60
+ to_4tuple = _ntuple(4)
61
+ to_ntuple = lambda n, x: _ntuple(n)(x)
62
+
63
+ # Replaces all linear layers with linear_replacement
64
+ # TODO: add int8 support for other linear layers including attn and convnets
65
+ def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
66
+ for name, module in model.named_children():
67
+ if len(list(module.children())) > 0:
68
+ replace_linear(module, linear_replacement, include_modules, copy_weights)
69
+
70
+ if isinstance(module, torch.nn.Linear) and name in include_modules:
71
+ old_module = model._modules[name]
72
+ model._modules[name] = linear_replacement(
73
+ module.in_features,
74
+ module.out_features,
75
+ module.bias is not None,
76
+ )
77
+ if copy_weights:
78
+ model._modules[name].weight.data.copy_(old_module.weight.data)
79
+ if model._modules[name].bias is not None:
80
+ model._modules[name].bias.data.copy_(old_module.bias)
81
+
82
+ return model
83
+
84
+ def convert_int8_model_to_inference_mode(model):
85
+ for m in model.modules():
86
+ if hasattr(m, 'prepare_for_eval'):
87
+ int8_original_dtype = m.weight.dtype
88
+ m.prepare_for_eval()
89
+ m.int8_original_dtype = int8_original_dtype
src/open_clip/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '2.26.1'