Spaces:
Runtime error
Runtime error
File size: 5,103 Bytes
0513aaf 2fec875 0513aaf 2fec875 33dd132 0513aaf 2fec875 0513aaf 2fec875 0513aaf 2fec875 0513aaf dd1add1 2fec875 c00162e dd1add1 2fec875 33dd132 0513aaf c00162e 2fec875 33dd132 2fec875 33dd132 2fec875 33dd132 2fec875 33dd132 2fec875 33dd132 2fec875 33dd132 2fec875 0513aaf 33dd132 52e4857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import utils.utils as utils
from PIL import Image
import torch
import math
from torchvision import transforms
from run_pti import run_PTI
from configs import global_config, paths_config
device = global_config.device
years = [str(y) for y in range(1880, 2020, 10)]
decades = [y + "s" for y in years]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
orig_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
orig_models[year] = { "G": G.eval().float()}
def run_alignment(image_path,idx=None):
import dlib
from align_all_parallel import align_face
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
return aligned_image
def predict(inp, in_decade):
in_year = in_decade[:-1]
#with torch.no_grad():
inp.save("imgs/input.png")
inversion = run_alignment("imgs/input.png", idx=0)
inversion.save("imgs/cropped/input.png")
run_PTI(run_name="gradio_demo", in_year=in_year, use_wandb=False, use_multi_id_training=False)
#inversion = Image.open("imgs/cropped/input.png")
pti_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
pti_models[year] = { "G": G.eval().float()}
pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()
for year in years:
if year != in_year:
for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
with torch.no_grad():
delta = p_pti - p_orig
p += delta
space = 0
#dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
w_pti = torch.load(f"embeddings/gradio/PTI/input/0.pt", map_location=device)
border_width = 10
#fill_color = 'red'
#dst.paste(inversion, (0, 0))
dst = []
dst.append(inversion)
for i in range(0, len(years)):
year = str(years[i])
with torch.no_grad():
child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
img = utils.tensor2im(child_tensor.squeeze(0))
# if year == in_year:
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
#dst.paste(img, ((256 + space) * (i+1), 0))
dst.append(img)
dst
return dst
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
in_img = gr.Image(label="Input Image", type="pil")
in_year = gr.Dropdown(label="Input Decade", choices=decades, value="2010s")
submit = gr.Button(value="Submit")
examples = gr.Examples(examples=[["imgs/Steven-Yeun.jpg", "2010s"], ["imgs/00061_1920.png", "1920s"]], inputs=[in_img, in_year])
with gr.Column() as outs:
with gr.Row():
cropped = gr.Image(label=f"Cropped Input", type="pil").style(height=256, width=256)
out_1880 = gr.Image(label=f"1880", type="pil").style(height=256, width=256)
out_1890 = gr.Image(label=f"1890", type="pil").style(height=256, width=256)
with gr.Row():
out_1900 = gr.Image(label=f"1900", type="pil").style(height=256, width=256)
out_1910 = gr.Image(label=f"1910", type="pil").style(height=256, width=256)
out_1920 = gr.Image(label=f"1920", type="pil").style(height=256, width=256)
with gr.Row():
out_1930 = gr.Image(label=f"1930", type="pil").style(height=256, width=256)
out_1940 = gr.Image(label=f"1940", type="pil").style(height=256, width=256)
out_1950 = gr.Image(label=f"1950", type="pil").style(height=256, width=256)
with gr.Row():
out_1960 = gr.Image(label=f"1960", type="pil").style(height=256, width=256)
out_1970 = gr.Image(label=f"1970", type="pil").style(height=256, width=256)
out_1980 = gr.Image(label=f"1980", type="pil").style(height=256, width=256)
with gr.Row():
out_1990 = gr.Image(label=f"1990", type="pil").style(height=256, width=256)
out_2000 = gr.Image(label=f"2000", type="pil").style(height=256, width=256)
out_2010 = gr.Image(label=f"2010", type="pil").style(height=256, width=256)
outs = [cropped, out_1880, out_1890, out_1900, out_1910, out_1920, out_1930, out_1940, out_1950, out_1960, out_1970, out_1980, out_1990, out_2000, out_2010]
submit.click(predict, inputs=[in_img, in_year], outputs=outs)
demo.launch() #server_name="0.0.0.0", server_port=8098) |