import os
import argparse
from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image

# If required, create a face detection pipeline using MTCNN:
mtcnn = MTCNN(image_size=400, margin=150)

# Create an inception resnet (in eval mode):
resnet = InceptionResnetV1(pretrained='vggface2').eval()

def process(in_file, out_file, box=None):
    img = Image.open(in_file)

    if box is None:
        boxes, probs = mtcnn.detect(img)

        if boxes is None:
            print("Face not found, using default box")
            boxes = [[0,0,img.size[0],img.size[0]]]
        else:
            boxes = sorted(zip(probs, boxes), reverse=True)
            boxes = [box[1] for box in boxes]

        box = boxes[0]

    img_pad = 25

    box_l = int(box[0]) - img_pad
    box_t = int(box[1]) - img_pad
    box_r = int(box[2]) + img_pad
    box_b = int(box[3]) + img_pad
    
    # normalize box coordinates
    box_l = max(0, box_l)
    box_t = max(0, box_t)
    box_r = min(img.size[0], box_r)
    box_b = min(img.size[1], box_b)

    # calculate box width and height
    box_w = int(box_r-box_l)
    box_h = int(box_b-box_t)

    print("image size", img.size)
    print("original box", (box_l, box_t, box_r, box_b))
    print("original box size", box_w, "x", box_h)

    # find the smaller dimension
    box_d = min(box_w, box_h)

    # adjust box coordinates to be square
    box_l = int(box_l + (box_w - box_d)/2)
    box_t = int(box_t + (box_h - box_d)/2)
    box_r = int(box_l + box_d)
    box_b = int(box_t + box_d)
    
    box_w = int(box_r-box_l)
    box_h = int(box_b-box_t)   
    
    print("adjusted box", (box_l, box_t, box_r, box_b))
    print("adjusted size", box_w, "x", box_h)

    im_new = img.crop((box_l, box_t, box_r, box_b)).resize((300,300), Image.Resampling.LANCZOS)
    im_new.save(out_file)

def auto_crop(input_dir, output_dir):
    if os.path.isdir(output_dir) == False:
        print("Error: output directory does not exist")
        return
    # iterate over all files in the input directory
    if os.path.isdir(input_dir):
        for file in os.listdir(input_dir):
            try:
                in_file = os.path.join(input_dir, file)
                out_file = os.path.join(output_dir, file)
                print("Processing file", in_file)
                process(in_file, out_file)
            except KeyboardInterrupt:
                raise
            except:
                print("Error processing file", file)
    else:
        path, file = os.path.split(input_dir)
        print("Processing file", file)
        out_file = os.path.join(output_dir, file)
        process(input_dir, out_file)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Batch Auto Cropping")
    parser.add_argument('-i', '--input', help='Input folder', required=True)
    parser.add_argument('-o', '--output', help='Output folder', required=True)
    args = parser.parse_args()
    auto_crop(args.input, args.output)