import streamlit as st
import yaml, os, json, random, time, re, torch, random, warnings, shutil, sys
import seaborn as sns
import plotly.graph_objs as go
from itertools import chain
from PIL import Image
import pandas as pd
from io import BytesIO
from streamlit_extras.let_it_rain import rain
from annotated_text import annotated_text
from vouchervision.LeafMachine2_Config_Builder import write_config_file
from vouchervision.VoucherVision_Config_Builder import build_VV_config, run_demo_tests_GPT, run_demo_tests_Palm , TestOptionsGPT, TestOptionsPalm, check_if_usable, run_api_tests
from vouchervision.vouchervision_main import voucher_vision, voucher_vision_OCR_test
from vouchervision.general_utils import test_GPU, get_cfg_from_full_path, summarize_expense_report, create_google_ocr_yaml_config, validate_dir
from vouchervision.model_maps import ModelMaps
from vouchervision.API_validation import APIvalidation
from vouchervision.utils_hf import upload_to_drive, image_to_base64, setup_streamlit_config, save_uploaded_file, check_prompt_yaml_filename
########################################################################################################
###  ADDED FOR HUGGING FACE                                                                         ####
########################################################################################################
if 'uploader_idk' not in st.session_state:
    st.session_state['uploader_idk'] = 1
if 'input_list_small' not in st.session_state:
    st.session_state['input_list_small'] = []  
if 'input_list' not in st.session_state:
    st.session_state['input_list'] = []
if 'user_clicked_load_prompt_yaml' not in st.session_state:
    st.session_state['user_clicked_load_prompt_yaml'] = None
if 'new_prompt_yaml_filename' not in st.session_state:
    st.session_state['new_prompt_yaml_filename'] = None
MAX_GALLERY_IMAGES = 50
GALLERY_IMAGE_SIZE = 128
def content_input_images_hf():
    st.write('---')
    col1, col2 = st.columns([2,8])
    with col1:
        st.header('Run name')
        st.session_state.config['leafmachine']['project']['run_name'] = st.text_input("Run name", st.session_state.config['leafmachine']['project'].get('run_name', ''),
                                                                                        label_visibility='collapsed',key=995)
        st.write("Run name will be the name of the final zipped folder.")
        st.write('---')
        st.header('Input Images')
        st.session_state['dir_uploaded_images'] = os.path.join(st.session_state.dir_home,'uploads')
        st.session_state['dir_uploaded_images_small'] = os.path.join(st.session_state.dir_home,'uploads_small')
        uploaded_files = st.file_uploader("Upload Images", type=['jpg', 'jpeg'], accept_multiple_files=True, key=st.session_state['uploader_idk'])
        if uploaded_files:
            # Clear input image gallery and input list
            clear_image_gallery()
            # Process the new iamges
            for uploaded_file in uploaded_files:
                file_path = save_uploaded_file(st.session_state['dir_uploaded_images'], uploaded_file)
                st.session_state['input_list'].append(file_path)
                img = Image.open(file_path)
                img.thumbnail((GALLERY_IMAGE_SIZE, GALLERY_IMAGE_SIZE), Image.Resampling.LANCZOS)  
                file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
                st.session_state['input_list_small'].append(file_path_small)
                print(uploaded_file.name)
            # Set the local images to the uploaded images
            st.session_state.config['leafmachine']['project']['dir_images_local'] = st.session_state['dir_uploaded_images']
            n_images = len([f for f in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']) if os.path.isfile(os.path.join(st.session_state.config['leafmachine']['project']['dir_images_local'], f))])
            st.session_state['processing_add_on'] = f" {n_images} Images"
            uploaded_files = None
            st.session_state['uploader_idk'] += 1
            st.info(f"Processing **{n_images}** images from {st.session_state.config['leafmachine']['project']['dir_images_local']}")
        st.button("Use Test Image",help="This will clear any uploaded images and load the 1 provided test image.",on_click=use_test_image)
    with col2:
        if st.session_state['input_list_small']:
            st.subheader('Image Gallery')
            if len(st.session_state['input_list_small']) > MAX_GALLERY_IMAGES:
                # Only take the first 100 images from the list
                images_to_display = st.session_state['input_list_small'][:MAX_GALLERY_IMAGES]
            else:
                # If there are less than 100 images, take them all
                images_to_display = st.session_state['input_list_small']
            st.image(images_to_display)
def create_download_button(zip_filepath):
    with open(zip_filepath, 'rb') as f:
        bytes_io = BytesIO(f.read())
    st.download_button(
        label=f"Download Results for{st.session_state['processing_add_on']}",type='primary',
        data=bytes_io,
        file_name=os.path.basename(zip_filepath),
        mime='application/zip'
    )
def delete_directory(dir_path):
    try:
        shutil.rmtree(dir_path)
        st.session_state['input_list'] = []
        st.session_state['input_list_small'] = []
        # st.success(f"Deleted previously uploaded images, making room for new images: {dir_path}")
    except OSError as e:
        st.error(f"Error: {dir_path} : {e.strerror}")
def clear_image_gallery():
    delete_directory(st.session_state['dir_uploaded_images'])
    delete_directory(st.session_state['dir_uploaded_images_small'])
    validate_dir(st.session_state['dir_uploaded_images'])
    validate_dir(st.session_state['dir_uploaded_images_small'])
def use_test_image():
    st.info(f"Processing images from {os.path.join(st.session_state.dir_home,'demo','demo_images')}")
    st.session_state.config['leafmachine']['project']['dir_images_local'] = os.path.join(st.session_state.dir_home,'demo','demo_images')
    n_images = len([f for f in os.listdir(st.session_state.config['leafmachine']['project']['dir_images_local']) if os.path.isfile(os.path.join(st.session_state.config['leafmachine']['project']['dir_images_local'], f))])
    st.session_state['processing_add_on'] = f" {n_images} Images"
    clear_image_gallery()
    st.session_state['uploader_idk'] += 1
def create_download_button_yaml(file_path, selected_yaml_file):
    file_label = f"Download {selected_yaml_file}"
    with open(file_path, 'rb') as f:
        st.download_button(
            label=file_label,
            data=f,
            file_name=os.path.basename(file_path),
            mime='application/x-yaml',use_container_width=True
        )
def upload_local_prompt_to_server(dir_prompt):
    uploaded_file = st.file_uploader("Upload a custom prompt file", type=['yaml'])
    if uploaded_file is not None:
        # Check the file extension
        file_name = uploaded_file.name
        if file_name.endswith('.yaml'):
            file_path = os.path.join(dir_prompt, file_name)
            
            # Save the file
            with open(file_path, 'wb') as f:
                f.write(uploaded_file.getbuffer())
            st.success(f"Saved file {file_name} in {dir_prompt}")
        else:
            st.error("Please upload a .yaml file that you previously created using this Prompt Builder tool.")
def refresh():
    st.session_state['uploader_idk'] += 1
    st.write('')
# def display_image_gallery():
#     # Initialize the container
#     con_image = st.empty()
    
#     # Start the div for the image grid
#     img_grid_html = """
#     
#     """
    
#     # Loop through each image in the input list
#     # with con_image.container():
#     for image_path in st.session_state['input_list']:
#         # Open the image and create a thumbnail
#         img = Image.open(image_path)
#         img.thumbnail((120, 120), Image.Resampling.LANCZOS)  
#         # Convert the image to base64
#         base64_image = image_to_base64(img)
#         # Append the image to the grid HTML
#         # img_html = f"""
#         #     
#         #         

#         #     
#         #     """
#         img_html = f"""
#                 

#             """
#         img_grid_html += img_html
#         # st.markdown(img_html, unsafe_allow_html=True)
    
#     # Close the div for the image grid
#     img_grid_html += "