# app_api.py
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import base64
from io import BytesIO
import random
import numpy as np
import cv2
from PIL import Image
import torch
from segment_anything import SamPredictor, sam_model_registry
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
from celery import states
from celery_app import celery_app
from celery.result import AsyncResult
import requests

torch.set_num_threads(1)

# Initialize FastAPI
app = FastAPI()

# Device configuration function
def get_device():
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"

# Device configuration
DEVICE = get_device()

# Initialize models at the module level
# Load models
mobile_sam = sam_model_registry['vit_h'](checkpoint='data/ckpt/sam_vit_h_4b8939.pth').to(DEVICE)
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)

# Load BrushNet and Stable Diffusion pipeline
base_model_path = "data/ckpt/realisticVisionV60B1_v51VAE"  # Update to your base model path
brushnet_path = "data/ckpt/segmentation_mask_brushnet_ckpt"  # Update to your BrushNet checkpoint path

brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16, device=DEVICE)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
    base_model_path,
    brushnet=brushnet,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=False,
    device=DEVICE
)

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload(device=DEVICE)

# Pydantic models for request and response validation
class InpaintingRequest(BaseModel):
    input_image: str  # Base64-encoded image
    input_mask: Optional[str] = None  # Base64-encoded mask image
    prompt: str
    negative_prompt: Optional[str] = 'ugly, low quality'
    control_strength: float = 1.0
    guidance_scale: float = 12.0
    num_inference_steps: int = 50
    seed: int = 0  # Set default seed to 0
    randomize_seed: bool = False
    blended: bool = False
    invert_mask: bool = True
    count: int = 1
    webhook_url: Optional[str] = None  # URL to send the webhook to

class InpaintingResponse(BaseModel):
    images: List[str]  # Base64-encoded images

# Function to resize images
def resize_image(input_image, resolution):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img

# Main processing function
def process_api(request: InpaintingRequest):
    global mobile_predictor, pipe, DEVICE

    # Decode the base64-encoded input image
    input_image_data = base64.b64decode(request.input_image)
    original_image = Image.open(BytesIO(input_image_data)).convert("RGB")
    original_image_np = np.array(original_image)

    # Resize the original image if its smallest dimension is greater than 512 pixels
    if min(original_image_np.shape[:2]) > 512:
        original_image_np = resize_image(original_image_np, 512)

    # Handle the mask image if provided
    if request.input_mask:
        input_mask_data = base64.b64decode(request.input_mask)
        input_mask_image = Image.open(BytesIO(input_mask_data)).convert("RGB")
        input_mask_np = np.array(input_mask_image)
    else:
        raise ValueError("Input mask is required")

    # Prepare other parameters
    if request.randomize_seed:
        seed = random.randint(0, 2147483647)
    else:
        seed = request.seed  # This will always be an int due to the default value

    generator = torch.Generator(DEVICE).manual_seed(seed)

    # Prepare mask and init images
    H, W = original_image_np.shape[:2]
    original_mask = cv2.resize(input_mask_np, (W, H))

    if request.invert_mask:
        original_mask = 255 - original_mask

    mask = 1.0 * (original_mask.sum(-1) > 255)[:, :, np.newaxis]
    masked_image = original_image_np * (1 - mask)

    init_image = Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
    mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("RGB")

    # Generate images
    images = pipe(
        [request.prompt] * request.count,
        init_image,
        mask_image,
        num_inference_steps=request.num_inference_steps,
        guidance_scale=request.guidance_scale,
        generator=generator,
        brushnet_conditioning_scale=float(request.control_strength),
        negative_prompt=[request.negative_prompt] * request.count,
    ).images

    # Blended images if requested
    if request.blended:
        if request.control_strength < 1.0:
            raise ValueError('Using blurred blending with control strength less than 1.0 is not allowed')
        blended_images = []
        mask_np = np.array(mask_image) / 255.0
        mask_blurred = cv2.GaussianBlur(mask_np, (21, 21), 0)
        mask_combined = 1 - (1 - mask_np) * (1 - mask_blurred)
        for img in images:
            img_np = np.array(img)
            image_pasted = original_image_np * (1 - mask_combined) + img_np * mask_combined
            image_pasted = image_pasted.astype(np.uint8)
            blended_images.append(Image.fromarray(image_pasted))
        images = blended_images

    # Encode images to base64
    output_images = []
    for img in images:
        buffered = BytesIO()
        img.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
        output_images.append(img_str)

    return output_images

# Celery task function
@celery_app.task(bind=True)
def generate_image_task(self, request_data):
    try:
        # Convert request_data to InpaintingRequest object
        request = InpaintingRequest(**request_data)
        # Call the processing function
        output_images = process_api(request)
        # If webhook URL is provided, send the images to the webhook
        webhook_url = request.webhook_url
        if webhook_url:
            payload = {'images': output_images}
            requests.post(webhook_url, json=payload)
        return output_images
    except Exception as e:
        # Optionally, log the error
        print(f"Exception in generate_image_task: {e}")
        # Re-raise the exception to let Celery handle it
        raise

# API endpoint to start the image generation task
@app.post("/inpaint")
def inpaint(request: InpaintingRequest):
    try:
        task = generate_image_task.delay(request.dict())
        return {"job_id": task.id}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# API endpoint to check task status
@app.get("/status/{job_id}")
def get_status(job_id: str):
    task = AsyncResult(job_id, app=celery_app)
    if task.state == states.PENDING:
        return {"status": "PENDING"}
    elif task.state == states.STARTED:
        return {"status": "STARTED"}
    elif task.state == states.SUCCESS:
        return {"status": "SUCCESS"}
    elif task.state == states.FAILURE:
        # Extract the error message
        error_info = task.info
        error_message = str(error_info)
        return {"status": "FAILURE", "error": error_message}
    else:
        return {"status": str(task.state)}

# API endpoint to retrieve results
@app.get("/result/{job_id}")
def get_result(job_id: str):
    task = AsyncResult(job_id, app=celery_app)
    if task.state == states.SUCCESS:
        return {"images": task.result}
    elif task.state == states.FAILURE:
        raise HTTPException(status_code=500, detail=str(task.info))
    else:
        raise HTTPException(status_code=202, detail="Task not completed yet.")

# Run the app
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app_api:app", host="0.0.0.0", port=8000)