|
import torch |
|
from torch import nn |
|
from huggingface_hub import hf_hub_download |
|
from torchvision.utils import save_image |
|
import gradio as gr |
|
|
|
class Generator(nn.Module): |
|
|
|
|
|
def __init__(self, nc=4, nz=100, ngf=64): |
|
super(Generator, self).__init__() |
|
self.network = nn.Sequential( |
|
nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False), |
|
nn.BatchNorm2d(ngf), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), |
|
nn.Tanh(), |
|
) |
|
|
|
def forward(self, input): |
|
output = self.network(input) |
|
return output |
|
|
|
model = Generator() |
|
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth') |
|
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) |
|
|
|
def predict(seed, num_punks): |
|
torch.manual_seed(seed) |
|
z = torch.randn(num_punks, 100, 1, 1) |
|
punks = model(z) |
|
save_image(punks, "punks.png", normalize=True) |
|
return 'punks.png' |
|
|
|
demo = gr.Interface( |
|
predict, |
|
inputs=[ |
|
gr.Slider(0, 1000, label='Seed', value=42), |
|
gr.Slider(4, 64, label='Number of Punks', step=1, value=10), |
|
], |
|
outputs="image", |
|
examples=[[123, 15], [42, 29], [456, 8], [1337, 35]], |
|
) |
|
|
|
demo.launch(share=True) |