import random
from typing import List

import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin

# from videoswap.utils.registry import MODEL_REGISTRY


class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, mid_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, mid_dim, bias=True),
            nn.SiLU(inplace=False),
            nn.Linear(mid_dim, out_dim, bias=True)
        )

    def forward(self, x):
        return self.mlp(x)


def bilinear_interpolation(level_adapter_state, x, y, frame_idx, interpolated_value):
    # level_adapter_state: (frames, channels, h, w)
    # note the boundary
    x1 = int(x)
    y1 = int(y)
    x2 = x1 + 1
    y2 = y1 + 1
    x_frac = x - x1
    y_frac = y - y1

    x1, x2 = max(min(x1, level_adapter_state.shape[3] - 1), 0), max(min(x2, level_adapter_state.shape[3] - 1), 0)
    y1, y2 = max(min(y1, level_adapter_state.shape[2] - 1), 0), max(min(y2, level_adapter_state.shape[2] - 1), 0)

    w11 = (1 - x_frac) * (1 - y_frac)
    w21 = x_frac * (1 - y_frac)
    w12 = (1 - x_frac) * y_frac
    w22 = x_frac * y_frac

    level_adapter_state[frame_idx, :, y1, x1] += interpolated_value * w11
    level_adapter_state[frame_idx, :, y1, x2] += interpolated_value * w21
    level_adapter_state[frame_idx, :, y2, x1] += interpolated_value * w12
    level_adapter_state[frame_idx, :, y2, x2] += interpolated_value * w22

    return level_adapter_state


# @MODEL_REGISTRY.register()
class SparsePointAdapter(ModelMixin, ConfigMixin):

    @register_to_config
    def __init__(
        self,
        embedding_channels=1280,
        channels=[320, 640, 1280, 1280],
        downsample_rate=[8, 16, 32, 64],
        mid_dim=128,
    ):
        super().__init__()

        self.model_list = nn.ModuleList()

        for ch in channels:
            self.model_list.append(MLP(embedding_channels, ch, mid_dim))

        self.downsample_rate = downsample_rate
        self.channels = channels
        self.radius = 2

    def generate_loss_mask(self, point_index_list, point_tracker, num_frames, h, w, loss_type):
        if loss_type == 'global':
            # True
            loss_mask = torch.ones((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0]))
        else:
            # only compute loss for visible points, with a radius that is irrelevant of the downsampling scale
            loss_mask = torch.zeros((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0]))
            for point_idx in point_index_list:
                for frame_idx in range(num_frames):
                    px, py = point_tracker[frame_idx, point_idx]

                    if px < 0 or py < 0:
                        continue
                    else:
                        px, py = px / self.downsample_rate[0], py / self.downsample_rate[0]

                        x1 = int(px) - self.radius
                        y1 = int(py) - self.radius
                        x2 = int(px) + self.radius
                        y2 = int(py) + self.radius

                        x1, x2 = max(min(x1, loss_mask.shape[3] - 1), 0), max(min(x2, loss_mask.shape[3] - 1), 0)
                        y1, y2 = max(min(y1, loss_mask.shape[2] - 1), 0), max(min(y2, loss_mask.shape[2] - 1), 0)

                        loss_mask[:, :, y1:y2, x1:x2] = 1.0
        return loss_mask

    def forward(self, point_tracker, size, point_embedding, index_list=None, drop_rate=0.0, loss_type='global') -> List[torch.Tensor]:

        # # (1, frames, num_points, 2) -> (frames, num_points, 2)
        # point_tracker = point_tracker.squeeze(0)
        # # (1, num_points, 1280) -> (num_points, 1280)
        # point_embedding = point_embedding.squeeze(0)

        w, h = size
        num_frames, num_points = point_tracker.shape[:2]

        if self.training:
            point_index_list = [point_idx for point_idx in range(num_points) if random.random() > drop_rate]
            loss_mask = self.generate_loss_mask(point_index_list, point_tracker, num_frames, h, w, loss_type)
        else:
            point_index_list = [point_idx for point_idx in range(num_points) if index_list is None or point_idx in index_list]

        adapter_state = []
        for level_idx, module in enumerate(self.model_list):

            downsample_rate = self.downsample_rate[level_idx]
            level_w, level_h = w // downsample_rate, h // downsample_rate

            # e.g. (num_points, 1280) -> (num_points, 320) 
            point_feat = module(point_embedding)

            level_adapter_state = torch.zeros((num_frames, self.channels[level_idx], level_h, level_w)).to(point_feat.device, dtype=point_feat.dtype)

            for point_idx in point_index_list:

                for frame_idx in range(num_frames):
                    px, py = point_tracker[frame_idx, point_idx]

                    if px < 0 or py < 0:
                        continue
                    else:
                        px, py = px / downsample_rate, py / downsample_rate
                        level_adapter_state = bilinear_interpolation(level_adapter_state, px, py, frame_idx, point_feat[point_idx])
            adapter_state.append(level_adapter_state)

        if self.training:
            return adapter_state, loss_mask
        else:
            return adapter_state