hangg-sai's picture
Initial commit
a342aa8
import os
import safetensors.torch
import torch
from huggingface_hub import hf_hub_download
from seva.model import Seva, SevaParams
def seed_everything(seed: int = 0):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_model(
pretrained_model_name_or_path: str = "stabilityai/stable-virtual-camera",
weight_name: str = "model.safetensors",
device: str | torch.device = "cuda",
verbose: bool = False,
) -> Seva:
if os.path.isdir(pretrained_model_name_or_path):
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
else:
weight_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=weight_name
)
_ = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="config.yaml"
)
state_dict = safetensors.torch.load_file(
weight_path,
device=str(device),
)
with torch.device("meta"):
model = Seva(SevaParams()).to(torch.bfloat16)
missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
if verbose:
print_load_warning(missing, unexpected)
return model