from PIL import Image

import matplotlib as mpl
from utils import prep_for_plot

import torch.multiprocessing
import torchvision.transforms as T

from utils_gee import extract_img, transform_ee_img

import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from plotly.subplots import make_subplots

import os   
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
cmap = mpl.colors.ListedColormap(colors)

colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
scores_init = [1,2,4,3,4,1,0]

# Function that look for img on EE and segment it
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
def segment_loc(model, location, month, year, how = "month", month_end = '12', year_end = None) :
    if how == 'month':
        img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
    elif how == 'year' :
        if year_end == None :
            img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
        else : 
            img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
    
    img_test= transform_ee_img(img, max = 0.25)

    # Preprocess opened img
    x = preprocess(img_test)
    x = torch.unsqueeze(x, dim=0).cpu()
    # model=model.cpu()

    with torch.no_grad():
        feats, code = model.net(x)
        linear_preds = model.linear_probe(x, code)
        linear_preds = linear_preds.argmax(1)
        outputs = {
            'img': x[:model.cfg.n_images].detach().cpu(),
            'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
            }
    return outputs


# Function that look for all img on EE and extract all segments with the date as first output arg

def segment_group(location, start_date, end_date, how = 'month') :
    outputs = []
    st_month = int(start_date[5:7])
    end_month = int(end_date[5:7])
    
    st_year = int(start_date[0:4])
    end_year = int(end_date[0:4])
    


    for year in range(st_year, end_year+1) : 
        
        if year != end_year :
            last = 12
        else :
            last = end_month 

        if year != st_year:
            start = 1
        else :
            start = st_month

        if how == 'month' :
            for month in range(start, last + 1):
                month_str = f"{month:0>2d}"
                year_str = str(year)
                
                outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
        
        elif how == 'year' :
             outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
       
        elif how == '2months' :
            for month in range(start, last + 1):
                month_str = f"{month:0>2d}"
                year_str = str(year)
                month_end = (month) % 12 +1
                if month_end < month :
                    year_end = year +1
                else :
                    year_end = year
                month_end= f"{month_end:0>2d}"
                year_end = str(year_end)
                
                outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))

             
    return outputs

def values_from_output(output):
    imgs = transform_to_pil(output, alpha = 0.3)
    
    img = imgs[0]
    img = np.array(img.convert('RGB'))

    labeled_img = imgs[2]
    labeled_img = np.array(labeled_img.convert('RGB'))
    
    nb_values = []
    for i in range(7):
        nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))

    score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)    
    
    return img, labeled_img, nb_values, score


# Function that extract from outputs (from segment_group function) all dates/ all images 
def values_from_outputs(outputs) : 
    months = []
    imgs = []
    imgs_label = []
    nb_values = []
    scores = []

    for output in outputs:
        img, labeled_img, nb_value, score = values_from_output(output[1])
        months.append(output[0])
        imgs.append(img)
        imgs_label.append(labeled_img)
        nb_values.append(nb_value)
        scores.append(score)

    return months, imgs, imgs_label, nb_values, scores



def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :       

    fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
    fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
    
    # Scores 
    scatters = []
    temp = []
    for score in scores :
        temp_score = []
        temp_date = []
        score = scores[i]
        temp.append(score)
        text_temp = ["" for i in temp]
        text_temp[-1] = str(round(score,2))
        scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
        

    # Scores 
    fig = make_subplots(
        rows=1, cols=4,
        # specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]]
        # row_heights=[0.8, 0.2],
        column_widths = [0.6, 0.6,0.3, 0.3],
        subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores")
    )

    fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
    fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)

    fig.add_trace(go.Pie(labels = class_names,
                values = nb_values[0],
                marker_colors = colors, 
                name="Segment repartition",
                textposition='inside',
                texttemplate = "%{percent:.0%}",
                textfont_size=14
                ),
                row=1, col=3)


    fig.add_trace(scatters[0], row=1, col=4)
    # fig.add_annotation(text='score:' + str(scores[0]), 
    #                 showarrow=False,
    #                 row=2, col=2)


    number_frames = len(imgs)
    frames = [dict(
                name = k,
                data = [ fig2["frames"][k]["data"][0],
                        fig3["frames"][k]["data"][0],
                        go.Pie(labels = class_names,
                                values = nb_values[k],
                                marker_colors = colors, 
                                name="Segment repartition",
                                textposition='inside',
                                texttemplate = "%{percent:.0%}",
                                textfont_size=14
                                ),
                        scatters[k]
                        ],
                traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
                                        # that are updated by the above three go.Scatter instances
                ) for k in range(number_frames)]

    updatemenus = [dict(type='buttons',
                        buttons=[dict(label='Play',
                                    method='animate',
                                    args=[[f'{k}' for k in range(number_frames)], 
                                            dict(frame=dict(duration=500, redraw=False), 
                                                transition=dict(duration=0),
                                                easing='linear',
                                                fromcurrent=True,
                                                mode='immediate'
                                                                    )])],
                        direction= 'left', 
                        pad=dict(r= 10, t=85), 
                        showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
                ]

    sliders = [{'yanchor': 'top',
                'xanchor': 'left', 
                'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
                'transition': {'duration': 500.0, 'easing': 'linear'},
                'pad': {'b': 10, 't': 50}, 
                'len': 0.9, 'x': 0.1, 'y': 0, 
                'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
                                        'transition': {'duration': 0, 'easing': 'linear'}}], 
                        'label': months[k], 'method': 'animate'} for k in range(number_frames)       
                        ]}]


    fig.update(frames=frames)

    for i,fr in enumerate(fig["frames"]):
        fr.update(
            layout={
                "xaxis": {
                            "range": [0,imgs[0].shape[1]+i/100000]
                        },
                "yaxis": {
                            "range": [imgs[0].shape[0]+i/100000,0]
                        },
            })
        
        fr.update(layout_title_text= months[i])


    fig.update(layout_title_text= 'tot')
    fig.update(
            layout={
                "xaxis": {
                            "range": [0,imgs[0].shape[1]+i/100000],
                            'showgrid': False, # thin lines in the background
                            'zeroline': False, # thick line at x=0
                            'visible': False,  # numbers below
                        },

                "yaxis": {
                            "range": [imgs[0].shape[0]+i/100000,0],
                            'showgrid': False, # thin lines in the background
                            'zeroline': False, # thick line at y=0
                            'visible': False,},
                
                "xaxis3": {
                            "range": [0,len(scores)+1],
                            'autorange': False, # thin lines in the background
                            'showgrid': False, # thin lines in the background
                            'zeroline': False, # thick line at y=0
                            'visible': False    
                        },
                
                "yaxis3": {
                            "range": [0,1.5],
                            'autorange': False,
                            'showgrid': False, # thin lines in the background
                            'zeroline': False, # thick line at y=0
                            'visible': False # thin lines in the background
                         }   
            },
            legend=dict(
                yanchor="bottom",
                y=0.99,
                xanchor="center",
                x=0.01
            )
            )


    fig.update_layout(updatemenus=updatemenus,
                    sliders=sliders)

    fig.update_layout(margin=dict(b=0, r=0))

    # fig.show() #in jupyter notebook
    
    return fig



# Last function (global one)
# how = 'month' or '2months' or 'year' 

def segment_region(location, start_date, end_date, how = 'month'):
    
    #extract the outputs for each image
    outputs = segment_group(location, start_date, end_date, how = how)

    #extract the intersting values from image
    months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)

    #Create the figure
    fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
    
    return fig
#normalize img
preprocess = T.Compose([
   T.ToPILImage(),
   T.Resize((320,320)),
#    T.CenterCrop(224),
   T.ToTensor(),
   T.Normalize(
       mean=[0.485, 0.456, 0.406],
       std=[0.229, 0.224, 0.225]
   )
])

# Function that look for img on EE and segment it
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img

def segment_loc(model,location, month, year, how = "month", month_end = '12', year_end = None) :
    if how == 'month':
        img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
    elif how == 'year' :
        if year_end == None :
            img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
        else : 
            img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
            
    
    img_test= transform_ee_img(img, max = 0.25)

    # Preprocess opened img
    x = preprocess(img_test)
    x = torch.unsqueeze(x, dim=0).cpu()
    # model=model.cpu()

    with torch.no_grad():
        feats, code = model.net(x)
        linear_preds = model.linear_probe(x, code)
        linear_preds = linear_preds.argmax(1)
        outputs = {
            'img': x[:model.cfg.n_images].detach().cpu(),
            'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
            }
    return outputs


# Function that look for all img on EE and extract all segments with the date as first output arg

def segment_group(location, start_date, end_date, how = 'month') :
    outputs = []
    st_month = int(start_date[5:7])
    end_month = int(end_date[5:7])
    
    st_year = int(start_date[0:4])
    end_year = int(end_date[0:4])



    for year in range(st_year, end_year+1) : 
    
        if year != end_year :
            last = 12
        else :
            last = end_month 

        if year != st_year:
            start = 1
        else :
            start = st_month

        if how == 'month' :
            for month in range(start, last + 1):
                month_str = f"{month:0>2d}"
                year_str = str(year)
                
                outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
        
        elif how == 'year' :
             outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
       
        elif how == '2months' :
            for month in range(start, last + 1):
                month_str = f"{month:0>2d}"
                year_str = str(year)
                month_end = (month) % 12 +1
                if month_end < month :
                    year_end = year +1
                else :
                    year_end = year
                month_end= f"{month_end:0>2d}"
                year_end = str(year_end)
                
                outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))

             
    return outputs


# Function that transforms an output to PIL images

def transform_to_pil(outputs,alpha=0.3):
    # Transform img with torch
    img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
    img=T.ToPILImage()(img)
    
    # Transform label by saving it then open it
    # label = outputs['linear_preds'][0]
    # plt.imsave('label.png',label,cmap=cmap)
    # label = Image.open('label.png')

    cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
    labels = np.array(outputs['linear_preds'][0])-1
    label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
    

    # Overlay labels with img wit alpha
    background = img.convert("RGBA")
    overlay = label.convert("RGBA")
    
    labeled_img = Image.blend(background, overlay, alpha)

    return img, label, labeled_img

def values_from_output(output):
    imgs = transform_to_pil(output,alpha = 0.3)
    
    img = imgs[0]
    img = np.array(img.convert('RGB'))

    labeled_img = imgs[2]
    labeled_img = np.array(labeled_img.convert('RGB'))
    
    nb_values = []
    for i in range(7):
        nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))

    score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)    
    
    return img, labeled_img, nb_values, score


# Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation



# Function that extract from outputs (from segment_group function) all dates/ all images 
def values_from_outputs(outputs) : 
    months = []
    imgs = []
    imgs_label = []
    nb_values = []
    scores = []

    for output in outputs:
        img, labeled_img, nb_value, score = values_from_output(output[1])
        months.append(output[0])
        imgs.append(img)
        imgs_label.append(labeled_img)
        nb_values.append(nb_value)
        scores.append(score)
  
    return months, imgs, imgs_label, nb_values, scores





# Last function (global one)
# how = 'month' or '2months' or 'year' 

def segment_region(latitude, longitude, start_date, end_date, how = 'month'):
    location = [float(latitude),float(longitude)]
    how = how[0]
    #extract the outputs for each image
    outputs = segment_group(location, start_date, end_date, how = how)

    #extract the intersting values from image
    months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
    print(months, imgs, imgs_label, nb_values, scores)


    #Create the figure
    fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
    
    return fig