File size: 890 Bytes
221f925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import sys


def fix_pytorch_int8():
    valid_path = [p for p in sys.path if p and os.path.isdir(p)]

    for path in valid_path:
        for folder in os.listdir(path):
            if 'torch' in folder:
                packages_path = path
                break

    fix_path = f'{packages_path}/torch/nn/parameter.py'

    with open(fix_path, 'r') as f:
        text = f.read()

    if 'if data.dtype == torch.int8' not in text:
        text = text.replace(
            '            return torch.Tensor._make_subclass(cls, data, requires_grad)', 
            '            if data.dtype == torch.int8:\n' \
            '                requires_grad = False\n' \
            '            return torch.Tensor._make_subclass(cls, data, requires_grad)'
        )
        with open(fix_path, 'w') as f:
            f.write(text)

        return print('Fixed torch/nn/parameter.py')