|
import os |
|
import sys |
|
from session import logger |
|
|
|
|
|
def fix_pytorch_int8(): |
|
valid_path = [p for p in sys.path if p and os.path.isdir(p)] |
|
|
|
for path in valid_path: |
|
for folder in os.listdir(path): |
|
if 'torch' in folder: |
|
packages_path = path |
|
break |
|
|
|
fix_path = f'{packages_path}/torch/nn/parameter.py' |
|
|
|
with open(fix_path, 'r') as f: |
|
text = f.read() |
|
|
|
if 'if data.dtype == torch.int8' not in text: |
|
text = text.replace( |
|
' return torch.Tensor._make_subclass(cls, data, requires_grad)', |
|
' if data.dtype == torch.int8:\n' \ |
|
' requires_grad = False\n' \ |
|
' return torch.Tensor._make_subclass(cls, data, requires_grad)' |
|
) |
|
with open(fix_path, 'w') as f: |
|
f.write(text) |
|
|
|
return logger.info('Fixed torch/nn/parameter.py') |
|
|