Spaces:
Runtime error
Runtime error
Commit
Β·
d65e5f6
1
Parent(s):
50f1efd
refactor(model): remove explicit device_type parameter from amp decorators
Browse files
llava/model/qlinear_te.py
CHANGED
|
@@ -98,7 +98,7 @@ class QLinearTE(nn.Linear):
|
|
| 98 |
|
| 99 |
class QuantLinearTE(Function):
|
| 100 |
@staticmethod
|
| 101 |
-
@amp.custom_fwd(cast_inputs=torch.bfloat16
|
| 102 |
def forward(ctx, input, weight, bias, args, layer_name):
|
| 103 |
|
| 104 |
time_bench = os.getenv("TIME_BENCH")
|
|
@@ -149,7 +149,7 @@ class QuantLinearTE(Function):
|
|
| 149 |
return fc_output
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
-
@amp.custom_bwd
|
| 153 |
def backward(ctx, grad_output):
|
| 154 |
Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
|
| 155 |
|
|
|
|
| 98 |
|
| 99 |
class QuantLinearTE(Function):
|
| 100 |
@staticmethod
|
| 101 |
+
@amp.custom_fwd(cast_inputs=torch.bfloat16)
|
| 102 |
def forward(ctx, input, weight, bias, args, layer_name):
|
| 103 |
|
| 104 |
time_bench = os.getenv("TIME_BENCH")
|
|
|
|
| 149 |
return fc_output
|
| 150 |
|
| 151 |
@staticmethod
|
| 152 |
+
@amp.custom_bwd
|
| 153 |
def backward(ctx, grad_output):
|
| 154 |
Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
|
| 155 |
|