import os
import zipfile
from pathlib import Path
from time import time
from typing import Union
import matplotlib.pyplot as plt

import dosma
import numpy as np
import wget
import cv2
import scipy.misc
from PIL import Image

import dicom2nifti
import math
import pydicom
import operator
import moviepy.video.io.ImageSequenceClip
from tkinter import Tcl
import pandas as pd
import warnings

import numpy as np
from skimage.morphology import skeletonize_3d
from scipy.spatial.distance import pdist, squareform
from scipy.interpolate import splprep, splev
import nibabel as nib
from nibabel.processing import resample_to_output

import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

from totalsegmentator.libs import (
    download_pretrained_weights,
    nostdout,
    setup_nnunet,
)

from comp2comp.inference_class_base import InferenceClass
from comp2comp.models.models import Models
from comp2comp.spine import spine_utils
import nibabel as nib

class AortaSegmentation(InferenceClass):
    """Spine segmentation."""

    def __init__(self, save=True):
        super().__init__()
        self.model_name = "totalsegmentator"
        self.save_segmentations = save

    def __call__(self, inference_pipeline):
        # inference_pipeline.dicom_series_path = self.input_path
        self.output_dir = inference_pipeline.output_dir
        self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/")
        if not os.path.exists(self.output_dir_segmentations):
            os.makedirs(self.output_dir_segmentations)

        self.model_dir = inference_pipeline.model_dir

        seg, mv = self.spine_seg(
            os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
            self.output_dir_segmentations + "spine.nii.gz",
            inference_pipeline.model_dir,
        )
       
        seg = seg.get_fdata()
        medical_volume = mv.get_fdata()
      
        axial_masks = []
        ct_image = []

        for i in range(seg.shape[2]):
            axial_masks.append(seg[:, :, i])
        
        for i in range(medical_volume.shape[2]):
            ct_image.append(medical_volume[:, :, i])

        # Save input axial slices to pipeline
        inference_pipeline.ct_image = ct_image

        # Save aorta masks to pipeline
        inference_pipeline.axial_masks = axial_masks

        return {}

    def setup_nnunet_c2c(self, model_dir: Union[str, Path]):
        """Adapted from TotalSegmentator."""

        model_dir = Path(model_dir)
        config_dir = model_dir / Path("." + self.model_name)
        (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir(exist_ok=True, parents=True)
        (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True)
        weights_dir = config_dir / "nnunet/results"
        self.weights_dir = weights_dir



        os.environ["nnUNet_raw_data_base"] = str(
            weights_dir
        )  # not needed, just needs to be an existing directory
        os.environ["nnUNet_preprocessed"] = str(
            weights_dir
        )  # not needed, just needs to be an existing directory
        os.environ["RESULTS_FOLDER"] = str(weights_dir)

    def download_spine_model(self, model_dir: Union[str, Path]):
        download_dir = Path(
            os.path.join(
                self.weights_dir,
                "nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1",
            )
        )
        print(download_dir)
        fold_0_path = download_dir / "fold_0"
        if not os.path.exists(fold_0_path):
            download_dir.mkdir(parents=True, exist_ok=True)
            wget.download(
                "https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip",
                out=os.path.join(download_dir, "fold_0.zip"),
            )
            with zipfile.ZipFile(os.path.join(download_dir, "fold_0.zip"), "r") as zip_ref:
                zip_ref.extractall(download_dir)
            os.remove(os.path.join(download_dir, "fold_0.zip"))
            wget.download(
                "https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl",
                out=os.path.join(download_dir, "plans.pkl"),
            )
            print("Spine model downloaded.")
        else:
            print("Spine model already downloaded.")

    def spine_seg(self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir):
        """Run spine segmentation.

        Args:
            input_path (Union[str, Path]): Input path.
            output_path (Union[str, Path]): Output path.
        """

        print("Segmenting spine...")
        st = time()
        os.environ["SCRATCH"] = self.model_dir

        print(self.model_dir)

        # Setup nnunet
        model = "3d_fullres"
        folds = [0]
        trainer = "nnUNetTrainerV2_ep4000_nomirror"
        crop_path = None
        task_id = [253]
        
        self.setup_nnunet_c2c(model_dir)
        self.download_spine_model(model_dir)

        from totalsegmentator.nnunet import nnUNet_predict_image

        with nostdout():

            img, seg = nnUNet_predict_image(
                input_path,
                output_path,
                task_id,
                model=model,
                folds=folds,
                trainer=trainer,
                tta=False,
                multilabel_image=True,
                resample=1.5,
                crop=None,
                crop_path=crop_path,
                task_name="total",
                nora_tag="None",
                preview=False,
                nr_threads_resampling=1,
                nr_threads_saving=6,
                quiet=False,
                verbose=False,
                test=0,
            )
        end = time()

        # Log total time for spine segmentation
        print(f"Total time for spine segmentation: {end-st:.2f}s.")

        seg_data = seg.get_fdata()
        seg = nib.Nifti1Image(seg_data, seg.affine, seg.header)

        return seg, img

class AortaDiameter(InferenceClass):

    def __init__(self):
        super().__init__()

    def normalize_img(self, img: np.ndarray) -> np.ndarray:
        """Normalize the image.
        Args:
            img (np.ndarray): Input image.
        Returns:
            np.ndarray: Normalized image.
        """
        return (img - img.min()) / (img.max() - img.min())

    def __call__(self, inference_pipeline):

        axial_masks = inference_pipeline.axial_masks # list of 2D numpy arrays of shape (512, 512)
        ct_img = inference_pipeline.ct_image # 3D numpy array of shape (512, 512, num_axial_slices)

        # image output directory 
        output_dir = inference_pipeline.output_dir
        output_dir_slices = os.path.join(output_dir, "images/slices/")
        if not os.path.exists(output_dir_slices):
            os.makedirs(output_dir_slices)

        output_dir = inference_pipeline.output_dir
        output_dir_summary = os.path.join(output_dir, "images/summary/")
        if not os.path.exists(output_dir_summary):
            os.makedirs(output_dir_summary)

        DICOM_PATH = inference_pipeline.dicom_series_path
        dicom = pydicom.dcmread(DICOM_PATH+"/"+os.listdir(DICOM_PATH)[0])
        
        dicom.PhotometricInterpretation = 'YBR_FULL'
        pixel_conversion = dicom.PixelSpacing
        print("Pixel conversion: "+str(pixel_conversion))
        RATIO_PIXEL_TO_MM = pixel_conversion[0]

        SLICE_COUNT = dicom["InstanceNumber"].value
        print(SLICE_COUNT)

        SLICE_COUNT = len(ct_img)
        diameterDict = {}
        
        for i in range(len(ct_img)):

            mask = axial_masks[i].astype('uint8')

            img = ct_img[i]

            img = np.clip(img, -300, 1800)
            img = self.normalize_img(img) * 255.0
            img = img.reshape((img.shape[0], img.shape[1], 1))
            img = np.tile(img, (1, 1, 3))

            contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)

            if len(contours) != 0:

                    areas = [cv2.contourArea(c) for c in contours]
                    sorted_areas = np.sort(areas)

                    contours = contours[areas.index(sorted_areas[-1])]

                    overlay = img.copy()

                    back = img.copy()
                    cv2.drawContours(back, [contours], 0, (0,255,0), -1)

                    alpha = 0.25
                    img = cv2.addWeighted(img, 1-alpha, back, alpha, 0)

                    ellipse = cv2.fitEllipse(contours)
                    (xc,yc),(d1,d2),angle = ellipse
            
                    cv2.ellipse(img, ellipse, (0, 255, 0), 1)
            
                    xc, yc = ellipse[0]
                    cv2.circle(img, (int(xc),int(yc)), 5, (0, 0, 255), -1)

                    rmajor = max(d1,d2)/2
                    rminor = min(d1,d2)/2

                    ### Draw major axes

                    if angle > 90:
                        angle = angle - 90
                    else:
                        angle = angle + 90
                    print(angle)
                    xtop = xc + math.cos(math.radians(angle))*rmajor
                    ytop = yc + math.sin(math.radians(angle))*rmajor
                    xbot = xc + math.cos(math.radians(angle+180))*rmajor
                    ybot = yc + math.sin(math.radians(angle+180))*rmajor
                    cv2.line(img, (int(xtop),int(ytop)), (int(xbot),int(ybot)), (0, 0, 255), 3)

                    ### Draw minor axes

                    if angle > 90:
                        angle = angle - 90
                    else:
                        angle = angle + 90
                    print(angle)
                    x1 = xc + math.cos(math.radians(angle))*rminor
                    y1 = yc + math.sin(math.radians(angle))*rminor
                    x2 = xc + math.cos(math.radians(angle+180))*rminor
                    y2 = yc + math.sin(math.radians(angle+180))*rminor
                    cv2.line(img, (int(x1),int(y1)), (int(x2),int(y2)), (255, 0, 0), 3)

                    # pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 )
                    pixel_length = rminor*2
      
                    print("Pixel_length_minor: "+str(pixel_length))

                    area_px = cv2.contourArea(contours)
                    area_mm = round(area_px*RATIO_PIXEL_TO_MM)
                    area_cm = area_mm/10

                    diameter_mm = round((pixel_length)*RATIO_PIXEL_TO_MM)
                    diameter_cm = diameter_mm/10

                    diameterDict[(SLICE_COUNT-(i))] = diameter_cm

                    img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)

                    h,w,c = img.shape
                    lbls = ["Area (mm): "+str(area_mm)+"mm", "Area (cm): "+str(area_cm)+"cm", "Diameter (mm): "+str(diameter_mm)+"mm", "Diameter (cm): "+str(diameter_cm)+"cm", "Slice: "+str(SLICE_COUNT-(i))]
                    offset = 0
                    font = cv2.FONT_HERSHEY_SIMPLEX
                    
                    scale = 0.03
                    fontScale = min(w,h)/(25/scale)
                    
                    cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2)
                    
                    cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2)
                    
                    cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2)
                    
                    cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2)

                    cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2)

                    cv2.imwrite(output_dir_slices+"slice"+str(SLICE_COUNT-(i))+".png", img)

        plt.bar(list(diameterDict.keys()), diameterDict.values(), color='b')

        plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$")


        plt.xlabel('Slice Number')

        plt.ylabel('Diameter Measurement (cm)')
        plt.savefig(output_dir_summary+"diameter_graph.png", dpi=500)

        print(diameterDict)
        print(max(diameterDict.items(), key=operator.itemgetter(1))[0])
        print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]])

        inference_pipeline.max_diameter = diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]

        img = ct_img[SLICE_COUNT-(max(diameterDict.items(), key=operator.itemgetter(1))[0])]
        img = np.clip(img, -300, 1800)
        img = self.normalize_img(img) * 255.0
        img = img.reshape((img.shape[0], img.shape[1], 1))
        img2 = np.tile(img, (1, 1, 3))
        img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE)

        img1 = cv2.imread(output_dir_slices+'slice'+str(max(diameterDict.items(), key=operator.itemgetter(1))[0])+'.png')

        border_size = 3
        img1 = cv2.copyMakeBorder(
            img1,
            top=border_size,
            bottom=border_size,
            left=border_size,
            right=border_size,
            borderType=cv2.BORDER_CONSTANT,
            value=[0, 244, 0]
        )
        img2 = cv2.copyMakeBorder(
            img2,
            top=border_size,
            bottom=border_size,
            left=border_size,
            right=border_size,
            borderType=cv2.BORDER_CONSTANT,
            value=[244, 0, 0]
        )

        vis = np.concatenate((img2, img1), axis=1)
        cv2.imwrite(output_dir_summary+'out.png', vis)

        image_folder=output_dir_slices
        fps=20
        image_files = [os.path.join(image_folder,img)
                    for img in Tcl().call('lsort', '-dict', os.listdir(image_folder))
                    if img.endswith(".png")]
        clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)
        clip.write_videofile(output_dir_summary+'aaa.mp4')


        def compute_centerline_3d(aorta_segmentation):
            skeleton = skeletonize_3d(aorta_segmentation)
            z, y, x = np.where(skeleton)
            centerline_points = np.vstack((x, y, z)).T
            centerline_points = centerline_points[centerline_points[:, 0].argsort()]
            return centerline_points


        def fit_bspline(centerline_points, smoothness=1e8):
            x, y, z = centerline_points.T
            tck, _ = splprep([x, y, z], s=smoothness)
            return tck


        def evaluate_bspline(tck, num_points=1000):
            u = np.linspace(0, 1, num_points)
            x, y, z = splev(u, tck)
            return np.vstack((x, y, z)).T


        def interpolate_points(data, num_points=32):
            x = data[:, 0]
            y = data[:, 1:]
            f_y = interp1d(x, y, kind="nearest", fill_value="extrapolate", axis=0)
            new_x = np.arange(0, num_points)
            new_y = f_y(new_x)
            new_data = np.round(np.hstack((new_x.reshape(-1, 1), new_y)))
            return new_data


        def compute_orthogonal_planes(tck, num_points=100):
            u = np.linspace(0, 1, num_points)
            points = np.vstack(splev(u, tck)).T
            tangents = np.vstack(splev(u, tck, der=1)).T

            normals = tangents / np.linalg.norm(tangents, axis=1)[:, np.newaxis]

            planes = []
            for point, normal in zip(points, normals):
                d = -np.dot(point, normal)
                planes.append((normal, d))

            return planes


        def compute_maximum_diameter(aorta_segmentation, planes):
            z, y, x = np.where(aorta_segmentation)
            aorta_points = np.vstack((x, y, z)).T


            max_diameters = []
            intersecting_points_list = []
            for normal, d in planes:
                distances = np.dot(aorta_points, normal) + d
                intersecting_points = aorta_points[np.abs(distances) < 0.5]

                if len(intersecting_points) < 2:
                    continue

                dist_matrix = squareform(pdist(intersecting_points))
                intersecting_points_list.append(intersecting_points)

                max_diameter = np.max(dist_matrix)
                max_diameters.append(max_diameter)

            max_diameter_index = np.argmax(max_diameters)
            max_diameter_in_pixels = max_diameters[max_diameter_index]
            print(f'Maximum Diameter in Pixels: {max_diameter_in_pixels}')

            diameter_mm = round((max_diameter_in_pixels)*RATIO_PIXEL_TO_MM)
            print(f'Maximum Diameter in mm: {diameter_mm}')

            max_diameters = np.array(max_diameters) * 0.15
            max_diameter_index = np.argmax(max_diameters)
            max_diameter_normal, max_diameter_point = planes[max_diameter_index]
            max_intersecting_points = intersecting_points_list[max_diameter_index]
            print("max_diameter_normal type:", type(max_diameter_normal))
            print("max_diameter_normal shape:", np.shape(max_diameter_normal))
            print("max_diameter_point type:", type(max_diameter_point))
            print("max_diameter_point shape:", np.shape(max_diameter_point))

            print("max intersecting points type:", type(max_intersecting_points))
            print("max intersecting points shape:", np.shape(max_intersecting_points))
            print("max intersecting points:", max_intersecting_points)

            return (
                max_diameters,
                max_diameter_point,
                max_diameter_normal,
                max_intersecting_points,
            )


        def plot_2d_planar_reconstruction(
            image,
            segmentation,
            interpolated_points,
            max_diameter_point,
            max_diameter_normal,
            max_intersecting_points,
        ):
            fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(15, 10))

            sagittal_index = interpolated_points[:, 2].astype(int)
            image_2d = image[sagittal_index, :, range(image.shape[2])]
            seg_2d = segmentation[sagittal_index, :, range(image.shape[2])]

            # axs[0].imshow(image_2d, cmap="gray")
            # axs[0].imshow(seg_2d, cmap="jet", alpha=0.3)
            axs[0].scatter(
                interpolated_points[:, 1].astype(int),
                interpolated_points[:, 0].astype(int),
                color="red",
                s=1,
            )
            axs[0].plot(
                max_intersecting_points[:, 1].astype(int),
                max_intersecting_points[:, 0].astype(int),
                color="blue",
            )

            coronal_index = interpolated_points[:, 1].astype(int)
            image_2d = image[:, coronal_index, range(image.shape[2])].T
            seg_2d = segmentation[:, coronal_index, range(image.shape[2])].T

            # axs[1].imshow(image_2d, cmap="gray")
            # axs[1].imshow(seg_2d, cmap="jet", alpha=0.3)
            axs[1].scatter(
                interpolated_points[:, 2].astype(int),
                interpolated_points[:, 0].astype(int),
                color="red",
                s=1,
            )
            axs[1].plot(
                max_intersecting_points[:, 2].astype(int),
                max_intersecting_points[:, 0].astype(int),
                color="blue",
            )

            plt.savefig(output_dir_summary+"planar_reconstruction.png")

        output_dir = inference_pipeline.output_dir_segmentations

        segmentation = nib.load(
             os.path.join(output_dir, "converted_dcm.nii.gz")
        )
        image = nib.load(
            os.path.join(output_dir, "spine.nii.gz")
        )

        image = resample_to_output(image, (1.5, 1.5, 1.5))
        segmentation = resample_to_output(segmentation, (1.5, 1.5, 1.5), order=0)
        image = image.get_fdata()
        segmentation = segmentation.get_fdata()

        segmentation[segmentation == 42] = 1

        print(segmentation.shape)
        print(np.unique(segmentation))
        centerline_points = compute_centerline_3d(segmentation)
        print(centerline_points)
        tck = fit_bspline(centerline_points)
        evaluated_points = evaluate_bspline(tck)
        print(evaluated_points)
        interpolated_points = interpolate_points(evaluated_points, image.shape[2])
        print(interpolated_points)
        planes = compute_orthogonal_planes(tck)
        (
            cmax_diameters,
            max_diameter_point,
            max_diameter_normal,
            max_intersecting_points,
        ) = compute_maximum_diameter(segmentation, planes)
        plot_2d_planar_reconstruction(
            image,
            segmentation,
            interpolated_points,
            max_diameter_point,
            max_diameter_normal,
            max_intersecting_points,
        )

        return {}


class AortaMetricsSaver(InferenceClass):
    """Save metrics to a CSV file."""

    def __init__(self):
        super().__init__()

    def __call__(self, inference_pipeline):
        """Save metrics to a CSV file."""
        self.max_diameter = inference_pipeline.max_diameter
        self.dicom_series_path = inference_pipeline.dicom_series_path
        self.output_dir = inference_pipeline.output_dir
        self.csv_output_dir = os.path.join(self.output_dir, "metrics")
        if not os.path.exists(self.csv_output_dir):
            os.makedirs(self.csv_output_dir, exist_ok=True)
        self.save_results()
        return {}

    def save_results(self):
        """Save results to a CSV file."""
        _, filename = os.path.split(self.dicom_series_path)
        data = [[filename, str(self.max_diameter)]]
        df = pd.DataFrame(data, columns=['Filename', 'Max Diameter'])
        df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False)