import os import safetensors.torch import torch from huggingface_hub import hf_hub_download from seva.model import Seva, SevaParams def seed_everything(seed: int = 0): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_model( pretrained_model_name_or_path: str = "stabilityai/stable-virtual-camera", weight_name: str = "model.safetensors", device: str | torch.device = "cuda", verbose: bool = False, ) -> Seva: if os.path.isdir(pretrained_model_name_or_path): weight_path = os.path.join(pretrained_model_name_or_path, weight_name) else: weight_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename=weight_name ) _ = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="config.yaml" ) state_dict = safetensors.torch.load_file( weight_path, device=str(device), ) with torch.device("meta"): model = Seva(SevaParams()).to(torch.bfloat16) missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True) if verbose: print_load_warning(missing, unexpected) return model