File size: 3,122 Bytes
977ddbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import itertools
from typing import List, Optional, Tuple, Union
import safetensors
import torch
from torch import Tensor
import os
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
def get_parameter_device(parameter: torch.nn.Module):
try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
def get_parameter_dtype(parameter: torch.nn.Module):
try:
params = tuple(parameter.parameters())
if len(params) > 0:
return params[0].dtype
buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
path_obj = Path(save_path)
return path_obj.parent
def get_base_name(save_path: Union[str, os.PathLike]) -> str:
path_obj = Path(save_path)
return path_obj.name
def load_state_dict_from_path(path: Union[str, os.PathLike]):
# Load a state dict from a path.
if 'safetensors' in path:
state_dict = safetensors.torch.load_file(path)
else:
state_dict = torch.load(path, map_location="cpu")
return state_dict
def replace_extension(path, new_extension):
if not new_extension.startswith('.'):
new_extension = '.' + new_extension
return os.path.splitext(path)[0] + new_extension
def make_config_path(save_path):
config_path = replace_extension(save_path, '.yaml')
return config_path
def save_config(config, config_path):
assert isinstance(config, dict) or isinstance(config, DictConfig)
os.makedirs(get_parent_directory(config_path), exist_ok=True)
if isinstance(config, dict):
config = OmegaConf.create(config)
OmegaConf.save(config, config_path)
def save_state_dict_and_config(state_dict, config, save_path):
os.makedirs(get_parent_directory(save_path), exist_ok=True)
# save config dict
config_path = make_config_path(save_path)
save_config(config, config_path)
# Save the model
if 'safetensors' in save_path:
safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
else:
torch.save(state_dict, save_path)
|