File size: 7,173 Bytes
223d932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import gradio as gr
from io import BytesIO
import os
import sys

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from PIL import Image
from omegaconf import OmegaConf

import torch
from torchvision import transforms as T

from optvq.models.quantizer import sinkhorn
from optvq.utils.init import seed_everything
seed_everything(42)
from optvq.models.vqgan_hf import VQModelHF
matplotlib.rcParams['font.family'] = 'Times New Roman'

#################
N_data = 50
N_code = 20
dim = 2
handler = None
device = torch.device("cpu")
#################

def nearest(src, trg):
    dis_mat = torch.cdist(src, trg)
    min_idx = torch.argmin(dis_mat, dim=-1)
    return min_idx

def normalize(A, dim, mode="all"):
    if mode == "all":
        A = (A - A.mean()) / (A.std() + 1e-6)
        A = A - A.min()
    elif mode == "dim":
        A = A / dim
    elif mode == "null":
        pass
    return A

def draw_NN(data, code):
    # nearest neighbor method
    indices = nearest(data, code)
    data = data.numpy()
    code = code.numpy()
    
    plt.figure(figsize=(3, 2.5), dpi=400)
    # draw arrows in blue color, alpha=0.5
    for i in range(data.shape[0]):
        idx = indices[i].item()
        start = data[i]
        end = code[idx]
        plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
                    head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6,
                    ls="-", lw=0.5)
    plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
    plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
    plt.legend(loc="lower right")
    plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
    plt.title("Nearest neighbor")

    buf = BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    image = Image.open(buf)
    return image

def draw_optvq(data, code):
    cost = torch.cdist(data, code, p=2.0)
    cost = normalize(cost, dim, mode="all")
    Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False)
    indices = torch.argmax(Q, dim=-1)
    data = data.numpy()
    code = code.numpy()

    plt.figure(figsize=(3, 2.5), dpi=400)
    # draw arrows in blue color, alpha=0.5
    for i in range(data.shape[0]):
        idx = indices[i].item()
        start = data[i]
        end = code[idx]
        plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
                    head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6,
                    ls="-", lw=0.5)
    plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
    plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
    plt.legend(loc="lower right")
    plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
    plt.title("Optimal Transport (OptVQ)")

    buf = BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    image = Image.open(buf)
    return image

def draw_process(x, y, std):
    data = torch.randn(N_data, dim)
    code = torch.randn(N_code, dim) * std
    code[:, 0] += x
    code[:, 1] += y

    image_NN = draw_NN(data, code)
    image_optvq = draw_optvq(data, code)
    
    return image_NN, image_optvq

class Handler:
    def __init__(self, device):
        self.transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(256),
            T.ToTensor()
        ])
        self.device = device

        
        self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4")
        self.basevq.to(self.device)
        self.basevq.eval()

        self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16")
        self.vqgan.to(self.device)
        self.vqgan.eval()

        self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
        self.optvq.to(self.device)
        self.optvq.eval()

    def tensor_to_image(self, tensor):
        img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy()
        img = (img + 1) / 2 * 255
        img = img.astype("uint8")
        return img

    def process_image(self, img: np.ndarray):
        img = Image.fromarray(img.astype("uint8"))
        img = self.transform(img)
        img = img.unsqueeze(0).to(self.device)
        with torch.no_grad():
            img = 2 * img - 1
            # basevq
            quant, *_ = self.basevq.encode(img)
            basevq_rec = self.basevq.decode(quant)
            # vqgan
            quant, *_ = self.vqgan.encode(img)
            vqgan_rec = self.vqgan.decode(quant)
            # optvq
            quant, *_ = self.optvq.encode(img)
            optvq_rec = self.optvq.decode(quant)
        
        # tensor to PIL image
        img = self.tensor_to_image(img)
        basevq_rec = self.tensor_to_image(basevq_rec)
        vqgan_rec = self.tensor_to_image(vqgan_rec)
        optvq_rec = self.tensor_to_image(optvq_rec)
        return img, basevq_rec, vqgan_rec, optvq_rec

if __name__ == "__main__":
    # create the model handler
    handler = Handler(device=device)

    # create the interface
    with gr.Blocks() as demo:
        gr.Textbox(value="This demo shows the image reconstruction comparison between OptVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
                btn_demo1 = gr.Button(value="Run reconstruction")
            image_basevq = gr.Image(label="BaseVQ rec.")
            image_vqgan = gr.Image(label="VQGAN rec.")
            image_optvq = gr.Image(label="OptVQ rec.")
        btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_optvq])

        gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
        with gr.Row():
            with gr.Column():
                input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
                input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
                input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
                btn_demo2 = gr.Button(value="Run 2D example")
            output_nn = gr.Image(label="NN")
            output_optvq = gr.Image(label="OptVQ")
        
        # set the function
        input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
        input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
        input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
        btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])

    demo.launch()