Spaces:
Runtime error
Runtime error
Update image_wgan.py
Browse files- image_wgan.py +1 -3
image_wgan.py
CHANGED
|
@@ -106,14 +106,12 @@ class ImageWgan:
|
|
| 106 |
|
| 107 |
def generate(
|
| 108 |
self,
|
| 109 |
-
sample_folder: str = 'samples'
|
| 110 |
-
seed: int = 'seed'
|
| 111 |
):
|
| 112 |
if not exists(sample_folder):
|
| 113 |
mkdir(sample_folder)
|
| 114 |
|
| 115 |
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
|
| 116 |
-
np.random.seed(seed)
|
| 117 |
z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension))))
|
| 118 |
gen_imgs = self.generator(z)
|
| 119 |
generator_loss = -torch.mean(self.discriminator(gen_imgs))
|
|
|
|
| 106 |
|
| 107 |
def generate(
|
| 108 |
self,
|
| 109 |
+
sample_folder: str = 'samples'
|
|
|
|
| 110 |
):
|
| 111 |
if not exists(sample_folder):
|
| 112 |
mkdir(sample_folder)
|
| 113 |
|
| 114 |
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
|
|
|
|
| 115 |
z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension))))
|
| 116 |
gen_imgs = self.generator(z)
|
| 117 |
generator_loss = -torch.mean(self.discriminator(gen_imgs))
|