Update app.py
Browse files
app.py
CHANGED
@@ -19,23 +19,23 @@ with fs.open("PrakhAI/AIPlane2/g_checkpoint.msgpack", "rb") as f:
|
|
19 |
def sample_latent(batch, key):
|
20 |
return jax.random.normal(key, shape=(batch, LATENT_DIM))
|
21 |
|
22 |
-
def
|
23 |
-
|
24 |
-
num = images.shape[0]
|
25 |
-
image_width = images.shape[1]
|
26 |
-
image_height = images.shape[2]
|
27 |
-
channels = images.shape[3]
|
28 |
-
width = math.floor(math.sqrt(num))
|
29 |
-
height = math.ceil(math.sqrt(num))
|
30 |
-
if width * height < num:
|
31 |
-
width += 1
|
32 |
-
padded = np.concatenate([images, np.zeros((width*height-num, image_width, image_height, channels))], axis=0)
|
33 |
-
return padded.reshape((width, height, image_width, image_height, -1)).transpose((0, 2, 1, 3, 4)).reshape((width * image_width, height * image_height, -1))
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
if
|
38 |
-
|
39 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
40 |
-
img =
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def sample_latent(batch, key):
|
20 |
return jax.random.normal(key, shape=(batch, LATENT_DIM))
|
21 |
|
22 |
+
def to_img(normalized):
|
23 |
+
return ((normalized+1)*255./2.).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
def generate_images(previous=None):
|
26 |
+
latents = sample_latent(16, jax.random.PRNGKey(int(1_000_000 * time.time())))
|
27 |
+
if previous:
|
28 |
+
latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
|
29 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
30 |
+
img = np.array(to_img(g_out128))
|
31 |
+
for row in range(4):
|
32 |
+
with st.container():
|
33 |
+
for (col_idx, col) in enumerate(st.columns(4)):
|
34 |
+
with col:
|
35 |
+
idx = row*grid_width + col_idx
|
36 |
+
st.image(Image.fromarray(img[idx]))
|
37 |
+
st.button(label="Generate similar", on_click=generate_images, args=latents[idx])
|
38 |
+
|
39 |
+
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
|
40 |
+
if st.button('Generate Random'):
|
41 |
+
generate_images()
|