|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
|
|
def Generator(): |
|
up_conv_block = lambda c_in, c_out: [ |
|
nn.Upsample(None, 2, 'bilinear'), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(c_in, c_in, 3, 1, 1), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(c_in, c_in, 3, 1, 1), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Conv2d(c_in, c_out, 3, 1, 1), |
|
] |
|
return nn.Sequential( |
|
nn.Linear(256, 1024), |
|
nn.LeakyReLU(0.1, True), |
|
nn.Linear(1024, 9216), |
|
nn.LayerNorm(9216, 1e-6, False), |
|
nn.Unflatten(1, (1024, 3, 3)), |
|
*up_conv_block(1024, 512), |
|
*up_conv_block(512, 256), |
|
*up_conv_block(256, 128), |
|
*up_conv_block(128, 64), |
|
*up_conv_block(64, 3), |
|
nn.Sigmoid(), |
|
) |
|
|
|
model = Generator().requires_grad_(False).eval() |
|
model.load_state_dict(torch.load('weights.pt')) |
|
p = 2147483647 |
|
|
|
def gen(state): |
|
state = max(round(state), 1) |
|
x = torch.empty(1, 256, dtype=torch.float64) |
|
for i in range(256): |
|
state = state * 48271 % p |
|
x[0, i] = float(state) / p |
|
x = torch.special.ndtri(x).float() |
|
y = model(x).mul(255).round().byte() |
|
img = y[0].permute(1, 2, 0).numpy() |
|
return state, img |
|
|
|
with gr.Blocks() as demo: |
|
state_slider = gr.Slider(1, p - 1, 4, step=1, label='PRNG State') |
|
img_output = gr.Image(label="Generated Image", format='png') |
|
click_btn = gr.Button('Generate') |
|
click_btn.click(fn=gen, inputs=state_slider, outputs=[state_slider, img_output]) |
|
|
|
demo.launch() |