from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from einops import rearrange
import random

def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
    x_coord = torch.arange(kernel_size)
    gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
    gaussian_1d = gaussian_1d / gaussian_1d.sum()
    gaussian_3d = gaussian_1d[:, None, None] * gaussian_1d[None, :, None] * gaussian_1d[None, None, :]
    kernel = gaussian_3d[None, None, :, :, :].repeat(channels, 1, 1, 1, 1)
    
    return kernel

def gaussian_filter(latents, kernel_size=3, sigma=1.0):
    channels = latents.shape[0]
    kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
    blurred_latents = F.conv3d(latents.unsqueeze(0), kernel, padding=kernel_size//2, groups=channels)[0]
    
    return blurred_latents
    
def get_views(height, width, h_window_size=64, w_window_size=64, scale_factor=8):
    height = int(height)
    width = int(width)
    h_window_stride = h_window_size // 2
    w_window_stride = w_window_size // 2
    h_window_size = int(h_window_size / scale_factor)
    w_window_size = int(w_window_size / scale_factor)
    h_window_stride = int(h_window_stride / scale_factor)
    w_window_stride = int(w_window_stride / scale_factor)
    num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
    num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * h_window_stride)
        h_end = h_start + h_window_size
        w_start = int((i % num_blocks_width) * w_window_stride)
        w_end = w_start + w_window_size

        if h_end > height:
            h_start = int(h_start + height - h_end)
            h_end = int(height)
        if w_end > width:
            w_start = int(w_start + width - w_end)
            w_end = int(width)
        if h_start < 0:
            h_end = int(h_end - h_start)
            h_start = 0
        if w_start < 0:
            w_end = int(w_end - w_start)
            w_start = 0

        random_jitter = True
        if random_jitter:
            h_jitter_range = h_window_size // 8
            w_jitter_range = w_window_size // 8
            h_jitter = 0
            w_jitter = 0
            
            if (w_start != 0) and (w_end != width):
                w_jitter = random.randint(-w_jitter_range, w_jitter_range)
            elif (w_start == 0) and (w_end != width):
                w_jitter = random.randint(-w_jitter_range, 0)
            elif (w_start != 0) and (w_end == width):
                w_jitter = random.randint(0, w_jitter_range)
            if (h_start != 0) and (h_end != height):
                h_jitter = random.randint(-h_jitter_range, h_jitter_range)
            elif (h_start == 0) and (h_end != height):
                h_jitter = random.randint(-h_jitter_range, 0)
            elif (h_start != 0) and (h_end == height):
                h_jitter = random.randint(0, h_jitter_range)
            h_start += (h_jitter + h_jitter_range)
            h_end += (h_jitter + h_jitter_range)
            w_start += (w_jitter + w_jitter_range)
            w_end += (w_jitter + w_jitter_range)

        views.append((h_start, h_end, w_start, w_end))
    return views

def scale_forward(
    self,
    hidden_states: torch.FloatTensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    timestep: Optional[torch.LongTensor] = None,
    cross_attention_kwargs: Dict[str, Any] = None,
    class_labels: Optional[torch.LongTensor] = None,
):
    # Notice that normalization is always applied before the real computation in the following blocks.
    if self.current_hw:
        current_scale_num_h, current_scale_num_w = self.current_hw[0] // 512, self.current_hw[1] // 512
    else:
        current_scale_num_h, current_scale_num_w = 1, 1

    # 0. Self-Attention
    if self.use_ada_layer_norm:
        norm_hidden_states = self.norm1(hidden_states, timestep)
    elif self.use_ada_layer_norm_zero:
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
            hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
        )
    else:
        norm_hidden_states = self.norm1(hidden_states)

    # 2. Prepare GLIGEN inputs
    cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
    gligen_kwargs = cross_attention_kwargs.pop("gligen", None)

    ratio_hw = current_scale_num_h / current_scale_num_w
    latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5)
    latent_w = int(latent_h / ratio_hw)
    scale_factor = 64 * current_scale_num_h / latent_h
    if ratio_hw > 1:
        sub_h = 64
        sub_w = int(64 / ratio_hw)
    else:
        sub_h = int(64 * ratio_hw)
        sub_w = 64 

    h_jitter_range = int(sub_h / scale_factor // 8)
    w_jitter_range = int(sub_w / scale_factor // 8)
    views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor)

    current_scale_num = max(current_scale_num_h, current_scale_num_w)
    global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)]

    if self.fast_mode:
        four_window = False
        fourg_window = True
    else:
        four_window = True
        fourg_window = False

    if four_window:
        norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
        norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
        value = torch.zeros_like(norm_hidden_states_)
        count = torch.zeros_like(norm_hidden_states_)
        for index, view in enumerate(views):
            h_start, h_end, w_start, w_end = view
            local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
            local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
            local_output = self.attn1(
                local_states,
                encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
            )
            local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))

            value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
            count[:, h_start:h_end, w_start:w_end, :] += 1

        value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
        count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
        attn_output = torch.where(count>0, value/count, value)
        
        gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)

        attn_output_global = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
        attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)

        gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)

        attn_output = gaussian_local + (attn_output_global - gaussian_global)
        attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')

    elif fourg_window:
        norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
        norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
        value = torch.zeros_like(norm_hidden_states_)
        count = torch.zeros_like(norm_hidden_states_)
        for index, view in enumerate(views):
            h_start, h_end, w_start, w_end = view
            local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
            local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
            local_output = self.attn1(
                local_states,
                encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
            )
            local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))

            value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
            count[:, h_start:h_end, w_start:w_end, :] += 1

        value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
        count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
        attn_output = torch.where(count>0, value/count, value)
        
        gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)

        value = torch.zeros_like(norm_hidden_states)
        count = torch.zeros_like(norm_hidden_states)
        for index, global_view in enumerate(global_views):
            h, w = global_view
            global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :]
            global_states = rearrange(global_states, 'bh h w d -> bh (h w) d')
            global_output = self.attn1(
                global_states,
                encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
                attention_mask=attention_mask,
                **cross_attention_kwargs,
            )
            global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5))

            value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1
            count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1

        attn_output_global = torch.where(count>0, value/count, value)

        gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)

        attn_output = gaussian_local + (attn_output_global - gaussian_global)
        attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
        
    else:
        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )

    if self.use_ada_layer_norm_zero:
        attn_output = gate_msa.unsqueeze(1) * attn_output
    hidden_states = attn_output + hidden_states

    # 2.5 GLIGEN Control
    if gligen_kwargs is not None:
        hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
    # 2.5 ends

    # 3. Cross-Attention
    if self.attn2 is not None:
        norm_hidden_states = (
            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
        )
        attn_output = self.attn2(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=encoder_attention_mask,
            **cross_attention_kwargs,
        )
        hidden_states = attn_output + hidden_states

    # 4. Feed-forward
    norm_hidden_states = self.norm3(hidden_states)

    if self.use_ada_layer_norm_zero:
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

    if self._chunk_size is not None:
        # "feed_forward_chunk_size" can be used to save memory
        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
            raise ValueError(
                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
            )

        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
        ff_output = torch.cat(
            [
                self.ff(hid_slice)
                for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
            ],
            dim=self._chunk_dim,
        )
    else:
        ff_output = self.ff(norm_hidden_states)

    if self.use_ada_layer_norm_zero:
        ff_output = gate_mlp.unsqueeze(1) * ff_output

    hidden_states = ff_output + hidden_states

    return hidden_states

def ori_forward(
    self,
    hidden_states: torch.FloatTensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    timestep: Optional[torch.LongTensor] = None,
    cross_attention_kwargs: Dict[str, Any] = None,
    class_labels: Optional[torch.LongTensor] = None,
):
    # Notice that normalization is always applied before the real computation in the following blocks.
    # 0. Self-Attention
    if self.use_ada_layer_norm:
        norm_hidden_states = self.norm1(hidden_states, timestep)
    elif self.use_ada_layer_norm_zero:
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
            hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
        )
    else:
        norm_hidden_states = self.norm1(hidden_states)

    # 2. Prepare GLIGEN inputs
    cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
    gligen_kwargs = cross_attention_kwargs.pop("gligen", None)

    attn_output = self.attn1(
        norm_hidden_states,
        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
        attention_mask=attention_mask,
        **cross_attention_kwargs,
    )

    if self.use_ada_layer_norm_zero:
        attn_output = gate_msa.unsqueeze(1) * attn_output
    hidden_states = attn_output + hidden_states

    # 2.5 GLIGEN Control
    if gligen_kwargs is not None:
        hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
    # 2.5 ends

    # 3. Cross-Attention
    if self.attn2 is not None:
        norm_hidden_states = (
            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
        )
        attn_output = self.attn2(
            norm_hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=encoder_attention_mask,
            **cross_attention_kwargs,
        )
        hidden_states = attn_output + hidden_states

    # 4. Feed-forward
    norm_hidden_states = self.norm3(hidden_states)

    if self.use_ada_layer_norm_zero:
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

    if self._chunk_size is not None:
        # "feed_forward_chunk_size" can be used to save memory
        if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
            raise ValueError(
                f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
            )

        num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
        ff_output = torch.cat(
            [
                self.ff(hid_slice)
                for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
            ],
            dim=self._chunk_dim,
        )
    else:
        ff_output = self.ff(norm_hidden_states)

    if self.use_ada_layer_norm_zero:
        ff_output = gate_mlp.unsqueeze(1) * ff_output

    hidden_states = ff_output + hidden_states

    return hidden_states