File size: 1,525 Bytes
8cef139 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import torch
import torch.nn as nn
def Generator():
up_conv_block = lambda c_in, c_out: [
nn.Upsample(None, 2, 'bilinear'),
nn.LeakyReLU(0.1, True),
nn.Conv2d(c_in, c_in, 3, 1, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(c_in, c_in, 3, 1, 1),
nn.LeakyReLU(0.1, True),
nn.Conv2d(c_in, c_out, 3, 1, 1),
]
return nn.Sequential(
nn.Linear(256, 1024),
nn.LeakyReLU(0.1, True),
nn.Linear(1024, 9216),
nn.LayerNorm(9216, 1e-6, False),
nn.Unflatten(1, (1024, 3, 3)),
*up_conv_block(1024, 512),
*up_conv_block(512, 256),
*up_conv_block(256, 128),
*up_conv_block(128, 64),
*up_conv_block(64, 3),
nn.Sigmoid(),
)
model = Generator().requires_grad_(False).eval()
model.load_state_dict(torch.load('weights.pt'))
p = 2147483647
def gen(state):
state = max(round(state), 1)
x = torch.empty(1, 256, dtype=torch.float64)
for i in range(256):
state = state * 48271 % p
x[0, i] = float(state) / p
x = torch.special.ndtri(x).float()
y = model(x).mul(255).round().byte()
img = y[0].permute(1, 2, 0).numpy()
return state, img
with gr.Blocks() as demo:
state_slider = gr.Slider(1, p - 1, 4, step=1, label='PRNG State')
img_output = gr.Image(label="Generated Image", format='png')
click_btn = gr.Button('Generate')
click_btn.click(fn=gen, inputs=state_slider, outputs=[state_slider, img_output])
demo.launch() |