import gradio as gr
import numpy as np
from torchvision import transforms
import torch
from helpers import *
import sys
import csv
from monoscene.monoscene import MonoScene

csv.field_size_limit(sys.maxsize)
torch.set_grad_enabled(False)


model = MonoScene.load_from_checkpoint(
        "monoscene_kitti.ckpt",
        dataset="kitti",
        n_classes=20,
        feature = 64,
        project_scale = 4,
        full_scene_size = (256, 256, 32),
    )

img_W, img_H = 1220, 370


def predict(img):
    img = np.array(img, dtype=np.float32, copy=False) / 255.0

    normalize_rgb = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
    img = normalize_rgb(img)
   
    batch = get_projections(img_W, img_H)
    batch["img"] = img
    for k in batch:
        batch[k] = batch[k].unsqueeze(0)#.cuda()

    pred = model(batch).squeeze()
    fig = draw(pred, batch['fov_mask_2'])


    return fig
   

description = """
MonoScene Demo on SemanticKITTI Validation Set (Sequence 08), which uses the <b>camera parameters of Sequence 08</b>.
Due to the <b>CPU-only</b> inference, it might take up to 20s to predict a scene. \n
This is a <b>smaller</b> model with half resolution and <b>w/o 3D CRP</b>. You can find the full model at: <a href="https://huggingface.co/spaces/CVPR/MonoScene">https://huggingface.co/spaces/CVPR/MonoScene</a>
<center>
    <a href="https://cv-rits.github.io/MonoScene/">
        <img style="display:inline" alt="Project page" src="https://img.shields.io/badge/Project%20Page-MonoScene-red">
    </a>
    <a href="https://arxiv.org/abs/2112.00726"><img style="display:inline" src="https://img.shields.io/badge/arXiv%20%2B%20supp-2112.00726-purple"></a>
    <a href="https://github.com/cv-rits/MonoScene"><img style="display:inline" src="https://img.shields.io/github/stars/cv-rits/MonoScene?style=social"></a>
</center>
"""
title = "MonoScene Lite - Half resolution, w/o 3D CRP"
article="""
<center>
    <img src='https://visitor-badge.glitch.me/badge?page_id=anhquancao.MonoScene_lite&left_color=darkmagenta&right_color=purple' alt='visitor badge'>
</center>
"""

examples = [
    'images/08/001385.jpg',
    'images/08/000295.jpg',
    'images/08/002505.jpg',
    'images/08/000085.jpg',
    'images/08/000290.jpg',
    'images/08/000465.jpg',
    'images/08/000790.jpg',
    'images/08/001005.jpg',
    'images/08/001380.jpg',
    'images/08/001530.jpg',
    'images/08/002360.jpg',
    'images/08/004059.jpg',
    'images/08/003149.jpg',
    'images/08/001446.jpg',
    'images/08/000010.jpg',
    'images/08/001122.jpg',
    'images/08/003533.jpg',
    'images/08/003365.jpg',
    'images/08/002944.jpg',
    'images/08/000822.jpg',
    'images/08/000103.jpg',
    'images/08/002716.jpg',
    'images/08/000187.jpg',
    'images/08/002128.jpg',
    'images/08/000511.jpg',
    'images/08/000618.jpg',
    'images/08/002010.jpg',
    'images/08/000234.jpg',
    'images/08/001842.jpg',
    'images/08/001687.jpg',
    'images/08/003929.jpg',
    'images/08/002272.jpg',
]




demo = gr.Interface(
    predict, 
    gr.Image(shape=(1220, 370)), 
    gr.Plot(),  
    article=article,
    title=title,
    enable_queue=True,
    cache_examples=False,
    live=False,
    examples=examples,
    description=description)


demo.launch(enable_queue=True, debug=False)