File size: 1,525 Bytes
8cef139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
import torch.nn as nn

def Generator():
    up_conv_block = lambda c_in, c_out: [
        nn.Upsample(None, 2, 'bilinear'),
        nn.LeakyReLU(0.1, True),
        nn.Conv2d(c_in, c_in, 3, 1, 1),
        nn.LeakyReLU(0.1, True),
        nn.Conv2d(c_in, c_in, 3, 1, 1),
        nn.LeakyReLU(0.1, True),
        nn.Conv2d(c_in, c_out, 3, 1, 1),
    ]
    return nn.Sequential(
        nn.Linear(256, 1024),
        nn.LeakyReLU(0.1, True),
        nn.Linear(1024, 9216),
        nn.LayerNorm(9216, 1e-6, False),
        nn.Unflatten(1, (1024, 3, 3)),
        *up_conv_block(1024, 512),
        *up_conv_block(512, 256),
        *up_conv_block(256, 128),
        *up_conv_block(128, 64),
        *up_conv_block(64, 3),
        nn.Sigmoid(),
    )

model = Generator().requires_grad_(False).eval()
model.load_state_dict(torch.load('weights.pt'))
p = 2147483647

def gen(state):
    state = max(round(state), 1)
    x = torch.empty(1, 256, dtype=torch.float64)
    for i in range(256):
        state = state * 48271 % p
        x[0, i] = float(state) / p
    x = torch.special.ndtri(x).float()
    y = model(x).mul(255).round().byte()
    img = y[0].permute(1, 2, 0).numpy()
    return state, img

with gr.Blocks() as demo:
    state_slider = gr.Slider(1, p - 1, 4, step=1, label='PRNG State')
    img_output = gr.Image(label="Generated Image", format='png')
    click_btn = gr.Button('Generate')
    click_btn.click(fn=gen, inputs=state_slider, outputs=[state_slider, img_output])

demo.launch()