K00B404 commited on
Commit
d902dc8
·
verified ·
1 Parent(s): 4ee27ce

Create CLIP.py

Browse files
Files changed (1) hide show
  1. CLIP.py +28 -0
CLIP.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
2
+ import torch
3
+ from safetensors.torch import load_file as load_safetensor
4
+
5
+ # Device configuration
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ def load(tokenizer_path = "tokenizer", text_encoder_path = "text_encoder"):
9
+ """ loads the clip model and tokenizer. returns: tuple of clip_model, tokenizer"""
10
+ safetensor_fp16 = f"./{text_encoder_path}/model.fp16.safetensors" # or use model.safetensors
11
+ config_path = f"./{text_encoder_path}/config.json"
12
+
13
+ # Load tokenizer
14
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
15
+
16
+ # Load CLIPTextModelWithProjection from the config file and safetensor
17
+ clip_model = CLIPTextModelWithProjection.from_pretrained(
18
+ text_encoder_path,
19
+ config=config_path,
20
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
21
+ )
22
+
23
+ # Load safetensor weights
24
+ state_dict = load_safetensor(safetensor_fp16)
25
+ clip_model.load_state_dict(state_dict)
26
+ clip_model = clip_model.to(device)
27
+
28
+ return clip_model, tokenizer