Spaces:
Running
on
T4
Running
on
T4
Update Modules/ControllabilityGAN/GAN.py
Browse files
Modules/ControllabilityGAN/GAN.py
CHANGED
|
@@ -5,7 +5,7 @@ from Modules.ControllabilityGAN.wgan.init_wgan import create_wgan
|
|
| 5 |
|
| 6 |
class GanWrapper:
|
| 7 |
|
| 8 |
-
def __init__(self, path_wgan, device):
|
| 9 |
self.device = device
|
| 10 |
self.path_wgan = path_wgan
|
| 11 |
|
|
@@ -20,15 +20,18 @@ class GanWrapper:
|
|
| 20 |
|
| 21 |
self.z_list = list()
|
| 22 |
|
| 23 |
-
|
| 24 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.z = self.z_list[0]
|
| 26 |
|
| 27 |
def set_latent(self, seed):
|
| 28 |
-
self.z = self.
|
| 29 |
-
|
| 30 |
-
def reset_default_latent(self):
|
| 31 |
-
self.z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
|
| 32 |
|
| 33 |
def load_model(self, path):
|
| 34 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
|
@@ -53,7 +56,7 @@ class GanWrapper:
|
|
| 53 |
self.mean = gan_checkpoint["dataset_mean"]
|
| 54 |
self.std = gan_checkpoint["dataset_std"]
|
| 55 |
|
| 56 |
-
def compute_controllability(self, n_samples=
|
| 57 |
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
|
| 58 |
intermediate = intermediate.cpu()
|
| 59 |
z = z.cpu()
|
|
|
|
| 5 |
|
| 6 |
class GanWrapper:
|
| 7 |
|
| 8 |
+
def __init__(self, path_wgan, device, num_cached_voices=10):
|
| 9 |
self.device = device
|
| 10 |
self.path_wgan = path_wgan
|
| 11 |
|
|
|
|
| 20 |
|
| 21 |
self.z_list = list()
|
| 22 |
|
| 23 |
+
while len(self.z_list) < num_cached_voices + 2:
|
| 24 |
+
z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)
|
| 25 |
+
sims = [-1.0]
|
| 26 |
+
for other_z in self.z_list:
|
| 27 |
+
sims.append(torch.nn.functional.cosine_similarity(z, other_z))
|
| 28 |
+
print(max(sims), len(self.z_list))
|
| 29 |
+
if max(sims) < 0.25:
|
| 30 |
+
self.z_list.append(z)
|
| 31 |
self.z = self.z_list[0]
|
| 32 |
|
| 33 |
def set_latent(self, seed):
|
| 34 |
+
self.z = self.z_list[seed]
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def load_model(self, path):
|
| 37 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
|
|
|
| 56 |
self.mean = gan_checkpoint["dataset_mean"]
|
| 57 |
self.std = gan_checkpoint["dataset_std"]
|
| 58 |
|
| 59 |
+
def compute_controllability(self, n_samples=200000):
|
| 60 |
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
|
| 61 |
intermediate = intermediate.cpu()
|
| 62 |
z = z.cpu()
|