# Edit Anything trained with Stable Diffusion + ControlNet + SAM  + BLIP2
# pip install mmcv

from torchvision.utils import save_image
from PIL import Image
import subprocess
from collections import OrderedDict
import numpy as np
import cv2
import textwrap
import torch
import os
from annotator.util import resize_image, HWC3
import mmcv
import random

# device = "cuda" if torch.cuda.is_available() else "cpu" # > 15GB GPU memory required
device = "cpu"
use_blip = True
use_gradio = True

if device == 'cpu':
    data_type = torch.float32
else:
    data_type = torch.float16
# Diffusion init using diffusers.

# diffusers==0.14.0 required.
from diffusers.utils import load_image

base_model_path = "stabilityai/stable-diffusion-2-inpainting"
config_dict = OrderedDict([('SAM Pretrained(v0-1): Good Natural Sense', 'shgao/edit-anything-v0-1-1'),
                        ('LAION Pretrained(v0-3): Good Face', 'shgao/edit-anything-v0-3'),
                        ('SD Inpainting: Not keep position', 'stabilityai/stable-diffusion-2-inpainting')
                        ])

# Segment-Anything init.
# pip install git+https://github.com/facebookresearch/segment-anything.git
try:
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except ImportError:
    print('segment_anything not installed')
    result = subprocess.run(['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
    print(f'Install segment_anything {result}')   
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
    result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
    print(f'Download sam_vit_h_4b8939.pth {result}')   
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)


# BLIP2 init.
if use_blip:
    # need the latest transformers
    # pip install git+https://github.com/huggingface/transformers.git
    from transformers import AutoProcessor, Blip2ForConditionalGeneration
    processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
    blip_model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", torch_dtype=data_type)


def region_classify_w_blip2(image):
    inputs = processor(image, return_tensors="pt").to(device, data_type)
    generated_ids = blip_model.generate(**inputs, max_new_tokens=15)
    generated_text = processor.batch_decode(
        generated_ids, skip_special_tokens=True)[0].strip()
    return generated_text

def region_level_semantic_api(image, topk=5):
    """
    rank regions by area, and classify each region with blip2
    Args:
        image: numpy array
        topk: int
    Returns:
        topk_region_w_class_label: list of dict with key 'class_label'
    """
    topk_region_w_class_label = []
    anns = mask_generator.generate(image)
    if len(anns) == 0:
        return []
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    for i in range(min(topk, len(sorted_anns))):
        ann = anns[i]
        m = ann['segmentation']
        m_3c = m[:,:, np.newaxis]
        m_3c = np.concatenate((m_3c,m_3c,m_3c), axis=2)
        bbox = ann['bbox']
        region = mmcv.imcrop(image*m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1)
        region_class_label = region_classify_w_blip2(region)
        ann['class_label'] = region_class_label
        print(ann['class_label'], str(bbox))
        topk_region_w_class_label.append(ann)
    return topk_region_w_class_label

def show_semantic_image_label(anns):
    """
    show semantic image label for each region
    Args:
        anns: list of dict with key 'class_label'
    Returns:
        full_img: numpy array
    """
    full_img = None
    # generate mask image
    for i in range(len(anns)):
        m = anns[i]['segmentation']
        if full_img is None:
            full_img = np.zeros((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        full_img[m != 0] = color_mask
    full_img = full_img*255
    # add text on this mask image
    for i in range(len(anns)):
        m = anns[i]['segmentation']
        class_label = anns[i]['class_label']
        # add text to region
        # Calculate the centroid of the region to place the text
        y, x = np.where(m != 0)
        x_center, y_center = int(np.mean(x)), int(np.mean(y))

        # Split the text into multiple lines
        max_width = 20  # Adjust this value based on your preferred maximum width
        wrapped_text = textwrap.wrap(class_label, width=max_width)

        # Add text to region
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1.2
        font_thickness = 2
        font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))  # red
        line_spacing = 40  # Adjust this value based on your preferred line

        for idx, line in enumerate(wrapped_text):
            y_offset = y_center - (len(wrapped_text) - 1) * line_spacing // 2 + idx * line_spacing
            text_size = cv2.getTextSize(line, font, font_scale, font_thickness)[0]
            x_offset = x_center - text_size[0] // 2
            # Draw the text multiple times with small offsets to create a bolder appearance
            offsets = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
            for off_x, off_y in offsets:
                cv2.putText(full_img, line, (x_offset + off_x, y_offset + off_y), font, font_scale, font_color, font_thickness, cv2.LINE_AA)

    return full_img



image_path = "images/sa_224577.jpg"
input_image = Image.open(image_path)
detect_resolution=1024
input_image = resize_image(np.array(input_image, dtype=np.uint8), detect_resolution)
region_level_annots = region_level_semantic_api(input_image, topk=5)
output = show_semantic_image_label(region_level_annots)

image_list = []
input_image = resize_image(input_image, 512)
output = resize_image(output, 512)
input_image = np.array(input_image, dtype=np.uint8)
output = np.array(output, dtype=np.uint8)
image_list.append(torch.tensor(input_image).float())
image_list.append(torch.tensor(output).float())
for each in image_list:
    print(each.shape, type(each))
    print(each.max(), each.min())


image_list = torch.stack(image_list).permute(0, 3, 1, 2)
print(image_list.shape)

save_image(image_list, "images/sample_semantic.jpg", nrow=2,
        normalize=True)