PrakhAI commited on
Commit
fae14cf
·
1 Parent(s): 827c2c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -19,23 +19,23 @@ with fs.open("PrakhAI/AIPlane2/g_checkpoint.msgpack", "rb") as f:
19
  def sample_latent(batch, key):
20
  return jax.random.normal(key, shape=(batch, LATENT_DIM))
21
 
22
- def gridify(images): # num x image_width x image_height x channels
23
- # Every num can be padded to make a grid of size floor(sqrt(num)) x ceil(sqrt(num)) or ceil(sqrt(num)) x ceil(sqrt(num))
24
- num = images.shape[0]
25
- image_width = images.shape[1]
26
- image_height = images.shape[2]
27
- channels = images.shape[3]
28
- width = math.floor(math.sqrt(num))
29
- height = math.ceil(math.sqrt(num))
30
- if width * height < num:
31
- width += 1
32
- padded = np.concatenate([images, np.zeros((width*height-num, image_width, image_height, channels))], axis=0)
33
- return padded.reshape((width, height, image_width, image_height, -1)).transpose((0, 2, 1, 3, 4)).reshape((width * image_width, height * image_height, -1))
34
 
35
- st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
36
- num_images = st.number_input(label="Number of images to generate", min_value=1, max_value=256, value=16)
37
- if st.button('Generate Planes'):
38
- latents = sample_latent(num_images, jax.random.PRNGKey(int(1_000_000 * time.time())))
39
  (g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
40
- img = (np.array(gridify(g_out128)+1)*255./2.).astype(np.uint8)
41
- st.image(Image.fromarray(img))
 
 
 
 
 
 
 
 
 
 
 
19
  def sample_latent(batch, key):
20
  return jax.random.normal(key, shape=(batch, LATENT_DIM))
21
 
22
+ def to_img(normalized):
23
+ return ((normalized+1)*255./2.).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def generate_images(previous=None):
26
+ latents = sample_latent(16, jax.random.PRNGKey(int(1_000_000 * time.time())))
27
+ if previous:
28
+ latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
29
  (g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
30
+ img = np.array(to_img(g_out128))
31
+ for row in range(4):
32
+ with st.container():
33
+ for (col_idx, col) in enumerate(st.columns(4)):
34
+ with col:
35
+ idx = row*grid_width + col_idx
36
+ st.image(Image.fromarray(img[idx]))
37
+ st.button(label="Generate similar", on_click=generate_images, args=latents[idx])
38
+
39
+ st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
40
+ if st.button('Generate Random'):
41
+ generate_images()