import streamlit as st
from transformers import pipeline
from PIL import Image

def main():
    st.set_page_config(page_title="Unmasked the Target Customers", page_icon="🦜")
    st.header("Turn the photos taken in the campaign to useful marketing insights")
    uploaded_file = st.file_uploader("Select an Image...")


    # define a function to extract the sub-image using
    def extract_subimage(image, xmin, xmax, ymin, ymax):
        # crop the sub-image using the provided coordinates
        sub_image = image.crop((xmin, ymin, xmax, ymax))
        # return the extracted sub-image
        return sub_image
    
    def pipeline_1_final(image_lst):
        pipe = pipeline("object-detection", model="hustvl/yolos-tiny")
        preds = pipe(image)
        person_count = 0
        sub_image_lst = []
        for pred in preds:
            if pred['label'] == 'person':
                person_count +=1
                box = pred['box']
                xmin, ymin, xmax, ymax = box.values()
                sub_image = extract_subimage(image,xmin, xmax, ymin, ymax)
                sub_image_lst += [sub_image]
        return sub_image_lst, person_count
    
    def pipeline_2_final(image_lst):
        age_lst = []
        age_mapping = {"0-2": "lower than 10",
                  "3-9": "lower than 10",
                  "10-19":"10-19",
                  "20-29":"20-29",
                  "30-39":"30-39",
                  "40-49":"40-49",
                  "50-59":"50-59",
                  "60-69":"60-69",
                  "more than 70" : "70 or above"}
        pipe = pipeline("image-classification", model="nateraw/vit-age-classifier")
        for image in image_lst:
            preds = pipe(image)
            preds_age_range = preds[0]['label']
            preds_age_range = age_mapping[preds_age_range]
            age_lst +=[preds_age_range]
        return age_lst
    
    def pipeline_3_final(image_lst):
        gender_lst = []
        pipe = pipeline("image-classification", model="mikecho/NTQAI_pedestrian_gender_recognition_v1")
        for image in image_lst:
            preds = pipe(image)
            preds_gender = preds[0]['label']
            gender_lst +=[preds_gender]
        return gender_lst
    
    def gender_prediciton_model_NTQAI_pedestrian_gender_recognition(image_lst):
        gender_lst = []
        pipe = pipeline("image-classification", model="NTQAI/pedestrian_gender_recognition")
        for image in image_lst:
            preds = pipe(image)
            preds_gender = preds[0]['label']
            gender_lst +=[preds_gender]
        return gender_lst

    
    def pipeline_4_final(image_lst):
        start_time = time.time()
        pipe = pipeline("image-classification", model="dima806/facial_emotions_image_detection")
        preds_lst = []
        for image in image_lst:
            preds = pipe(image)
            preds_emotion = preds[0]['label']
            preds_lst +=[preds_emotion]
        return preds_lst
        
    def generate_gender_tables(gender_list, age_list, emotion_list):
        gender_count = {}
        for gender, age, emotion in zip(gender_list, age_list, emotion_list):
            if age not in gender_count:
                gender_count[age] = {'male': 0, 'female': 0}
                gender_count[age][gender] += 1
        happiness_percentage = {}
        for gender, age, emotion in zip(gender_list, age_list, emotion_list):
            if age not in happiness_percentage:
                happiness_percentage[age] = {'male': 0, 'female': 0}
            if emotion == 'happiness':
                happiness_percentage[age][gender] += 1
    
        table1 = []
        for age, count in gender_count.items():
            male_count = count['male']
            female_count = count['female']
            table1.append([age, male_count, female_count])
    
        table2 = []
        for age, happiness in happiness_percentage.items():
            male_count = gender_count[age]['male']
            female_count = gender_count[age]['female']
            male_percentage = (happiness['male'] / male_count) * 100 if male_count > 0 else 0
            female_percentage = (happiness['female'] / female_count) * 100 if female_count > 0 else 0
            table2.append([age, male_percentage, female_percentage])
    
        return table1, table2

    if uploaded_file is not None:
        print(uploaded_file)
        image = Image.open(uploaded_file)
        st.image(uploaded_file, caption="Processing Image", use_column_width=True)
        
        pipeline_1_out, person_count = pipeline_1_final(image)
        pipeline_2_age = pipeline_2_final(pipeline_1_out)
        pipeline_3_gender = pipeline_3_final(pipeline_1_out)
        pipeline_4_emotion = pipeline_3_final(pipeline_1_out)
        table1, table2 = generate_gender_tables(pipeline_3_gender, pipeline_2_age, pipeline_4_emotion)
        st.text('The detected number of person:', person_count)
        st.text('\nGender and Age Group Distribution')
        st.text('Age, Male, Female')
        for row in table1:
            print(row)
        

        st.text('\nShare of Happniess')
        st.text('Age, Male, Female')
        for row in table2:
            print(row)


if __name__ == "__main__":
    main()