File size: 1,451 Bytes
223d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------

import importlib
import random
import numpy as np
from typing import Mapping, Optional

import torch

def initiate_from_config(config: Mapping):
    assert "target" in config, f"Expected key `target` to initialize!"
    module, cls = config["target"].rsplit(".", 1)
    meta_class = getattr(importlib.import_module(module, package=None), cls)
    return meta_class(**config.get("params", dict()))

def initiate_from_config_recursively(config: Mapping):
    assert "target" in config, f"Expected key `target` to initialize!"
    update_config = {"target": config["target"], "params": {}}
    for k, v in config["params"].items():
        if isinstance(v, Mapping) and "target" in v:
            sub_instance = initiate_from_config_recursively(v)
            update_config["params"][k] = sub_instance
        else:
            update_config["params"][k] = v
    return initiate_from_config(update_config)

def seed_everything(seed: Optional[int] = None):
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True