echen01
update demo
33dd132
raw
history blame
5.1 kB
import gradio as gr
import utils.utils as utils
from PIL import Image
import torch
import math
from torchvision import transforms
from run_pti import run_PTI
from configs import global_config, paths_config
device = global_config.device
years = [str(y) for y in range(1880, 2020, 10)]
decades = [y + "s" for y in years]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
orig_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
orig_models[year] = { "G": G.eval().float()}
def run_alignment(image_path,idx=None):
import dlib
from align_all_parallel import align_face
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)
return aligned_image
def predict(inp, in_decade):
in_year = in_decade[:-1]
#with torch.no_grad():
inp.save("imgs/input.png")
inversion = run_alignment("imgs/input.png", idx=0)
inversion.save("imgs/cropped/input.png")
run_PTI(run_name="gradio_demo", in_year=in_year, use_wandb=False, use_multi_id_training=False)
#inversion = Image.open("imgs/cropped/input.png")
pti_models = {}
for year in years:
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
pti_models[year] = { "G": G.eval().float()}
pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()
for year in years:
if year != in_year:
for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
with torch.no_grad():
delta = p_pti - p_orig
p += delta
space = 0
#dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
w_pti = torch.load(f"embeddings/gradio/PTI/input/0.pt", map_location=device)
border_width = 10
#fill_color = 'red'
#dst.paste(inversion, (0, 0))
dst = []
dst.append(inversion)
for i in range(0, len(years)):
year = str(years[i])
with torch.no_grad():
child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
img = utils.tensor2im(child_tensor.squeeze(0))
# if year == in_year:
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
#dst.paste(img, ((256 + space) * (i+1), 0))
dst.append(img)
dst
return dst
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
in_img = gr.Image(label="Input Image", type="pil")
in_year = gr.Dropdown(label="Input Decade", choices=decades, value="2010s")
submit = gr.Button(value="Submit")
examples = gr.Examples(examples=[["imgs/Steven-Yeun.jpg", "2010s"], ["imgs/00061_1920.png", "1920s"]], inputs=[in_img, in_year])
with gr.Column() as outs:
with gr.Row():
cropped = gr.Image(label=f"Cropped Input", type="pil").style(height=256, width=256)
out_1880 = gr.Image(label=f"1880", type="pil").style(height=256, width=256)
out_1890 = gr.Image(label=f"1890", type="pil").style(height=256, width=256)
with gr.Row():
out_1900 = gr.Image(label=f"1900", type="pil").style(height=256, width=256)
out_1910 = gr.Image(label=f"1910", type="pil").style(height=256, width=256)
out_1920 = gr.Image(label=f"1920", type="pil").style(height=256, width=256)
with gr.Row():
out_1930 = gr.Image(label=f"1930", type="pil").style(height=256, width=256)
out_1940 = gr.Image(label=f"1940", type="pil").style(height=256, width=256)
out_1950 = gr.Image(label=f"1950", type="pil").style(height=256, width=256)
with gr.Row():
out_1960 = gr.Image(label=f"1960", type="pil").style(height=256, width=256)
out_1970 = gr.Image(label=f"1970", type="pil").style(height=256, width=256)
out_1980 = gr.Image(label=f"1980", type="pil").style(height=256, width=256)
with gr.Row():
out_1990 = gr.Image(label=f"1990", type="pil").style(height=256, width=256)
out_2000 = gr.Image(label=f"2000", type="pil").style(height=256, width=256)
out_2010 = gr.Image(label=f"2010", type="pil").style(height=256, width=256)
outs = [cropped, out_1880, out_1890, out_1900, out_1910, out_1920, out_1930, out_1940, out_1950, out_1960, out_1970, out_1980, out_1990, out_2000, out_2010]
submit.click(predict, inputs=[in_img, in_year], outputs=outs)
demo.launch() #server_name="0.0.0.0", server_port=8098)