File size: 5,103 Bytes
0513aaf
2fec875
0513aaf
 
 
 
2fec875
33dd132
 
0513aaf
2fec875
0513aaf
 
2fec875
 
 
 
 
0513aaf
 
 
 
2fec875
0513aaf
 
dd1add1
 
 
 
2fec875
 
c00162e
dd1add1
 
2fec875
33dd132
0513aaf
c00162e
2fec875
 
33dd132
2fec875
 
33dd132
2fec875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33dd132
2fec875
 
33dd132
2fec875
 
 
33dd132
 
 
2fec875
 
 
 
 
 
 
 
 
 
 
33dd132
 
2fec875
 
0513aaf
 
33dd132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52e4857
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import utils.utils as utils
from PIL import Image
import torch
import math
from torchvision import transforms
from run_pti import run_PTI
from configs import global_config, paths_config
device = global_config.device
years = [str(y) for y in range(1880, 2020, 10)]
decades = [y + "s" for y in years]


transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

orig_models = {}

for year in years:
    G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
    orig_models[year] = { "G": G.eval().float()}
    

def run_alignment(image_path,idx=None):
    import dlib
    from align_all_parallel import align_face
    predictor = dlib.shape_predictor("pretrained_models/shape_predictor_68_face_landmarks.dat")
    aligned_image = align_face(filepath=image_path, predictor=predictor, idx=idx)     
    

    return aligned_image 

def predict(inp, in_decade):
  in_year = in_decade[:-1]
  #with torch.no_grad():     
  inp.save("imgs/input.png")
  inversion = run_alignment("imgs/input.png", idx=0)
  inversion.save("imgs/cropped/input.png")
  run_PTI(run_name="gradio_demo", in_year=in_year, use_wandb=False, use_multi_id_training=False)
  #inversion = Image.open("imgs/cropped/input.png")
  
  
  pti_models = {}

  for year in years:
      G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
      pti_models[year] = { "G": G.eval().float()}


  pti_models[in_year]['G'] = torch.load(f"checkpoints/model_gradio_demo_input.pt", device).eval().float()

  for year in years:
      if year != in_year:
          for p_pti, p_orig, (names, p) in zip(pti_models[in_year]['G'].parameters(),orig_models[in_year]['G'].parameters(), pti_models[year]['G'].named_parameters()):
              with torch.no_grad():
                  delta = p_pti - p_orig
                  p += delta 

  space = 0
  #dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')


  w_pti = torch.load(f"embeddings/gradio/PTI/input/0.pt", map_location=device)  

  border_width = 10
  #fill_color = 'red'
  #dst.paste(inversion, (0, 0))
  dst = []
  dst.append(inversion)



  for i in range(0, len(years)):
      year = str(years[i])
      with torch.no_grad():
          child_tensor = pti_models[year]["G"].synthesis(w_pti.view(1, 14, 512), noise_mode="const", force_fp32=True)
      img = utils.tensor2im(child_tensor.squeeze(0))
      #     if year == in_year:
      #         img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
      #         img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
      #dst.paste(img, ((256 + space) * (i+1), 0))  
      dst.append(img)
  dst
  return dst



with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            
            in_img = gr.Image(label="Input Image", type="pil")
            in_year = gr.Dropdown(label="Input Decade", choices=decades, value="2010s")
            submit = gr.Button(value="Submit")
            examples = gr.Examples(examples=[["imgs/Steven-Yeun.jpg", "2010s"], ["imgs/00061_1920.png", "1920s"]], inputs=[in_img, in_year])
        with gr.Column() as outs:
            with gr.Row():
                cropped = gr.Image(label=f"Cropped Input", type="pil").style(height=256, width=256)
                out_1880 = gr.Image(label=f"1880", type="pil").style(height=256, width=256)
                out_1890 = gr.Image(label=f"1890", type="pil").style(height=256, width=256)                
            with gr.Row():
                out_1900 = gr.Image(label=f"1900", type="pil").style(height=256, width=256)
                out_1910 = gr.Image(label=f"1910", type="pil").style(height=256, width=256)
                out_1920 = gr.Image(label=f"1920", type="pil").style(height=256, width=256)
            with gr.Row():
                out_1930 = gr.Image(label=f"1930", type="pil").style(height=256, width=256)
                out_1940 = gr.Image(label=f"1940", type="pil").style(height=256, width=256)
                out_1950 = gr.Image(label=f"1950", type="pil").style(height=256, width=256)
            with gr.Row():
                out_1960 = gr.Image(label=f"1960", type="pil").style(height=256, width=256)
                out_1970 = gr.Image(label=f"1970", type="pil").style(height=256, width=256)
                out_1980 = gr.Image(label=f"1980", type="pil").style(height=256, width=256)
            with gr.Row():
                out_1990 = gr.Image(label=f"1990", type="pil").style(height=256, width=256)
                out_2000 = gr.Image(label=f"2000", type="pil").style(height=256, width=256)
                out_2010 = gr.Image(label=f"2010", type="pil").style(height=256, width=256)

    outs = [cropped, out_1880, out_1890, out_1900, out_1910, out_1920, out_1930, out_1940, out_1950, out_1960, out_1970, out_1980, out_1990, out_2000, out_2010]
    submit.click(predict, inputs=[in_img, in_year], outputs=outs)
    

demo.launch() #server_name="0.0.0.0", server_port=8098)