import os
import tempfile
import time
from functools import lru_cache
from typing import Any
import boto3
import gradio as gr
import numpy as np
import rembg
import torch
from gradio_litmodel3d import LitModel3D
from PIL import Image
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
import sf3d.utils as sf3d_utils
from sf3d.system import SF3D

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse

import datetime

ACCESS = os.getenv("ACCESS")
SECRET = os.getenv("SECRET")
bedrock = boto3.client(service_name='bedrock', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
s3_client = boto3.client('s3',aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')


app = FastAPI()

rembg_session = rembg.new_session()

COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 1.6
COND_FOVY_DEG = 40
BACKGROUND_COLOR = [0.5, 0.5, 0.5]

# Cached. Doesn't change
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
    COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
)


model = SF3D.from_pretrained(
    "stabilityai/stable-fast-3d",
    config_name="config.yaml",
    weight_name="model.safetensors",
)
model.eval().cuda()

example_files = [
    os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
]


def run_model(input_image):
    start = time.time()
    with torch.no_grad():
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            model_batch = create_batch(input_image)
            model_batch = {k: v.cuda() for k, v in model_batch.items()}
            trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
            trimesh_mesh = trimesh_mesh[0]

    # Create new tmp file
    tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")

    trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)

    print("Generation took:", time.time() - start, "s")

    return tmp_file.name


def create_batch(input_image: Image) -> dict[str, Any]:
    img_cond = (
        torch.from_numpy(
            np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
            / 255.0
        )
        .float()
        .clip(0, 1)
    )
    mask_cond = img_cond[:, :, -1:]
    rgb_cond = torch.lerp(
        torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
    )

    batch_elem = {
        "rgb_cond": rgb_cond,
        "mask_cond": mask_cond,
        "c2w_cond": c2w_cond.unsqueeze(0),
        "intrinsic_cond": intrinsic.unsqueeze(0),
        "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
    }
    # Add batch dim
    batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
    return batched


@lru_cache
def checkerboard(squares: int, size: int, min_value: float = 0.5):
    base = np.zeros((squares, squares)) + min_value
    base[1::2, ::2] = 1
    base[::2, 1::2] = 1

    repeat_mult = size // squares
    return (
        base.repeat(repeat_mult, axis=0)
        .repeat(repeat_mult, axis=1)[:, :, None]
        .repeat(3, axis=-1)
    )


def remove_background(input_image: Image) -> Image:
    return rembg.remove(input_image, session=rembg_session)


def resize_foreground(
    image: Image,
    ratio: float,
) -> Image:
    image = np.array(image)
    assert image.shape[-1] == 4
    alpha = np.where(image[..., 3] > 0)
    y1, y2, x1, x2 = (
        alpha[0].min(),
        alpha[0].max(),
        alpha[1].min(),
        alpha[1].max(),
    )
    # crop the foreground
    fg = image[y1:y2, x1:x2]
    # pad to square
    size = max(fg.shape[0], fg.shape[1])
    ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
    ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
    new_image = np.pad(
        fg,
        ((ph0, ph1), (pw0, pw1), (0, 0)),
        mode="constant",
        constant_values=((0, 0), (0, 0), (0, 0)),
    )

    # compute padding according to the ratio
    new_size = int(new_image.shape[0] / ratio)
    # pad to size, double side
    ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
    ph1, pw1 = new_size - size - ph0, new_size - size - pw0
    new_image = np.pad(
        new_image,
        ((ph0, ph1), (pw0, pw1), (0, 0)),
        mode="constant",
        constant_values=((0, 0), (0, 0), (0, 0)),
    )
    new_image = Image.fromarray(new_image, mode="RGBA").resize(
        (COND_WIDTH, COND_HEIGHT)
    )
    return new_image


def square_crop(input_image: Image) -> Image:
    # Perform a center square crop
    min_size = min(input_image.size)
    left = (input_image.size[0] - min_size) // 2
    top = (input_image.size[1] - min_size) // 2
    right = (input_image.size[0] + min_size) // 2
    bottom = (input_image.size[1] + min_size) // 2
    return input_image.crop((left, top, right, bottom)).resize(
        (COND_WIDTH, COND_HEIGHT)
    )


def show_mask_img(input_image: Image) -> Image:
    img_numpy = np.array(input_image)
    alpha = img_numpy[:, :, 3] / 255.0
    chkb = checkerboard(32, 512) * 255
    new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
    return Image.fromarray(new_img.astype(np.uint8), mode="RGB")



def upload_file_to_s3(file_path, bucket_name, object_name=None):
    try:
        s3_client.upload_file(file_path, bucket_name, object_name)
        return True
    except NoCredentialsError:
        print("Credentials not available")
    except PartialCredentialsError:
        print("Incomplete credentials provided")
    except boto3.exceptions.S3UploadFailedError as e:
        print(f"Upload failed: {e}")
    return False

    
@app.post("/process-image/")
async def process_image(file: UploadFile = File(...), foreground_ratio: float = 0.85):
    input_image = Image.open(file.file).convert("RGBA")
    rem_removed = remove_background(input_image)
    sqr_crop = square_crop(rem_removed)
    fr_res = resize_foreground(sqr_crop, foreground_ratio)
    glb_file = run_model(fr_res)


    timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
    object_name = f'object_{timestamp}.glb'

    if upload_file_to_s3(glb_file, 'framebucket3d',object_name):
        return {
            "glb2_path": f"https://framebucket3d.s3.amazonaws.com/{object_name}"
        }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)