Spaces:
Runtime error
Runtime error
import gradio as gr | |
import utils.utils as utils | |
from PIL import Image | |
import torch | |
import math | |
from torchvision import transforms | |
from run_pti import run_PTI | |
from configs import global_config, paths_config | |
device = global_config.device | |
years = [str(y) for y in range(1880, 2020, 10)] | |
decades = [y + "s" for y in years] | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
orig_models = {} | |
for year in years: | |
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device) | |
orig_models[year] = { "G": G.eval().float()} | |
def run_alignment(image_path,idx=None): | |
import dlib | |
from align_all_parallel import align_face | |
predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat") | |
aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx) | |
return aligned_image | |
def predict(inp, in_decade): | |
in_year = in_decade[:-1] | |
#with torch.no_grad(): | |
inp.save("imgs/input.png") | |
inversion = run_alignment("imgs/input.png", idx=0) | |
inversion.save("imgs/cropped/input.png") | |
run_PTI(run_name="gradio_demo", in_year=in_year, use_wandb=False, use_multi_id_training=False) | |
#inversion = Image.open("imgs/cropped/input.png") | |
pti_models = {} | |
for year in years: | |
G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device) | |
pti_models[year] = { "G": G.eval().float()} | |
pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float() | |
for year in years: | |
if year != in_year: | |
for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()): | |
with torch.no_grad(): | |
delta = p_pti - p_orig | |
p += delta | |
space = 0 | |
#dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white') | |
w_pti = torch.load(f"embeddings/gradio/PTI/input/0.pt", map_location=device) | |
border_width = 10 | |
#fill_color = 'red' | |
#dst.paste(inversion, (0, 0)) | |
dst = [] | |
dst.append(inversion) | |
for i in range(0, len(years)): | |
year = str(years[i]) | |
with torch.no_grad(): | |
child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True) | |
img = utils.tensor2im(child_tensor.squeeze(0)) | |
# if year == in_year: | |
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width)) | |
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color) | |
#dst.paste(img, ((256 + space) * (i+1), 0)) | |
dst.append(img) | |
dst | |
return dst | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
in_img = gr.Image(label="Input Image", type="pil") | |
in_year = gr.Dropdown(label="Input Decade", choices=decades, value="2010s") | |
submit = gr.Button(value="Submit") | |
examples = gr.Examples(examples=[["imgs/Steven-Yeun.jpg", "2010s"], ["imgs/00061_1920.png", "1920s"]], inputs=[in_img, in_year]) | |
with gr.Column() as outs: | |
with gr.Row(): | |
cropped = gr.Image(label=f"Cropped Input", type="pil").style(height=256, width=256) | |
out_1880 = gr.Image(label=f"1880", type="pil").style(height=256, width=256) | |
out_1890 = gr.Image(label=f"1890", type="pil").style(height=256, width=256) | |
with gr.Row(): | |
out_1900 = gr.Image(label=f"1900", type="pil").style(height=256, width=256) | |
out_1910 = gr.Image(label=f"1910", type="pil").style(height=256, width=256) | |
out_1920 = gr.Image(label=f"1920", type="pil").style(height=256, width=256) | |
with gr.Row(): | |
out_1930 = gr.Image(label=f"1930", type="pil").style(height=256, width=256) | |
out_1940 = gr.Image(label=f"1940", type="pil").style(height=256, width=256) | |
out_1950 = gr.Image(label=f"1950", type="pil").style(height=256, width=256) | |
with gr.Row(): | |
out_1960 = gr.Image(label=f"1960", type="pil").style(height=256, width=256) | |
out_1970 = gr.Image(label=f"1970", type="pil").style(height=256, width=256) | |
out_1980 = gr.Image(label=f"1980", type="pil").style(height=256, width=256) | |
with gr.Row(): | |
out_1990 = gr.Image(label=f"1990", type="pil").style(height=256, width=256) | |
out_2000 = gr.Image(label=f"2000", type="pil").style(height=256, width=256) | |
out_2010 = gr.Image(label=f"2010", type="pil").style(height=256, width=256) | |
outs = [cropped, out_1880, out_1890, out_1900, out_1910, out_1920, out_1930, out_1940, out_1950, out_1960, out_1970, out_1980, out_1990, out_2000, out_2010] | |
submit.click(predict, inputs=[in_img, in_year], outputs=outs) | |
demo.launch() #server_name="0.0.0.0", server_port=8098) |