Spaces:
Runtime error
Runtime error
Update audio_diffusion_attacks_forhf/src/balancer.py
Browse files
audio_diffusion_attacks_forhf/src/balancer.py
CHANGED
|
@@ -90,10 +90,7 @@ class Balancer:
|
|
| 90 |
grads = {}
|
| 91 |
for name, loss in losses.items():
|
| 92 |
# Compute partial derivative of the less with respect to the input.
|
| 93 |
-
|
| 94 |
-
loss.requires_grad = True
|
| 95 |
-
#Andy edited--CHECK WITH WILLIAM: grad, = autograd.grad(loss, [input], retain_graph=True)
|
| 96 |
-
grad, = autograd.grad(loss, [input], retain_graph=True, allow_unused=True)
|
| 97 |
if self.per_batch_item:
|
| 98 |
# We do not average the gradient over the batch dimension.
|
| 99 |
dims = tuple(range(1, grad.dim()))
|
|
|
|
| 90 |
grads = {}
|
| 91 |
for name, loss in losses.items():
|
| 92 |
# Compute partial derivative of the less with respect to the input.
|
| 93 |
+
grad, = autograd.grad(loss, [input], retain_graph=True)
|
|
|
|
|
|
|
|
|
|
| 94 |
if self.per_batch_item:
|
| 95 |
# We do not average the gradient over the batch dimension.
|
| 96 |
dims = tuple(range(1, grad.dim()))
|