import os
from transformers import pipeline
import gradio
import base64
from PIL import Image, ImageDraw
from io import BytesIO
from sentence_transformers import SentenceTransformer, util
import spaces

backgroundPipe = pipeline("image-segmentation", model="facebook/maskformer-swin-large-coco")
PersonPipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
sentenceModal = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
personDetailsPipe = pipeline("image-segmentation", model="yolo12138/segformer-b2-human-parse-24")
faceModal = pipeline("image-segmentation", model="jonathandinu/face-parsing")
faceDetectionModal = pipeline("object-detection", model="aditmohan96/detr-finetuned-face")
PersonDetectionpipe = pipeline("object-detection", model="hustvl/yolos-tiny")

def getPersonDetail(image):
    data = PersonDetectionpipe(image)
    persn = []
    for per in data:
        if per["label"].lower() == "person":
            persn.append(per["box"])
    n = 1
    ret = {}
    for cord in persn:
        crop_box = (cord['xmin'], cord['ymin'], cord['xmax'], cord['ymax'])
        cropped_image = image.crop(crop_box)
        personData = personDetailsPipe(cropped_image)
        for dt in personData:
            if len(persn) > 1:
                ret[(f'Person {n} {dt["label"]}').lower()] = cbiwm(image, dt["mask"], cord)
            else:
                ret[dt["label"].lower()] = cbiwm(image, dt["mask"], cord)
        n = n + 1
    return ret

def cbiwm(image, mask, coordinates):
    black_image = Image.new("RGBA", image.size, (0, 0, 0, 255))
    black_image.paste(mask, (coordinates['xmin'], coordinates['ymin']), mask)
    return black_image

def processFaceDetails(image):
    ret = getPersonDetail(image)
    data = faceDetectionModal(image)
    if len(data) > 1:
        cordinates = data[1]["box"]
        crop_box = (data[1]["box"]['xmin'], data[1]["box"]['ymin'], data[1]["box"]['xmax'], data[1]["box"]['ymax'])
    elif len(data) > 0:
        cordinates = data[0]["box"]
        crop_box = (data[0]["box"]['xmin'], data[0]["box"]['ymin'], data[0]["box"]['xmax'], data[0]["box"]['ymax'])
    else:
        return ret
    cropped_image = image.crop(crop_box)
    facedata = faceModal(cropped_image)
    for imask in facedata:
        ret[imask["label"].replace(".png", "").lower()] = cbiwm(image, imask["mask"], cordinates)
    return ret

def getImageDetails(image) -> dict:
    ret = processFaceDetails(image)
    person = PersonPipe(image)
    bg = backgroundPipe(image)
    for imask in bg:
        ret[imask["label"].lower()] = imask["mask"] # Apply base64 image converter here if needed
    for mask in person:
        ret[mask["label"].lower()] = mask["mask"] # Apply base64 image converter here if needed
    return ret

def processSentence(sentence: str, semilist: list):
    query_embedding = sentenceModal.encode(sentence)
    passage_embedding = sentenceModal.encode(semilist)
    listv = util.dot_score(query_embedding, passage_embedding)[0]
    float_list = []
    for i in listv:
        float_list.append(i)
    max_value = max(float_list)
    max_index = float_list.index(max_value)
    return semilist[max_index]

def process_image(image):
    rgba_image = image.convert("RGBA")
    switched_data = [
        (255, 255, 255, pixel[3]) if pixel[:3] == (0, 0, 0) else (0, 0, 0, pixel[3]) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in rgba_image.getdata()
    ]
    switched_image = Image.new("RGBA", rgba_image.size)
    switched_image.putdata(switched_data)
    final_data = [
        (0, 0, 0, 0) if pixel[:3] == (255, 255, 255) else pixel
        for pixel in switched_image.getdata()
    ]
    processed_image = Image.new("RGBA", rgba_image.size)
    processed_image.putdata(final_data)
    return processed_image

@spaces.GPU()
def processAndGetMask(image: str, text: str):
    datas = getImageDetails(image)
    labs = list(datas.keys())
    selector = processSentence(text, labs)
    imageout = datas[selector]
    print(f"Selected : {selector} Among : {labs}")
    return process_image(imageout)

gr = gradio.Interface(
    processAndGetMask,
    [gradio.Image(label="Input Image", type="pil"), gradio.Text(label="Input text to segment")],
    gradio.Image(label="Output Image", type="pil")
)
gr.launch(share=True)