PrakhAI commited on
Commit
c84c172
·
0 Parent(s):

Duplicate from PrakhAI/AIPlane

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. __init__.py +0 -0
  4. app.py +61 -0
  5. local_response_norm.py +11 -0
  6. requirements.txt +1 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AIPlane
3
+ emoji: 🌖
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.25.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: PrakhAI/AIPlane
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import jax
4
+ import jax.numpy as jnp # JAX NumPy
5
+ import numpy as np
6
+ from flax import linen as nn # Linen API
7
+ from huggingface_hub import HfFileSystem
8
+ from flax.serialization import msgpack_restore, from_state_dict
9
+ import time
10
+ from local_response_norm import LocalResponseNorm
11
+
12
+ LATENT_DIM = 100
13
+
14
+ class Generator(nn.Module):
15
+ @nn.compact
16
+ def __call__(self, latent, training=True):
17
+ x = nn.Dense(features=32)(latent)
18
+ # x = nn.BatchNorm(not training)(x)
19
+ x = nn.relu(x)
20
+ x = nn.Dense(features=2*2*256)(x)
21
+ x = nn.BatchNorm(not training)(x)
22
+ x = nn.relu(x)
23
+ x = nn.Dropout(0.5, deterministic=not training)(x)
24
+ x = x.reshape((x.shape[0], 2, 2, -1))
25
+ x4o = nn.ConvTranspose(features=3, kernel_size=(2, 2), strides=(2, 2))(x)
26
+ x4 = nn.ConvTranspose(features=128, kernel_size=(2, 2), strides=(2, 2))(x)
27
+ x4 = LocalResponseNorm()(x4)
28
+ # x4 = nn.BatchNorm(not training)(x4)
29
+ x8 = nn.relu(x4)
30
+ # x8 = nn.Dropout(0.5, deterministic=not training)(x8)
31
+ x8o = nn.ConvTranspose(features=3, kernel_size=(2, 2), strides=(2, 2))(x8)
32
+ x8 = nn.ConvTranspose(features=64, kernel_size=(2, 2), strides=(2, 2))(x8)
33
+ x8 = LocalResponseNorm()(x8)
34
+ # x8 = nn.BatchNorm(not training)(x8)
35
+ x16 = nn.relu(x8)
36
+ # x16 = nn.Dropout(0.5, deterministic=not training)(x16)
37
+ x16o = nn.ConvTranspose(features=3, kernel_size=(2, 2), strides=(2, 2))(x16)
38
+ x16 = nn.ConvTranspose(features=32, kernel_size=(2, 2), strides=(2, 2))(x16)
39
+ x16 = LocalResponseNorm()(x16)
40
+ # x16 = nn.BatchNorm(not training)(x16)
41
+ x32 = nn.relu(x16)
42
+ # x32 = nn.Dropout(0.5, deterministic=not training)(x32)
43
+ x32o = nn.ConvTranspose(features=3, kernel_size=(2, 2), strides=(2, 2))(x32)
44
+ return (nn.tanh(x32o), nn.tanh(x16o), nn.tanh(x8o), nn.tanh(x4o))
45
+
46
+ generator = Generator()
47
+ variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
48
+
49
+ fs = HfFileSystem()
50
+ with fs.open("PrakhAI/AIPlane/g_checkpoint.msgpack", "rb") as f:
51
+ g_state = from_state_dict(variables, msgpack_restore(f.read()))
52
+
53
+ def sample_latent(key):
54
+ return jax.random.normal(key, shape=(1, LATENT_DIM))
55
+
56
+ if st.button('Generate Plane'):
57
+ latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
58
+ (g_out32, g_out16, g_out8, g_out4) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
59
+ img = ((np.array(g_out32[0])+1)*255./2.).astype(np.uint8)
60
+ st.image(Image.fromarray(img))
61
+ st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane")
local_response_norm.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ class LocalResponseNorm(nn.Module):
6
+ @nn.compact
7
+ def __call__(
8
+ self,
9
+ value: jax.Array
10
+ ) -> jax.Array:
11
+ return value / jnp.repeat(jnp.expand_dims((1e-8 + (value**2).mean(axis=-1))**0.5, axis=-1), repeats=value.shape[-1], axis=-1)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ flax