from flask import Flask, request, jsonify ,send_file
from PIL import Image
import requests
import base64
import spaces
import multiprocessing
from loadimg import load_img
from io import BytesIO
import numpy as np
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
import threading
import logging
import uuid
from transformers import AutoModelForImageSegmentation,AutoModelForCausalLM, AutoProcessor
import torch
from torchvision import transforms
import subprocess
import logging
import json 
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

app = Flask(__name__)

kwargs = {}
kwargs['torch_dtype'] = torch.bfloat16

models = {
    "microsoft/Phi-3-vision-128k-instruct": AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
}

processors = {
    "microsoft/Phi-3-vision-128k-instruct": AutoProcessor.from_pretrained("microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True)
}

subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"

def get_image_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()  # Vérifie les erreurs HTTP
        img = Image.open(BytesIO(response.content))
        return img
    except Exception as e:
        logging.error(f"Error fetching image from URL: {e}")
        raise


# Function to decode a base64 image to PIL.Image.Image
def decode_image_from_base64(image_data):
    image_data = base64.b64decode(image_data)
    image = Image.open(BytesIO(image_data)).convert("RGB")
    return image

# Function to encode a PIL image to base64
def encode_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")  # Use PNG for compatibility with RGBA
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

def extract_image(image_data):
    # Vérifie si l'image est en base64 ou URL
    if image_data.startswith('http://') or image_data.startswith('https://'):
        return get_image_from_url(image_data)  # Télécharge l'image depuis l'URL
    else:
        return decode_image_from_base64(image_data)  # Décode l'image base64

@spaces.GPU
def process_vision(image, text_input=None, model_id="microsoft/Phi-3-vision-128k-instruct"):
    model = models[model_id]
    processor = processors[model_id]

    prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
    image = image.convert("RGB")

    inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
    generate_ids = model.generate(**inputs, 
                                max_new_tokens=4128,
                                eos_token_id=processor.tokenizer.eos_token_id,
                                )
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
    response = processor.batch_decode(generate_ids, 
                                    skip_special_tokens=True, 
                                    clean_up_tokenization_spaces=False)[0]
    return response


@app.route('/api/vision', methods=['POST'])
def process_api_vision():
    try:
        data = request.json
        image = data['image']
        prompt = data['prompt']
        image = extract_image(image)
        result = process_vision(image,prompt)    

        # Remove ```json and ``` markers
        if result.startswith("```json"):
            result = result[7:]  # Remove the leading ```json
        if result.endswith("```"):
            result = result[:-3]  # Remove the trailing ```

        # Convert the string result to a Python dictionary
        try:
            logging.info(result)
            result_dict = json.loads(result)
        except json.JSONDecodeError as e:
            logging.error(f"JSON decoding error: {e}")
            return jsonify({'error': 'Invalid JSON format in the response'}), 500

        
        return jsonify(result_dict)
    except Exception as e:
        logging.error(f"Error occurred: {e}")
        return jsonify({'error': str(e)}), 500

# Configure logging
logging.basicConfig(level=logging.INFO)

# Load the model lazily
model = None
detector = None

def load_model():
    global model, detector
    path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
    options = ort.SessionOptions()
    options.intra_op_num_threads = 8
    options.inter_op_num_threads = 8
    session = ort.InferenceSession(
        path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
    )
    model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
    model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
    detector = model
    logging.info("Model loaded successfully.")

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)



def save_image(img):
    unique_name = str(uuid.uuid4()) + ".png"
    img.save(unique_name)
    return unique_name

# Function to decode a base64 image to PIL.Image.Image
def decode_image_from_base64(image_data):
    image_data = base64.b64decode(image_data)
    image = Image.open(BytesIO(image_data)).convert("RGB")
    return image

# Function to encode a PIL image to base64
def encode_image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")  # Use PNG for compatibility with RGBA
    return base64.b64encode(buffered.getvalue()).decode('utf-8')
    
@spaces.GPU
def rm_background(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    image = load_img(im)
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return (image)

@spaces.GPU
def remove_background(image):
    remover = Remover()
    if isinstance(image, Image.Image):
        output = remover.process(image)
    elif isinstance(image, np.ndarray):
        image_pil = Image.fromarray(image)
        output = remover.process(image_pil)
    else:
        raise TypeError("Unsupported image type")
    return output
    
@spaces.GPU
def detect_and_segment_persons(image, clothes):
    img = np.array(image)
    img = img[:, :, ::-1]  # RGB -> BGR

    if detector is None:
        load_model()  # Ensure the model is loaded

    bboxes, kpss = detector.detect(img)
    if bboxes.shape[0] == 0:
        return [save_image(rm_background(image))]

    height, width, _ = img.shape
    bboxes = np.round(bboxes[:, :4]).astype(int)
    bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
    bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
    bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
    bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)

    all_segmented_images = []
    for i in range(bboxes.shape[0]):
        bbox = bboxes[i]
        x1, y1, x2, y2 = bbox
        person_img = img[y1:y2, x1:x2]
        pil_img = Image.fromarray(person_img[:, :, ::-1])

        img_rm_background = rm_background(pil_img)
        segmented_result = segment_clothing(img_rm_background, clothes)
        image_paths = [save_image(img) for img in segmented_result]
        print(image_paths)
        all_segmented_images.extend(image_paths)

    return all_segmented_images

@app.route('/', methods=['GET'])
def welcome():
    return "Welcome to Clothing Segmentation API"

@app.route('/api/detect', methods=['POST'])
def detect():
    try:
        data = request.json
        image_base64 = data['image']
        image = decode_image_from_base64(image_base64)

        clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
        

        result = detect_and_segment_persons(image, clothes)
    

        return jsonify({'images': result})
    except Exception as e:
        logging.error(f"Error occurred: {e}")
        return jsonify({'error': str(e)}), 500
        
# Route pour récupérer l'image générée
@app.route('/api/get_image/<image_id>', methods=['GET'])
def get_image(image_id):
    # Construire le chemin complet de l'image
    image_path = image_id  # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde

    # Renvoyer l'image
    try:
        return send_file(image_path, mimetype='image/png')
    except FileNotFoundError:
        return jsonify({'error': 'Image not found'}), 404
        

if __name__ == "__main__":
    app.run(debug=True, host="0.0.0.0", port=7860)