Spaces:
Runtime error
Runtime error
File size: 890 Bytes
221f925 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import os
import sys
def fix_pytorch_int8():
valid_path = [p for p in sys.path if p and os.path.isdir(p)]
for path in valid_path:
for folder in os.listdir(path):
if 'torch' in folder:
packages_path = path
break
fix_path = f'{packages_path}/torch/nn/parameter.py'
with open(fix_path, 'r') as f:
text = f.read()
if 'if data.dtype == torch.int8' not in text:
text = text.replace(
' return torch.Tensor._make_subclass(cls, data, requires_grad)',
' if data.dtype == torch.int8:\n' \
' requires_grad = False\n' \
' return torch.Tensor._make_subclass(cls, data, requires_grad)'
)
with open(fix_path, 'w') as f:
f.write(text)
return print('Fixed torch/nn/parameter.py')
|