import streamlit as st
import kornia
import torch
from torch import nn
from torchvision.transforms import functional as F
from torchvision.utils import make_grid
from streamlit_ace import st_ace
from PIL import Image
import numpy as np

IS_LOCAL = False  # Change this

@st.cache_data
def set_transform(content):
    try:
        transform = eval(content, {"kornia": kornia, "nn": nn}, None)
    except Exception as e:
        st.write(f"There was an error: {e}")
        transform = nn.Sequential()
    return transform

st.set_page_config(page_title="Kornia Augmentations Demo", layout="wide")

st.markdown("# Kornia Augmentations Demo")
st.sidebar.markdown(
    "[Kornia](https://github.com/kornia/kornia) is a *differentiable* computer vision library for PyTorch."
)

uploaded_file = st.sidebar.file_uploader("Choose a file", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
    im = Image.open(uploaded_file)
else:
    im = Image.open("./images/pretty_bird.jpg")

scaler = int(im.height / 2)
st.sidebar.image(im, caption="Input Image", width=256)

# Convert PIL Image to torch tensor
image = torch.from_numpy(np.array(im).transpose((2, 0, 1))).float() / 255.0

# batch size is just for show
batch_size = st.sidebar.slider("batch_size", min_value=4, max_value=16, value=8)

gpu = st.sidebar.checkbox("Use GPU!", value=False)
if not gpu:
    st.sidebar.markdown("Using CPU for operations.")
    device = torch.device("cpu")
else:
    if not IS_LOCAL or not torch.cuda.is_available():
        st.sidebar.markdown("GPU not available, using CPU.")
        device = torch.device("cpu")
    else:
        st.sidebar.markdown("Running on GPU~")
        device = torch.device("cuda:0")

predefined_transforms = [
    """
nn.Sequential(
   kornia.augmentation.RandomAffine(degrees=360,p=0.5),
   kornia.augmentation.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.2, hue=0.3, p=1)
)
# p=0.5 is the probability of applying the transformation
""",
    """
nn.Sequential(
   kornia.augmentation.RandomErasing(scale=(.4, .8), ratio=(.3, 1/.3), p=0.5),
)
""",
    """
nn.Sequential(
   kornia.augmentation.RandomErasing(scale=(.4, .8), ratio=(.3, 1/.3), p=1, same_on_batch=True),
)
#By setting same_on_batch=True you can apply the same transform across the batch
""",
    f"""
nn.Sequential(
    kornia.augmentation.RandomResizedCrop(size=({scaler}, {scaler}), scale=(3., 3.), ratio=(2., 2.), p=1.),
    kornia.augmentation.RandomHorizontalFlip(p=0.7),
    kornia.augmentation.RandomGrayscale(p=0.5),
)
"""
]

selected_transform = st.selectbox(
    "Pick an augmentation pipeline example:", predefined_transforms
)

st.write("Transform to apply:")
readonly = False
content = st_ace(
    value=selected_transform,
    height=150,
    language="python",
    keybinding="vscode",
    show_gutter=True,
    show_print_margin=True,
    wrap=False,
    auto_update=False,
    readonly=readonly,
)

if content:
    transform = set_transform(content)

process = st.button("Next Batch")

# Fake dataloader
image_batch = torch.stack(batch_size * [image])
image_batch = image_batch.to(device)

transformeds = None
try:
    transformeds = transform(image_batch)
except Exception as e:
    st.write(f"There was an error: {e}")

cols = st.columns(4)
if transformeds is not None:
    for i, x in enumerate(transformeds):
        i = i % 4
        img_np = x.cpu().numpy().transpose((1, 2, 0))
        img_np = (img_np * 255).astype(np.uint8)
        cols[i].image(img_np, use_column_width=True)

st.markdown(
    "There are a lot more transformations available: [Documentation](https://kornia.readthedocs.io/en/latest/augmentation.module.html)"
)
st.markdown(
    "Kornia can do a lot more than augmentations~ [Check it out](https://kornia.readthedocs.io/en/latest/get-started/introduction.html)"
)