import streamlit as st
import pandas as pd
import plotly.express as px
import sahi.utils.file
from PIL import Image
from sahi import AutoDetectionModel
from utils import sahi_yolov8m_inference
from ultralyticsplus.hf_utils import download_from_hub

IMAGE_TO_URL = {
    'factory_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/factory-pid.png',
    'plant_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
    'processing_pid.png' : 'https://d1afc1j4569hs1.cloudfront.net/processing-pid.png',
    'prediction_visual.png' : 'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png'
    }

st.set_page_config(
    page_title="P&ID Object Detection",
    layout="wide",
    initial_sidebar_state="expanded"
    )

st.title('P&ID Object Detection')
st.subheader(' Identify valves and pumps with deep learning model ', divider='rainbow')
st.markdown(
    """
    <a href='https://cl.linkedin.com/in/daniel-cerda-escobar' target='_blank'><img src="https://img.icons8.com/fluency/48/000000/linkedin.png" height="30"></a> 
    </p>
    """,
    unsafe_allow_html=True,
)

@st.cache_resource(show_spinner=False)
def get_model(postprocess_match_threshold):
    yolov8_model_path = download_from_hub('DanielCerda/pid_yolov8')
    detection_model = AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        model_path=yolov8_model_path,
        confidence_threshold=postprocess_match_threshold,
        device="cpu",
    )
    return detection_model
    
@st.cache_data(show_spinner=False)
def download_comparison_images():
    sahi.utils.file.download_from_url(
        'https://d1afc1j4569hs1.cloudfront.net/plant-pid.png',
        'plant_pid.png',
    )
    sahi.utils.file.download_from_url(
        'https://d1afc1j4569hs1.cloudfront.net/prediction_visual.png',
        'prediction_visual.png',
    )

download_comparison_images()

# initialize prediction visual data
coco_df = pd.DataFrame({
    'category' : ['centrifugal-pump','centrifugal-pump','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve','gate-valve'],
    'score' : [0.88, 0.85, 0.87, 0.87, 0.86, 0.86, 0.85, 0.84, 0.81, 0.81, 0.76]
})
output_df = pd.DataFrame({
    'category':['ball-valve', 'butterfly-valve', 'centrifugal-pump', 'check-valve', 'gate-valve'],
    'count':[0, 0, 2, 0, 9],
    'percentage':[0, 0, 18.2, 0, 81.8] 
})

# session state
if "output_1" not in st.session_state:
    img_1 = Image.open('plant_pid.png')
    st.session_state["output_1"] = img_1.resize((4960,3508))

if "output_2" not in st.session_state:
    img_2 = Image.open('prediction_visual.png')
    st.session_state["output_2"] = img_2.resize((4960,3508))

if "output_3" not in st.session_state:
    st.session_state["output_3"] = coco_df
    
if "output_4" not in st.session_state:
    st.session_state["output_4"] = output_df
    

col1, col2, col3 = st.columns(3, gap='medium')
with col1:
    with st.expander('How to use it'):
        st.markdown(
        '''
        1) Upload or select any example diagram  👆🏻
        2) Set model parameters 📈
        3) Press to perform inference   🚀
        4) Visualize model predictions  🔎
        '''
        )   

st.write('##')
   
col1, col2, col3 = st.columns(3, gap='large')
with col1:
    st.markdown('##### Set Input Image')
    # set input image by upload
    image_file = st.file_uploader(
        'Upload your P&ID', type = ['jpg','jpeg','png']
    )
    # set input images from examples
    def radio_func(option):
        option_to_id = {
            'factory_pid.png' : 'A',
            'plant_pid.png' : 'B',
            'processing_pid.png' : 'C',
        }
        return option_to_id[option]
    radio = st.radio(
        'Select from the following examples',
        options = ['factory_pid.png', 'plant_pid.png', 'processing_pid.png'],
        format_func = radio_func,
    )
with col2:
    # visualize input image
    if image_file is not None:
        image = Image.open(image_file)
    else:
        image = sahi.utils.cv.read_image_as_pil(IMAGE_TO_URL[radio])
    st.markdown('##### Preview')
    with st.container(border = True):
        st.image(image, use_column_width = True)
        
with col3:
    # set SAHI parameters
    st.markdown('##### Set model parameters')
    slice_number = st.select_slider(
        'Slices per Image', 
        options = [
            '1',
            '4',
            '16',
            '64',            
        ],
        value = '4'
    )
    overlap_ratio = st.slider(
        label = 'Slicing Overlap Ratio', 
        min_value=0.0, 
        max_value=0.5, 
        value=0.1, 
        step=0.1
    )
    postprocess_match_threshold = st.slider(
        label = 'Confidence Threshold',
        min_value = 0.0,
        max_value = 1.0,
        value = 0.85,
        step = 0.05
    )

st.write('##')

col1, col2, col3 = st.columns([4, 1, 4])
with col2:
    submit = st.button("🚀 Perform Prediction")
    
if submit:
    # perform prediction
    with st.spinner(text="Downloading model weights ... "):
        detection_model = get_model(postprocess_match_threshold)
        
    slice_size = int(4960/(float(slice_number)**0.5))
    image_size = 4960

    with st.spinner(text="Performing prediction ... "):
        output_visual,coco_df,output_df = sahi_yolov8m_inference(
            image,
            detection_model,
            image_size=image_size,
            slice_height=slice_size,
            slice_width=slice_size,
            overlap_height_ratio=overlap_ratio,
            overlap_width_ratio=overlap_ratio,
        )

    st.session_state["output_1"] = image
    st.session_state["output_2"] = output_visual
    st.session_state["output_3"] = coco_df
    st.session_state["output_4"] = output_df

st.write('##')

col1, col2, col3 = st.columns([1, 5, 1], gap='small')
with col2:  
    st.markdown(f"#### Object Detection Result")
    with st.container(border = True):
        tab1, tab2, tab3, tab4 = st.tabs(['Original Image','Inference Prediction','Data','Insights'])
        with tab1:
            st.image(st.session_state["output_1"])
        with tab2:
            st.image(st.session_state["output_2"])
        with tab3:
            col1,col2,col3 = st.columns([1,2,1])
            with col2:
                st.dataframe(
                    st.session_state["output_3"],
                    column_config = {
                        'category' : 'Predicted Category',
                        'score' : 'Confidence',
                    },
                    use_container_width = True,
                    hide_index = True,
                )
        with tab4:
            col1,col2,col3 = st.columns([1,5,1])
            with col2:
                chart_data = st.session_state["output_4"]
                fig = px.bar(chart_data, x='category', y='count', color='category')
                fig.update_layout(title='Objects Detected',xaxis_title=None, yaxis_title=None, showlegend=False,yaxis=dict(tick0=0,dtick=1),bargap=0.5)
                st.plotly_chart(fig,use_container_width=True, theme='streamlit' )