import os
from abc import ABC, abstractmethod
from typing import List

import cv2
import numpy as np
from retinaface import RetinaFace
from retinaface.model import retinaface_model

from .box_utils import convert_to_square


class FaceDetector(ABC):
    def __init__(self, target_size):
        self.target_size = target_size
    @abstractmethod
    def detect_crops(self, img, *args, **kwargs) -> List[np.ndarray]:
        """
        Img is a numpy ndarray in range [0..255], uint8 dtype, RGB type
        Returns ndarray with [x1, y1, x2, y2] in row
        """
        pass

    @abstractmethod
    def postprocess_crops(self, crops, *args, **kwargs) -> List[np.ndarray]:
        return crops

    def sort_faces(self, crops):
        sorted_faces = sorted(crops, key=lambda x: -(x[2] - x[0]) * (x[3] - x[1]))
        sorted_faces = np.stack(sorted_faces, axis=0)
        return sorted_faces

    def fix_range_crops(self, img, crops):
        H, W, _ = img.shape
        final_crops = []
        for crop in crops:
            x1, y1, x2, y2 = crop
            x1 = max(min(round(x1), W), 0)
            y1 = max(min(round(y1), H), 0)
            x2 = max(min(round(x2), W), 0)
            y2 = max(min(round(y2), H), 0)
            new_crop = [x1, y1, x2, y2]
            final_crops.append(new_crop)
        final_crops = np.array(final_crops, dtype=np.int32)
        return final_crops

    def crop_faces(self, img, crops) -> List[np.ndarray]:
        cropped_faces = []
        for crop in crops:
            x1, y1, x2, y2 = crop
            face_crop = img[y1:y2, x1:x2, :]
            cropped_faces.append(face_crop)
        return cropped_faces

    def unify_and_merge(self, cropped_images):
        return cropped_images

    def __call__(self, img):
        return self.detect_faces(img)

    def detect_faces(self, img):
        crops = self.detect_crops(img)
        if crops is None or len(crops) == 0:
            return [], []
        crops = self.sort_faces(crops)
        updated_crops = self.postprocess_crops(crops)
        updated_crops = self.fix_range_crops(img, updated_crops)
        cropped_faces = self.crop_faces(img, updated_crops)
        unified_faces = self.unify_and_merge(cropped_faces)
        return unified_faces, updated_crops


class StatRetinaFaceDetector(FaceDetector):
    def __init__(self, target_size=None):
        super().__init__(target_size)
        self.model = retinaface_model.build_model()
        #self.relative_offsets = [0.3258, 0.5225, 0.3258, 0.1290]
        self.relative_offsets = [0.3619, 0.5830, 0.3619, 0.1909]

    def postprocess_crops(self, crops, *args, **kwargs) -> np.ndarray:
        final_crops = []
        x1_offset, y1_offset, x2_offset, y2_offset = self.relative_offsets
        for crop in crops:
            x1, y1, x2, y2 = crop
            w, h = x2 - x1, y2 - y1
            x1 -= w * x1_offset
            y1 -= h * y1_offset
            x2 += w * x2_offset
            y2 += h * y2_offset
            crop = np.array([x1, y1, x2, y2], dtype=crop.dtype)
            crop = convert_to_square(crop)
            final_crops.append(crop)
        final_crops = np.stack(final_crops, axis=0)
        return final_crops

    def detect_crops(self, img, *args, **kwargs):
        faces = RetinaFace.detect_faces(img, model=self.model)
        crops = []
        if isinstance(faces, tuple):
            faces = {}
        for name, face in faces.items():
            x1, y1, x2, y2 = face['facial_area']
            crop = np.array([x1, y1, x2, y2])
            crops.append(crop)
        if len(crops) > 0:
            crops = np.stack(crops, axis=0)
        return crops

    def unify_and_merge(self, cropped_images):
        if self.target_size is None:
            return cropped_images
        else:
            resized_images = []
            for cropped_image in cropped_images:
                resized_image = cv2.resize(cropped_image, (self.target_size, self.target_size),
                                           interpolation=cv2.INTER_LINEAR)
                resized_images.append(resized_image)

            resized_images = np.stack(resized_images, axis=0)
            return resized_images