Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,8 @@ def to_img(normalized):
|
|
25 |
ROWS = 4
|
26 |
COLUMNS = 4
|
27 |
def generate_images(previous=None):
|
28 |
-
|
|
|
29 |
if previous:
|
30 |
latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
|
31 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
@@ -36,7 +37,7 @@ def generate_images(previous=None):
|
|
36 |
with col:
|
37 |
idx = row*COLUMNS + col_idx
|
38 |
st.image(Image.fromarray(img[idx]))
|
39 |
-
st.button(label="Generate similar", on_click=generate_images, args=(latents[idx]))
|
40 |
|
41 |
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
|
42 |
if st.button('Generate Random'):
|
|
|
25 |
ROWS = 4
|
26 |
COLUMNS = 4
|
27 |
def generate_images(previous=None):
|
28 |
+
unique_id = int(1_000_000 * time.time())
|
29 |
+
latents = sample_latent(ROWS * COLUMNS, jax.random.PRNGKey(unique_id))
|
30 |
if previous:
|
31 |
latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
|
32 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
|
|
37 |
with col:
|
38 |
idx = row*COLUMNS + col_idx
|
39 |
st.image(Image.fromarray(img[idx]))
|
40 |
+
st.button(label="Generate similar", key="%d_%d" % (unique_id, idx), on_click=generate_images, args=(latents[idx]))
|
41 |
|
42 |
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
|
43 |
if st.button('Generate Random'):
|