import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
import torch
from typing import Optional
from tqdm import tqdm
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import gc
import gradio as gr
import numpy as np
import os
import pickle
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
import argparse

weights = {
    'down': {
        4096: 0.0,
        1024: 1.0,
        256: 1.0,
    },
    'mid': {
        64: 1.0,
    },
    'up': {
        256: 1.0,
        1024: 1.0,
        4096: 0.0,
    }
}
num_inference_steps = 10
model_id = "stabilityai/stable-diffusion-2-1-base"

pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")

safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

should_stop = False

def save_state_to_file(state):
    filename = "state.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(state, f) 
    return filename

def load_state_from_file(filename):
    with open(filename, 'rb') as f:
        state = pickle.load(f) 
    return state 

def stop_reconstruct():
  global should_stop
  should_stop = True

def reconstruct(input_img, caption):

  img = input_img

  cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
  uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

  prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])


  transform = torchvision.transforms.Compose([
      torchvision.transforms.Resize((512, 512)),
      torchvision.transforms.ToTensor()
  ])

  loaded_image = transform(img).to("cuda").unsqueeze(0)

  if loaded_image.shape[1] == 4:
      loaded_image = loaded_image[:,:3,:,:]

  with torch.no_grad():
      encoded_image = pipe.vae.encode(loaded_image*2 - 1)
      real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()

  guidance_scale = 1
  inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
  timesteps = inverse_scheduler.timesteps

  latents = real_image_latents

  inversed_latents = []

  with torch.no_grad():

      replace_attention_processor(pipe.unet, True)

      for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

          inversed_latents.append(latents)

          latent_model_input = torch.cat([latents] * 2)

          noise_pred = pipe.unet(
              latent_model_input,
              t,
              encoder_hidden_states=prompt_embeds_combined,
              cross_attention_kwargs=None,
              return_dict=False,
          )[0]


          noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
          noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

          latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]


  # initial state
  real_image_initial_latents = latents

  W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)
  QT = nn.Parameter(W_values.clone())


  guidance_scale = 7.5
  scheduler.set_timesteps(num_inference_steps, device="cuda")
  timesteps = scheduler.timesteps

  optimizer = torch.optim.AdamW([QT], lr=0.008)

  pipe.vae.eval()
  pipe.vae.requires_grad_(False)
  pipe.unet.eval()
  pipe.unet.requires_grad_(False)

  last_loss = 1

  for epoch in range(50):
      gc.collect()
      torch.cuda.empty_cache()

      if last_loss < 0.02:
          break
      elif last_loss < 0.03:
          for param_group in optimizer.param_groups:
              param_group['lr'] = 0.003
      elif last_loss < 0.035:
          for param_group in optimizer.param_groups:
              param_group['lr'] = 0.006

      intermediate_values = real_image_initial_latents.clone()


      for i in range(num_inference_steps):
          latents = intermediate_values.detach().clone()

          t = timesteps[i]

          prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])

          latent_model_input = torch.cat([latents] * 2)

          noise_pred_model = pipe.unet(
              latent_model_input,
              t,
              encoder_hidden_states=prompt_embeds,
              cross_attention_kwargs=None,
              return_dict=False,
          )[0]

          noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
          noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

          intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]


          loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean")
          last_loss = loss

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

      global should_stop
      if should_stop:
        should_stop = False
        break

      image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
      image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
      safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
      image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
      image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
      image_np = (image_np * 255).astype(np.uint8)

      yield image_np, caption, [caption, real_image_initial_latents, QT]

  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
  image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
  safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
  image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
  image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
  image_np = (image_np * 255).astype(np.uint8)
  
  yield image_np, caption, [caption, real_image_initial_latents, QT]


class AttnReplaceProcessor(AttnProcessor2_0):

    def __init__(self, replace_all, weight):
        super().__init__()
        self.replace_all = replace_all
        self.weight = weight

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:

        residual = hidden_states

        is_cross = not encoder_hidden_states is None

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, _, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))

        dimension_squared = hidden_states.shape[1]

        if not is_cross and (self.replace_all):
            ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
            attn_scores_dst.copy_(self.weight[dimension_squared] * attn_scores_src + (1.0 - self.weight[dimension_squared]) * attn_scores_dst)
            ucond_attn_scores_dst.copy_(self.weight[dimension_squared] * ucond_attn_scores_src + (1.0 - self.weight[dimension_squared]) * ucond_attn_scores_dst)

        attention_probs = attention_scores.softmax(dim=-1)
        del attention_scores

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        del attention_probs

        hidden_states = attn.to_out[0](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

def replace_attention_processor(unet, clear = False):

  for name, module in unet.named_modules():
    if 'attn1' in name and 'to' not in name:
        layer_type = name.split('.')[0].split('_')[0]

        if not clear:
          if layer_type == 'down':
              module.processor = AttnReplaceProcessor(True, weights['down'])
          elif layer_type == 'mid':
              module.processor = AttnReplaceProcessor(True, weights['mid'])
          elif layer_type == 'up':
              module.processor = AttnReplaceProcessor(True, weights['up'])
        else:
          module.processor = AttnReplaceProcessor(False, 0.0)

def apply_prompt(meta_data, new_prompt):

  caption, real_image_initial_latents, QT = meta_data

  inference_steps = len(QT)

  cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
#   uncond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
  new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

  guidance_scale = 7.5
  scheduler.set_timesteps(inference_steps, device="cuda")
  timesteps = scheduler.timesteps

  latents = torch.cat([real_image_initial_latents] * 2)

  with torch.no_grad():
    replace_attention_processor(pipe.unet)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=modified_prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    replace_attention_processor(pipe.unet, True)

    image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
    image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
    safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
    image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
    image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
    image_np = (image_np * 255).astype(np.uint8)
    
  return image_np



def on_image_change(filepath):
    # Extract the filename without extension
    filename = os.path.splitext(os.path.basename(filepath))[0]
    
    # Check if the filename is "example1" or "example2"
    if filename in ["example1", "example2", "example3", "example4"]:
        meta_data_raw = load_state_from_file(f"assets/{filename}.pkl")
        _, _, QT_raw = meta_data_raw

        global num_inference_steps
        num_inference_steps = len(QT_raw)
        scale_value = 7
        new_prompt = ""

        if filename == "example1":
            scale_value = 7
            new_prompt = "a photo of a tree, summer, colourful"
            
        elif filename == "example2":
            scale_value = 8
            new_prompt = "a photo of a panda, two ears, white background"

        elif filename == "example3":
            scale_value = 7
            new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"
            
        elif filename == "example4":
            scale_value = 7 
            new_prompt = "a photo of plastic bottle on some sand, beach background, sky background"

        update_scale(scale_value)
        img = apply_prompt(meta_data_raw, new_prompt)
            
    return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value

def update_value(value, key, res):
    global weights
    weights[key][res] = value

def update_step(value):
    global num_inference_steps
    num_inference_steps = value

def update_scale(scale):
    values = [1.0] * 7

    if scale == 9:
        return values
    
    reduction_steps = (9 - scale) * 0.5
    
    for i in range(4):  # There are 4 positions to reduce symmetrically
        if reduction_steps >= 1:
            values[i] = 0.0
            values[-(i + 1)] = 0.0
            reduction_steps -= 1
        elif reduction_steps > 0:
            values[i] = 0.5
            values[-(i + 1)] = 0.5
            break

    global weights
    index = 0

    for outer_key, inner_dict in weights.items():
        for inner_key in inner_dict:
            inner_dict[inner_key] = values[index]
            index += 1
    
    return weights['down'][4096], weights['down'][1024], weights['down'][256], weights['mid'][64], weights['up'][256], weights['up'][1024], weights['up'][4096]
            

with gr.Blocks() as demo:
    gr.Markdown(
            '''
            <div style="text-align: center;">
                <div style="display: flex; justify-content: center;">
                    <img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo">
                </div>
                <h1>Out of Focus 1.0</h1>
                <p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
            </div>
            <br>
            <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
                <a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a> &ensp;
                <a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Ashleigh%20Watson"></a> &ensp;
                <a href="https://twitter.com/banterless_ai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Alex%20Nasa"></a>
            </div>
            <br>
            <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
                <p style="display: flex;gap: 6px;">
                <a href="https://huggingface.co/spaces/fffiloni/OutofFocus?duplicate=true">
                    <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate this Space">
                </a> to skip the queue and enjoy faster inference on the GPU of your choice 
                </p>
            </div>
            '''
        )
    with gr.Row():
      with gr.Column():

          with gr.Row():
            example_input = gr.Image(height=512, width=512, type="filepath", visible=False)
            image_input = gr.Image(height=512, width=512, type="pil", label="Upload Source Image")
          steps_slider = gr.Slider(minimum=5, maximum=25, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
          prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
          reconstruct_button = gr.Button("Reconstruct")
          stop_button = gr.Button("Stop", variant="stop", interactive=False)
      with gr.Column():
        reconstructed_image = gr.Image(type="pil", label="Reconstructed")

        with gr.Row():
            invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
            interpolate_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
        with gr.Row():  
            new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or word addition at the end, achieve the best results by swapping words instead of adding or removing in between")
        with gr.Row():
            apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
        with gr.Row():
            with gr.Accordion(label="Advanced Options", open=False):
                    gr.Markdown(
                        '''
                        <div style="text-align: center;">
                            <h1>Weight Adjustment</h1>
                            <p style="font-size:16px;">Specific Cross-Attention Influence weights can be manually modified for given resolutions (1.0 = Fully Source Attn 0.0 = Fully Target Attn)</p>
                        </div>
                        '''
                    )
                    down_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][4096], label="Self-Attn Down 64x64")
                    down_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][1024], label="Self-Attn Down 32x32")
                    down_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][256], label="Self-Attn Down 16x16")
                    mid_slider_64 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['mid'][64], label="Self-Attn Mid 8x8")
                    up_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][256], label="Self-Attn Up 16x16")
                    up_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][1024], label="Self-Attn Up 32x32")
                    up_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][4096], label="Self-Attn Up 64x64")

        with gr.Row():
            show_case = gr.Examples(
                examples=[
                    ["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background"],
                    ["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful"], 
                    ["assets/example2.png", "a photo of a cat, two ears, white background", "a photo of a panda, two ears, white background"], 
                    ["assets/example3.png", "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"],
                     
                ],
                inputs=[example_input, prompt_input, new_prompt_input],
                label=None
            )

    meta_data = gr.State()

    example_input.change(
        fn=on_image_change,
        inputs=example_input,
        outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]
    ).then(
        lambda: gr.update(interactive=True),
        outputs=apply_button
    ).then(
        lambda: gr.update(interactive=True),
        outputs=new_prompt_input
    )
    steps_slider.release(update_step, inputs=steps_slider)
    interpolate_slider.release(update_scale, inputs=interpolate_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
    invisible_slider.change(update_scale, inputs=invisible_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])

    up_slider_4096.change(update_value, inputs=[up_slider_4096, gr.State('up'), gr.State(4096)])
    up_slider_1024.change(update_value, inputs=[up_slider_1024, gr.State('up'), gr.State(1024)])
    up_slider_256.change(update_value, inputs=[up_slider_256, gr.State('up'), gr.State(256)])

    down_slider_4096.change(update_value, inputs=[down_slider_4096, gr.State('down'), gr.State(4096)])
    down_slider_1024.change(update_value, inputs=[down_slider_1024, gr.State('down'), gr.State(1024)])
    down_slider_256.change(update_value, inputs=[down_slider_256, gr.State('down'), gr.State(256)])

    mid_slider_64.change(update_value, inputs=[mid_slider_64, gr.State('mid'), gr.State(64)])

    reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, meta_data]).then(
        lambda: gr.update(interactive=True),
        outputs=reconstruct_button
    ).then(
        lambda: gr.update(interactive=True),
        outputs=new_prompt_input
    ).then(
        lambda: gr.update(interactive=True),
        outputs=apply_button
    ).then(
        lambda: gr.update(interactive=False),
        outputs=stop_button
    )

    reconstruct_button.click(
        lambda: gr.update(interactive=False),
        outputs=reconstruct_button
    )

    reconstruct_button.click(
        lambda: gr.update(interactive=True),
        outputs=stop_button
    )

    reconstruct_button.click(
        lambda: gr.update(interactive=False),
        outputs=apply_button
    )

    stop_button.click(
        lambda: gr.update(interactive=False),
        outputs=stop_button
    )

    apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
    stop_button.click(stop_reconstruct)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true")
    args = parser.parse_args()
    demo.queue()
    demo.launch(share=args.share)