# Copyright (c) OpenMMLab. All rights reserved.
import os

os.system('pip install -U openmim')
os.system('mim install mmengine')
os.system('mim install "mmcv>=2.0.0"')

from argparse import ArgumentParser

import cv2
from mmengine.model.utils import revert_sync_batchnorm

from mmseg.apis import inference_model, init_model
from mmseg.apis.inference import show_result_pyplot
import torch
import time
import gradio as gr
import plotly.express as px
import json

def main():
    parser = ArgumentParser()
    parser.add_argument('--config', default='configs/snnet/setr_naive_512x512_160k_b16_ade20k_deit_3_s_l_224_snnetv2.py', help='Config file')
    parser.add_argument('--checkpoint', help='Checkpoint file', default='setr_naive_512x512_160k_b16_ade20k_snnetv2_deit3_s_l_lora_16_iter_160000.pth')
    # parser.add_argument('--video', help='Video file or webcam id')

    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='cityscapes',
        help='Color palette used for segmentation map')
    parser.add_argument(
        '--show', action='store_true', help='Whether to show draw result')
    parser.add_argument(
        '--show-wait-time', default=1, type=int, help='Wait time after imshow')
    parser.add_argument(
        '--output-file', default=None, type=str, help='Output video file path')
    parser.add_argument(
        '--output-fourcc',
        default='MJPG',
        type=str,
        help='Fourcc of the output video')
    parser.add_argument(
        '--output-fps', default=30, type=int, help='FPS of the output video')
    parser.add_argument(
        '--output-height',
        default=-1,
        type=int,
        help='Frame height of the output video')
    parser.add_argument(
        '--output-width',
        default=-1,
        type=int,
        help='Frame width of the output video')
    parser.add_argument(
        '--opacity',
        type=float,
        default=0.5,
        help='Opacity of painted segmentation map. In (0, 1] range.')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_model(args.config, args.checkpoint, device=args.device)
    if args.device == 'cpu':
        model = revert_sync_batchnorm(model)

    from mmseg.models.backbones.snnet import get_stitch_configs_bidirection
    stitch_configs_info, _, _, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection([12, 24])

    stitch_configs_info = {i: cfg for i, cfg in enumerate(stitch_configs_info)}


    with open('./model_flops/snnet_flops_setr_naive_512x512_160k_b16_ade20k_deit_3_s_l_224_snnetv2.json', 'r') as f:
        flops_params = json.load(f)

    with open('./results/eval_single_scale_20230507_235400.json', 'r') as f:
        results = json.load(f)

    config_ids = list(results.keys())
    flops_res = {}
    eval_res = {}
    total_data = {}
    for i, cfg_id in enumerate(config_ids):
        flops = flops_params[cfg_id]
        miou_res = results[cfg_id]['metric']['mIoU'] * 100
        eval_res[int(cfg_id)] = miou_res
        flops_res[int(cfg_id)] = flops / 1e9
        total_data[int(cfg_id)] = [flops // 1e9, miou_res]


    def visualize_stitch_pos(stitch_id):
        if stitch_id == 13:
            # 13 is equivalent to 0
            stitch_id = 0

        names = [f'ID {key}' for key in flops_res.keys()]

        fig = px.scatter(x=flops_res.values(), y=eval_res.values(), hover_name=names)
        fig.update_layout(
            title=f"SN-Netv2 - Stitch ID - {stitch_id}",
            title_x=0.5,
            xaxis_title="GFLOPs",
            yaxis_title="mIoU",
            font=dict(
                family="Courier New, monospace",
                size=18,
                color="RebeccaPurple"
            ),
            legend=dict(
                yanchor="bottom",
                y=0.99,
                xanchor="left",
                x=0.01),
        )
        # continent, DarkSlateGrey
        fig.update_traces(marker=dict(size=10,
                                      line=dict(width=2)),
                          selector=dict(mode='markers'))

        fig.add_scatter(x=[flops_res[stitch_id]], y=[eval_res[stitch_id]], mode='markers', marker=dict(size=15), name='Current Stitch')
        return fig


    def segment_video(video, stitch_id):

        if stitch_id == 13:
            # 13 is equivalent to 0
            stitch_id = 0

        model.backbone.reset_stitch_id(stitch_id)
        output_video_path = './temp_video.avi'
        cap = cv2.VideoCapture(video)
        assert (cap.isOpened())
        input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
        input_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
        input_fps = cap.get(cv2.CAP_PROP_FPS)


        fourcc = cv2.VideoWriter_fourcc(*args.output_fourcc)
        output_fps = args.output_fps if args.output_fps > 0 else input_fps
        output_height = args.output_height if args.output_height > 0 else int(
            input_height)
        output_width = args.output_width if args.output_width > 0 else int(
            input_width)
        writer = cv2.VideoWriter(output_video_path, fourcc, output_fps,
                                 (output_width, output_height), True)

        try:
            while True:
                start_time = time.time()
                flag, frame = cap.read()
                if not flag:
                    break

                # test a single image
                result = inference_model(model, frame)

                # blend raw image and prediction
                draw_img = show_result_pyplot(model, frame, result,
                                              show=False,
                                              with_labels=False,
                                              )

                if draw_img.shape[0] != output_height or draw_img.shape[
                    1] != output_width:
                    draw_img = cv2.resize(draw_img,
                                          (output_width, output_height))
                writer.write(draw_img)
        finally:
            if writer:
                writer.release()
            cap.release()

        fig = visualize_stitch_pos(stitch_id)
        
        return output_video_path, fig

    def segment_image(image, stitch_id):
        if stitch_id == 13:
            # 13 is equivalent to 0
            stitch_id = 0

        model.backbone.reset_stitch_id(stitch_id)
        result = inference_model(model, image)
        draw_img = show_result_pyplot(model, image, result,
                                      show=False,
                                      with_labels=True,
                                      )
        fig = visualize_stitch_pos(stitch_id)
        return draw_img, fig



    with gr.Blocks() as image_demo:
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label='Input Image')
                stitch_slider = gr.Slider(minimum=0, maximum=134, step=1, label="Stitch ID")
                with gr.Row():
                    clear_button = gr.ClearButton()
                    submit_button = gr.Button()

            with gr.Column():
                image_output = gr.Image(label='Segmentation Results')
                stitch_plot = gr.Plot(label='Stitch Position')

            submit_button.click(
                fn=segment_image,
                inputs=[image_input, stitch_slider],
                outputs=[image_output, stitch_plot],
            )

            stitch_slider.change(
                fn=visualize_stitch_pos,
                inputs=[stitch_slider],
                outputs=[stitch_plot],
                show_progress=False
            )

            clear_button.click(
                lambda: [None, 0, None, None],
                outputs=[image_input, stitch_slider, image_output, stitch_plot],
            )

        gr.Examples(
            [
                ['./demo_1.jpg', 0],
                ['./demo_2.jpg', 1],
                ['./demo_3.jpg', 93],
                ['./demo_4.jpg', 3],
            ],
            inputs=[
                image_input,
                stitch_slider
            ],
            outputs=[
                image_input,
                stitch_plot
            ],
        )

    with gr.Blocks() as demo:
        with gr.Column():
            gr.HTML("""
                <h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; ">Stitched ViTs are Flexible Vision Backbones</h1>
                <div align="center"> <img align="center" src='file/gradio_banner.png' width="70%"> </div>
                <h3 align="center" >This is the semantic segmentation demo page of SN-Netv2, a flexible vision backbone that allows for 100+ runtime speed and performance trade-offs. You can also run this gradio demo on your local GPUs at <a href="https://github.com/ziplab/SN-Netv2">https://github.com/ziplab/SN-Netv2</a>, Paper link: <a href="https://arxiv.org/abs/2307.00154">https://arxiv.org/abs/2307.00154</a>.</h3>
                """)
        tabbed_page = gr.TabbedInterface([image_demo,], ['Image'])


    demo.launch(allowed_paths=['./'])


if __name__ == '__main__':
    main()