KumaTea commited on
Commit
1155f19
·
1 Parent(s): d2742cf

fix the fix

Browse files
Files changed (3) hide show
  1. app.py +2 -14
  2. fix_int8.py +29 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,17 +1,5 @@
1
- with open('/usr/local/lib/python3.8/dist-packages/torch/nn/parameter.py', 'r') as f:
2
- text = f.read()
3
-
4
- if 'if data.dtype == torch.int8' not in text:
5
- text = text.replace(
6
- ' return torch.Tensor._make_subclass(cls, data, requires_grad)',
7
- ' if data.dtype == torch.int8:\n' \
8
- ' requires_grad = False\n' \
9
- ' return torch.Tensor._make_subclass(cls, data, requires_grad)'
10
- )
11
-
12
- with open('/usr/local/lib/python3.8/dist-packages/torch/nn/parameter.py', 'w') as f:
13
- f.write(text)
14
-
15
 
16
 
17
  # Credit:
 
1
+ from fix_int8 import fix_pytorch_int8
2
+ fix_pytorch_int8()
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  # Credit:
fix_int8.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ def fix_pytorch_int8():
6
+ valid_path = [p for p in sys.path if p and os.path.isdir(p)]
7
+
8
+ for path in valid_path:
9
+ for folder in os.listdir(path):
10
+ if 'torch' in folder:
11
+ packages_path = path
12
+ break
13
+
14
+ fix_path = f'{packages_path}/torch/nn/parameter.py'
15
+
16
+ with open(fix_path, 'r') as f:
17
+ text = f.read()
18
+
19
+ if 'if data.dtype == torch.int8' not in text:
20
+ text = text.replace(
21
+ ' return torch.Tensor._make_subclass(cls, data, requires_grad)',
22
+ ' if data.dtype == torch.int8:\n' \
23
+ ' requires_grad = False\n' \
24
+ ' return torch.Tensor._make_subclass(cls, data, requires_grad)'
25
+ )
26
+ with open(fix_path, 'w') as f:
27
+ f.write(text)
28
+
29
+ return print('Fixed torch/nn/parameter.py')
requirements.txt CHANGED
@@ -5,7 +5,7 @@ bitsandbytes>=0.37.1
5
  accelerate>=0.17.1
6
 
7
  # chatglm
8
- protobuf>=3.19.5,<3.20.1
9
  transformers>=4.27.1
10
  icetk
11
  cpm_kernels>=1.0.11
 
5
  accelerate>=0.17.1
6
 
7
  # chatglm
8
+ protobuf>=3.19.5,<4
9
  transformers>=4.27.1
10
  icetk
11
  cpm_kernels>=1.0.11