Spaces:
Runtime error
Runtime error
Commit
·
647e1a1
1
Parent(s):
9589cd1
Fix diffusion sampler not working on gpu
Browse files- models.py +1 -1
- shell_vars.sh +1 -1
models.py
CHANGED
@@ -147,7 +147,7 @@ class DiffusionGenerationModel(nn.Module):
|
|
147 |
return self.model(x)
|
148 |
|
149 |
def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
|
150 |
-
noise = torch.randn(x.shape)
|
151 |
return self.model.sample(noise, num_steps=num_steps)
|
152 |
|
153 |
|
|
|
147 |
return self.model(x)
|
148 |
|
149 |
def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
|
150 |
+
noise = torch.randn(x.shape).to(x)
|
151 |
return self.model.sample(noise, num_steps=num_steps)
|
152 |
|
153 |
|
shell_vars.sh
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
export DATASET_ROOT="
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|
|
|
1 |
+
export DATASET_ROOT="./data/egfx"
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|