File size: 3,010 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import os.path as osp
import pickle
import torch
import json


def read_from_file(file_path):
    """
    Read content from a file based on its extension.

    Args:
        file_path (str): Path to the file.

    Returns:
        content: Content of the file.

    Raises:
        NotImplementedError: If the file type is not supported.
    """
    if file_path.endswith('.txt'):
        with open(file_path, 'r') as f:
            return f.read()
    elif file_path.endswith('.json'):
        with open(file_path, 'r') as f:
            return json.load(f)
    elif file_path.endswith('.pkl'):
        with open(file_path, 'rb') as f:
            return pickle.load(f)
    else:
        raise NotImplementedError(f'File type not supported: {file_path}')


def write_to_file(file_path, content):
    """
    Write content to a file based on its extension.

    Args:
        file_path (str): Path to the file.
        content: Content to write.

    Raises:
        NotImplementedError: If the file type is not supported.
    """
    if file_path.endswith('.txt'):
        with open(file_path, 'w') as f:
            f.write(content)
    elif file_path.endswith('.json'):
        with open(file_path, 'w') as f:
            json.dump(content, f, indent=4)
    elif file_path.endswith('.pkl'):
        with open(file_path, 'wb') as f:
            pickle.dump(content, f)
    else:
        raise NotImplementedError(f'File type not supported: {file_path}')


def save_files(save_path, **kwargs):
    """
    Save multiple files in a specified directory.

    Args:
        save_path (str): Directory to save the files.
        **kwargs: Keyword arguments where keys are filenames (without extension) and values are the contents.
    """
    os.makedirs(save_path, exist_ok=True)
    for key, value in kwargs.items():
        if isinstance(value, dict):
            with open(osp.join(save_path, f'{key}.pkl'), 'wb') as f:
                pickle.dump(value, f)
        elif isinstance(value, torch.Tensor):
            torch.save(value, osp.join(save_path, f'{key}.pt'))
        else:
            raise NotImplementedError(f'File type not supported for key: {key}')


def load_files(save_path):
    """
    Load all files from a specified directory.

    Args:
        save_path (str): Directory to load the files from.

    Returns:
        dict: Dictionary with filenames (without extension) as keys and file contents as values.
    """
    loaded_dict = {}
    for file in os.listdir(save_path):
        if os.path.isdir(osp.join(save_path, file)):
            continue
        file_path = osp.join(save_path, file)
        file_name, file_ext = osp.splitext(file)
        if file_ext == '.pkl':
            with open(file_path, 'rb') as f:
                loaded_dict[file_name] = pickle.load(f)
        elif file_ext == '.pt':
            loaded_dict[file_name] = torch.load(file_path)
        else:
            raise NotImplementedError(f'File type not supported: {file}')
    return loaded_dict