import base64
from huggingface_hub import hf_hub_download
import streamlit as st
import io
import gc
import json

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

MODEL_REPO = 'BlinkDL/clip-guided-binary-autoencoder'

import torch, types
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.nn import functional as F
import torchvision as vision
import torchvision.transforms as transforms
from torchvision.transforms import functional as VF

device = 'cuda' if torch.cuda.is_available() else 'cpu'

IMG_BITS = 13


class ResBlock(nn.Module):

    def __init__(self, c_x, c_hidden):
        super().__init__()
        self.B0 = nn.BatchNorm2d(c_x)
        self.C0 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1)
        self.C1 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1)
        self.C2 = nn.Conv2d(c_x, c_hidden, kernel_size=3, padding=1)
        self.C3 = nn.Conv2d(c_hidden, c_x, kernel_size=3, padding=1)

    def forward(self, x):
        ACT = F.mish
        x = x + self.C1(ACT(self.C0(ACT(self.B0(x)))))
        x = x + self.C3(ACT(self.C2(x)))
        return x


class REncoderSmall(nn.Module):

    def __init__(self):
        super().__init__()
        dd = 8
        self.Bxx = nn.BatchNorm2d(dd * 64)

        self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
        self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
        self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)

        self.B00 = nn.BatchNorm2d(dd * 4)
        self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
        self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)
        self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
        self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)

        self.B10 = nn.BatchNorm2d(dd * 16)
        self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
        self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)
        self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
        self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)

        self.B20 = nn.BatchNorm2d(dd * 64)
        self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
        self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
        self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
        self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)

        self.COUT = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1)

    def forward(self, img):
        ACT = F.mish

        x = self.CIN(img)
        xx = self.Bxx(F.pixel_unshuffle(x, 8))
        x = x + self.Cx1(ACT(self.Cx0(x)))

        x = F.pixel_unshuffle(x, 2)
        x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
        x = x + self.C03(ACT(self.C02(x)))

        x = F.pixel_unshuffle(x, 2)
        x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
        x = x + self.C13(ACT(self.C12(x)))

        x = F.pixel_unshuffle(x, 2)
        x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
        x = x + self.C23(ACT(self.C22(x)))

        x = self.COUT(x + xx)
        return torch.sigmoid(x)


class RDecoderSmall(nn.Module):

    def __init__(self):
        super().__init__()
        dd = 8
        self.CIN = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1)

        self.B00 = nn.BatchNorm2d(dd * 64)
        self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
        self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)
        self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1)
        self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1)

        self.B10 = nn.BatchNorm2d(dd * 16)
        self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
        self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)
        self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1)
        self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1)

        self.B20 = nn.BatchNorm2d(dd * 4)
        self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
        self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)
        self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1)
        self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1)

        self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
        self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
        self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)

    def forward(self, code):
        ACT = F.mish
        x = self.CIN(code)

        x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
        x = x + self.C03(ACT(self.C02(x)))
        x = F.pixel_shuffle(x, 2)

        x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
        x = x + self.C13(ACT(self.C12(x)))
        x = F.pixel_shuffle(x, 2)

        x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
        x = x + self.C23(ACT(self.C22(x)))
        x = F.pixel_shuffle(x, 2)

        x = x + self.Cx1(ACT(self.Cx0(x)))
        x = self.COUT(x)

        return torch.sigmoid(x)


class REncoderLarge(nn.Module):

    def __init__(self, dd, ee, ff):
        super().__init__()
        self.CXX = nn.Conv2d(3, dd, kernel_size=3, padding=1)
        self.BXX = nn.BatchNorm2d(dd)
        self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
        self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1)
        self.R0 = ResBlock(dd * 4, ff)
        self.R1 = ResBlock(dd * 16, ff)
        self.R2 = ResBlock(dd * 64, ff)
        self.CZZ = nn.Conv2d(dd * 64, IMG_BITS, kernel_size=3, padding=1)

    def forward(self, x):
        ACT = F.mish
        x = self.BXX(self.CXX(x))

        x = x + self.CX1(ACT(self.CX0(x)))
        x = F.pixel_unshuffle(x, 2)
        x = self.R0(x)
        x = F.pixel_unshuffle(x, 2)
        x = self.R1(x)
        x = F.pixel_unshuffle(x, 2)
        x = self.R2(x)

        x = self.CZZ(x)
        return torch.sigmoid(x)


class RDecoderLarge(nn.Module):

    def __init__(self, dd, ee, ff):
        super().__init__()
        self.CZZ = nn.Conv2d(IMG_BITS, dd * 64, kernel_size=3, padding=1)
        self.BZZ = nn.BatchNorm2d(dd * 64)
        self.R0 = ResBlock(dd * 64, ff)
        self.R1 = ResBlock(dd * 16, ff)
        self.R2 = ResBlock(dd * 4, ff)
        self.CX0 = nn.Conv2d(dd, ee, kernel_size=3, padding=1)
        self.CX1 = nn.Conv2d(ee, dd, kernel_size=3, padding=1)
        self.CXX = nn.Conv2d(dd, 3, kernel_size=3, padding=1)

    def forward(self, x):
        ACT = F.mish
        x = self.BZZ(self.CZZ(x))

        x = self.R0(x)
        x = F.pixel_shuffle(x, 2)
        x = self.R1(x)
        x = F.pixel_shuffle(x, 2)
        x = self.R2(x)
        x = F.pixel_shuffle(x, 2)
        x = x + self.CX1(ACT(self.CX0(x)))

        x = self.CXX(x)
        return torch.sigmoid(x)


@st.cache
def prepare_model(model_prefix):
    gc.collect()

    if model_prefix == 'out-v7c_d8_256-224-13bit-OB32x0.5-745':
        R_ENCODER, R_DECODER = REncoderSmall(), RDecoderSmall()
    else:
        if 'd16_512' in model_prefix:
            dd, ee, ff = 16, 64, 512
        elif 'd32_1024' in model_prefix:
            dd, ee, ff = 32, 128, 1024
        R_ENCODER = REncoderLarge(dd, ee, ff)
        R_DECODER = RDecoderLarge(dd, ee, ff)

    encoder = R_ENCODER.eval().to(device)
    decoder = R_DECODER.eval().to(device)

    encoder.load_state_dict(
        torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-E.pth')))
    decoder.load_state_dict(
        torch.load(hf_hub_download(MODEL_REPO, f'{model_prefix}-D.pth')))

    return encoder, decoder


def compute_padding(img_shape):
    hsize, vsize = (img_shape[1] + 7) // 8 * 8, (img_shape[0] + 7) // 8 * 8
    hpad, vpad = hsize - img_shape[1], vsize - img_shape[0]
    left, top = hpad // 2, vpad // 2
    right, bottom = hpad - left, vpad - top
    return left, top, right, bottom


def encode(model_prefix, img, keep_shape):
    gc.collect()
    encoder, _ = prepare_model(model_prefix)

    with torch.no_grad():
        img = VF.pil_to_tensor(img.convert("RGB"))
        img = VF.convert_image_dtype(img)
        img = img.unsqueeze(0).to(device)
        img_shape = img.shape[2:]

        if keep_shape:
            left, top, right, bottom = compute_padding(img_shape)
            img = VF.pad(img, [left, top, right, bottom], padding_mode='edge')
        else:
            img = VF.resize(img, [224, 224])

        z = torch.floor(encoder(img) + 0.5)

    with io.BytesIO() as buffer:
        np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
        z_b64 = base64.b64encode(buffer.getvalue()).decode()

    return json.dumps({
        "img_shape": img_shape,
        "z_shape": z.shape[2:],
        "keep_shape": keep_shape,
        "data": z_b64,
    })


def decode(model_prefix, z_str):
    gc.collect()
    _, decoder = prepare_model(model_prefix)

    z_json = json.loads(z_str)
    with io.BytesIO() as buffer:
        buffer.write(base64.b64decode(z_json["data"]))
        buffer.seek(0)
        z = np.load(buffer)
    img_shape = z_json["img_shape"]
    z_shape = z_json["z_shape"]
    keep_shape = z_json["keep_shape"]

    z = np.unpackbits(z)[:IMG_BITS * z_shape[0] * z_shape[1]].astype('float')
    z = z.reshape([1, IMG_BITS] + z_shape)

    img = decoder(torch.Tensor(z).to(device))

    if keep_shape:
        left, top, right, bottom = compute_padding(img_shape)
        img = img[0, :, top:img.shape[2] - bottom, left:img.shape[3] - right]
    else:
        img = img[0]

    return VF.to_pil_image(img)


st.title("Clip Guided Binary Autoencoder")
st.write(
    "Model is from [@BlinkDL](https://huggingface.co/BlinkDL/clip-guided-binary-autoencoder)"
)
model_prefix = st.selectbox('The model to use',
                            ('out-v7c_d8_256-224-13bit-OB32x0.5-745',
                             'out-v7d_d16_512-224-13bit-OB32x0.5-2487',
                             'out-v7d_d32_1024-224-13bit-OB32x0.5-5560'))

encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"])

with encoder_tab:
    col_in, col_out = st.columns(2)
    keep_shape = col_in.checkbox(
        'Use original size of input image instead of rescaling (Experimental)')
    uploaded_file = col_in.file_uploader('Choose an Image')
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        col_in.image(image, 'Input Image')
        z_str = encode(model_prefix, image, keep_shape)
        col_out.write("Encoded to:")
        col_out.code(z_str, language=None)
        col_out.image(decode(model_prefix, z_str), 'Output Image preview')

with decoder_tab:
    col_in, col_out = st.columns(2)
    z_str = col_in.text_area('Paste encoded string here:')
    if len(z_str) > 0:
        image = decode(model_prefix, z_str)
        col_out.image(image, 'Output Image')