File size: 1,305 Bytes
aa5d6d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4615af
 
 
aa5d6d0
 
 
d4615af
aa5d6d0
 
d4615af
 
 
 
 
 
 
 
 
 
 
 
 
 
aa5d6d0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import io

import torch

def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
	if isinstance(target_dtype, torch.dtype):
		return target_dtype
	
	if target_dtype == "float16":
		return torch.float16
	elif target_dtype == "float32":
		return torch.float32
	elif target_dtype == "bfloat16":
		return torch.bfloat16
	else:
		raise ValueError(f"Invalid target_dtype: {target_dtype}")

def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
	if isinstance(ckpt_upload, bytes):
		ckpt_upload = io.BytesIO(ckpt_upload)
	
	target_dtype = get_target_dtype_ref(target_dtype)
	
	# Load the checkpoint
	loaded_dict = torch.load(ckpt_upload, map_location="cpu")
	
	tensor_dict = {}
	
	is_embedding = 'string_to_param' in loaded_dict
	if is_embedding:
		emb_tensor = loaded_dict.get('string_to_param', {}).get('*', None)
		if emb_tensor is not None:
			emb_tensor = emb_tensor.to(dtype=target_dtype)
			tensor_dict = {
				'emb_params': emb_tensor
			}
	else:
		# Convert weights in a checkpoint to a dictionary of tensors
		for key, val in loaded_dict.items():
			if isinstance(val, torch.Tensor):
				tensor_dict[key] = val.to(dtype=target_dtype)
	
	return tensor_dict

if __name__ == '__main__':
	print('__main__ not allowed in modules')