import gradio as gr
import gradio.inputs as grinputs
import gradio.outputs as groutputs

import numpy as np
import json

import torch
from torchvision import transforms

import utils 
import utils_img

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
np.random.seed(0)

print('Building backbone and normalization layer...')
backbone = utils.build_backbone(path='dino_r50.pth')
normlayer = utils.load_normalization_layer(path='out2048.pth')
model = utils.NormLayerWrapper(backbone, normlayer)

print('Building the hypercone...')
FPR = 1e-6
angle = 1.462771101178447 # value for FPR=1e-6 and D=2048
rho = 1 + np.tan(angle)**2
# angle = utils.pvalue_angle(2048, 1, proba=FPR)
carrier = torch.randn(1, 2048)
carrier /= torch.norm(carrier, dim=1, keepdim=True)

default_transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1):
    img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
    img = img_orig.clone().to(device, non_blocking=True) 
    img.requires_grad = True
    optimizer = torch.optim.Adam([img], lr=1e-2)

    for iteration in range(epochs):
        x = utils_img.ssim_attenuation(img, img_orig)
        x = utils_img.psnr_clip(x, img_orig, psnr)

        ft = model(x) # BxCxWxH -> BxD

        dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
        norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
        cosines = torch.abs(dot_product/norm)
        log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
        loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B

        loss_l2_img = torch.norm(x - img_orig)**2 # CxWxH -> 1
        loss = lambda_w*loss_R + lambda_i*loss_l2_img
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logs = {
            "keyword": "img_optim",
            "iteration": iteration,
            "loss": loss.item(),
            "loss_R": loss_R.item(),
            "loss_l2_img": loss_l2_img.item(),
            "log10_pvalue": log10_pvalue.item(),
        }
        print("__log__:%s" % json.dumps(logs))

    img = utils_img.ssim_attenuation(img, img_orig)
    img = utils_img.psnr_clip(img, img_orig, psnr)
    img = utils_img.round_pixel(img)
    img = img.squeeze(0).detach().cpu()
    img = transforms.ToPILImage()(utils_img.unnormalize_img(img).squeeze(0))

    return img

def decode(image):
    img = default_transform(image).to(device, non_blocking=True).unsqueeze(0)
    ft = model(img) # BxCxWxH -> BxD

    dot_product = (ft @ carrier.T) # BxD @ Dx1 -> Bx1
    norm = torch.norm(ft, dim=-1, keepdim=True) # Bx1
    cosines = torch.abs(dot_product/norm)
    log10_pvalue = np.log10(utils.cosine_pvalue(cosines.item(), ft.shape[-1]))
    loss_R = -(rho * dot_product**2 - norm**2) # B-B -> B

    text_marked = "marked" if loss_R < 0 else "unmarked"
    return 'Image is {s}, with p-value={p}'.format(s=text_marked, p=10**log10_pvalue)



def on_submit(image, mode):
    print('{} mode'.format(mode))
    if mode=='Encode':
        return encode(image), 'Successfully encoded'
    else:
        return image, decode(image)

iface = gr.Interface(
    fn=on_submit, 
    inputs=[
        grinputs.Image(), 
        grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")], 
    outputs=[
        groutputs.Image(label='Watermarked image'), 
        groutputs.Textbox(label='Information')],
    allow_screenshot=False, 
    allow_flagging="auto",
    )
iface.launch()