File size: 1,287 Bytes
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import yaml
import json
from typing import List
import torch


def tensor2list(d: dict):
    tensor2list_lambda = lambda x: x.detach().cpu().numpy().tolist()
    for k in d.keys():
        if isinstance(d[k], torch.Tensor):
            d[k] = tensor2list_lambda(d[k])
        if isinstance(d[k], List):
            if isinstance(d[k][0], torch.Tensor):
                d[k] = [tensor2list_lambda(x) for x in d[k]]
    return d


def write_json(json_serializable_dict, fout, indent=2):
    with open(fout, "w") as fw:
        json.dump(json_serializable_dict, fw, indent=indent)


def write_yaml(json_serializable_dict, fout):
    with open(fout, "w") as fw:
        yaml.dump(json_serializable_dict, fw, default_flow_style=False)


def detach_dict(x_dict):
    with torch.no_grad():
        for k in x_dict.keys():
            if isinstance(x_dict[k], torch.Tensor):
                x_dict[k] = x_dict[k].detach().cpu()
            elif isinstance(x_dict[k], dict):
                x_dict[k] = detach_dict(x_dict[k])
    return x_dict


def tensor2list(xdict):
    for k in xdict.keys():
        if isinstance(xdict[k], torch.Tensor):
            xdict[k] = xdict[k].numpy().tolist()
        elif isinstance(xdict[k], dict):
            xdict[k] = tensor2list(xdict[k])
    return xdict