import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from einops import rearrange
from torch import Tensor

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.max_len = max_len
        self.d_model = d_model
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)
        

    def forward(self) -> Tensor:
        x = self.pe[0, : self.max_len]
        return self.dropout(x).unsqueeze(0)


class ResNetFeatureExtractor(nn.Module):
    def __init__(self, hidden_dim = 512):
        super().__init__()

        # Making the resnet 50 model, which was used in the docformer for the purpose of visual feature extraction

        resnet50 = models.resnet50(pretrained=False)
        modules = list(resnet50.children())[:-2]
        self.resnet50 = nn.Sequential(*modules)

        # Applying convolution and linear layer

        self.conv1 = nn.Conv2d(2048, 768, 1)
        self.relu1 = F.relu
        self.linear1 = nn.Linear(192, hidden_dim)

    def forward(self, x):
        x = self.resnet50(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = rearrange(x, "b e w h -> b e (w h)")  # b -> batch, e -> embedding dim, w -> width, h -> height
        x = self.linear1(x)
        x = rearrange(x, "b e s -> b s e")  # b -> batch, e -> embedding dim, s -> sequence length
        return x

class DocFormerEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super(DocFormerEmbeddings, self).__init__()

        self.config = config

        self.position_embeddings_v = PositionalEncoding(
            d_model=config["hidden_size"],
            dropout=0.1,
            max_len=config["max_position_embeddings"],
        )

        self.x_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.x_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.w_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
        self.x_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.x_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"]  + 1, config["shape_size"])
        self.x_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.x_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.x_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])

        self.y_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.y_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.h_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
        self.y_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])

        self.position_embeddings_t = PositionalEncoding(
            d_model=config["hidden_size"],
            dropout=0.1,
            max_len=config["max_position_embeddings"],
        )

        self.x_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.x_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.w_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
        self.x_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
        self.x_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
        self.x_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.x_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.x_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])

        self.y_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.y_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
        self.h_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
        self.y_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
        self.y_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])

        self.LayerNorm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])



    def forward(self, x_feature, y_feature):

        """
        Arguments:
        x_features of shape, (batch size, seq_len, 8)
        y_features of shape, (batch size, seq_len, 8)
        Outputs:
        (V-bar-s, T-bar-s) of shape (batch size, 512,768),(batch size, 512,768)
        What are the features:
        0 -> top left x/y
        1 -> bottom right x/y
        2 -> width/height
        3 -> diff top left x/y
        4 -> diff bottom left x/y
        5 -> diff top right x/y
        6 -> diff bottom right x/y
        7 -> centroids diff x/y
        """


        batch, seq_len = x_feature.shape[:-1]
        hidden_size = self.config["hidden_size"]
        num_feat = x_feature.shape[-1]
        sub_dim = hidden_size // num_feat
        
        # Clamping and adding a bias for handling negative values
        x_feature[:,:,3:] = torch.clamp(x_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
        x_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]

        y_feature[:,:,3:] = torch.clamp(y_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
        y_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]
        
        x_topleft_position_embeddings_v = self.x_topleft_position_embeddings_v(x_feature[:,:,0])
        x_bottomright_position_embeddings_v = self.x_bottomright_position_embeddings_v(x_feature[:,:,1])
        w_position_embeddings_v = self.w_position_embeddings_v(x_feature[:,:,2])
        x_topleft_distance_to_prev_embeddings_v = self.x_topleft_distance_to_prev_embeddings_v(x_feature[:,:,3])
        x_bottomleft_distance_to_prev_embeddings_v = self.x_bottomleft_distance_to_prev_embeddings_v(x_feature[:,:,4])
        x_topright_distance_to_prev_embeddings_v = self.x_topright_distance_to_prev_embeddings_v(x_feature[:,:,5])
        x_bottomright_distance_to_prev_embeddings_v = self.x_bottomright_distance_to_prev_embeddings_v(x_feature[:,:,6])
        x_centroid_distance_to_prev_embeddings_v = self.x_centroid_distance_to_prev_embeddings_v(x_feature[:,:,7])

        x_calculated_embedding_v = torch.cat(
            [
             x_topleft_position_embeddings_v,
             x_bottomright_position_embeddings_v,
             w_position_embeddings_v,
             x_topleft_distance_to_prev_embeddings_v,
             x_bottomleft_distance_to_prev_embeddings_v,
             x_topright_distance_to_prev_embeddings_v,
             x_bottomright_distance_to_prev_embeddings_v ,
             x_centroid_distance_to_prev_embeddings_v
            ],
            dim = -1
        )

        y_topleft_position_embeddings_v = self.y_topleft_position_embeddings_v(y_feature[:,:,0])
        y_bottomright_position_embeddings_v = self.y_bottomright_position_embeddings_v(y_feature[:,:,1])
        h_position_embeddings_v = self.h_position_embeddings_v(y_feature[:,:,2])
        y_topleft_distance_to_prev_embeddings_v = self.y_topleft_distance_to_prev_embeddings_v(y_feature[:,:,3])
        y_bottomleft_distance_to_prev_embeddings_v = self.y_bottomleft_distance_to_prev_embeddings_v(y_feature[:,:,4])
        y_topright_distance_to_prev_embeddings_v = self.y_topright_distance_to_prev_embeddings_v(y_feature[:,:,5])
        y_bottomright_distance_to_prev_embeddings_v = self.y_bottomright_distance_to_prev_embeddings_v(y_feature[:,:,6])
        y_centroid_distance_to_prev_embeddings_v = self.y_centroid_distance_to_prev_embeddings_v(y_feature[:,:,7])

        x_calculated_embedding_v = torch.cat(
            [
             x_topleft_position_embeddings_v,
             x_bottomright_position_embeddings_v,
             w_position_embeddings_v,
             x_topleft_distance_to_prev_embeddings_v,
             x_bottomleft_distance_to_prev_embeddings_v,
             x_topright_distance_to_prev_embeddings_v,
             x_bottomright_distance_to_prev_embeddings_v ,
             x_centroid_distance_to_prev_embeddings_v
            ],
            dim = -1
        )

        y_calculated_embedding_v = torch.cat(
            [
             y_topleft_position_embeddings_v,
             y_bottomright_position_embeddings_v,
             h_position_embeddings_v,
             y_topleft_distance_to_prev_embeddings_v,
             y_bottomleft_distance_to_prev_embeddings_v,
             y_topright_distance_to_prev_embeddings_v,
             y_bottomright_distance_to_prev_embeddings_v ,
             y_centroid_distance_to_prev_embeddings_v
            ],
            dim = -1
        )

        v_bar_s = x_calculated_embedding_v + y_calculated_embedding_v + self.position_embeddings_v()



        x_topleft_position_embeddings_t = self.x_topleft_position_embeddings_t(x_feature[:,:,0])
        x_bottomright_position_embeddings_t = self.x_bottomright_position_embeddings_t(x_feature[:,:,1])
        w_position_embeddings_t = self.w_position_embeddings_t(x_feature[:,:,2])
        x_topleft_distance_to_prev_embeddings_t = self.x_topleft_distance_to_prev_embeddings_t(x_feature[:,:,3])
        x_bottomleft_distance_to_prev_embeddings_t = self.x_bottomleft_distance_to_prev_embeddings_t(x_feature[:,:,4])
        x_topright_distance_to_prev_embeddings_t = self.x_topright_distance_to_prev_embeddings_t(x_feature[:,:,5])
        x_bottomright_distance_to_prev_embeddings_t = self.x_bottomright_distance_to_prev_embeddings_t(x_feature[:,:,6])
        x_centroid_distance_to_prev_embeddings_t = self.x_centroid_distance_to_prev_embeddings_t(x_feature[:,:,7])

        x_calculated_embedding_t = torch.cat(
            [
             x_topleft_position_embeddings_t,
             x_bottomright_position_embeddings_t,
             w_position_embeddings_t,
             x_topleft_distance_to_prev_embeddings_t,
             x_bottomleft_distance_to_prev_embeddings_t,
             x_topright_distance_to_prev_embeddings_t,
             x_bottomright_distance_to_prev_embeddings_t ,
             x_centroid_distance_to_prev_embeddings_t
            ],
            dim = -1
        )

        y_topleft_position_embeddings_t = self.y_topleft_position_embeddings_t(y_feature[:,:,0])
        y_bottomright_position_embeddings_t = self.y_bottomright_position_embeddings_t(y_feature[:,:,1])
        h_position_embeddings_t = self.h_position_embeddings_t(y_feature[:,:,2])
        y_topleft_distance_to_prev_embeddings_t = self.y_topleft_distance_to_prev_embeddings_t(y_feature[:,:,3])
        y_bottomleft_distance_to_prev_embeddings_t = self.y_bottomleft_distance_to_prev_embeddings_t(y_feature[:,:,4])
        y_topright_distance_to_prev_embeddings_t = self.y_topright_distance_to_prev_embeddings_t(y_feature[:,:,5])
        y_bottomright_distance_to_prev_embeddings_t = self.y_bottomright_distance_to_prev_embeddings_t(y_feature[:,:,6])
        y_centroid_distance_to_prev_embeddings_t = self.y_centroid_distance_to_prev_embeddings_t(y_feature[:,:,7])

        x_calculated_embedding_t = torch.cat(
            [
             x_topleft_position_embeddings_t,
             x_bottomright_position_embeddings_t,
             w_position_embeddings_t,
             x_topleft_distance_to_prev_embeddings_t,
             x_bottomleft_distance_to_prev_embeddings_t,
             x_topright_distance_to_prev_embeddings_t,
             x_bottomright_distance_to_prev_embeddings_t ,
             x_centroid_distance_to_prev_embeddings_t
            ],
            dim = -1
        )

        y_calculated_embedding_t = torch.cat(
            [
             y_topleft_position_embeddings_t,
             y_bottomright_position_embeddings_t,
             h_position_embeddings_t,
             y_topleft_distance_to_prev_embeddings_t,
             y_bottomleft_distance_to_prev_embeddings_t,
             y_topright_distance_to_prev_embeddings_t,
             y_bottomright_distance_to_prev_embeddings_t ,
             y_centroid_distance_to_prev_embeddings_t
            ],
            dim = -1
        )

        t_bar_s = x_calculated_embedding_t + y_calculated_embedding_t + self.position_embeddings_t()
        
        return v_bar_s, t_bar_s



# fmt: off
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class PreNormAttn(nn.Module):
    def __init__(self, dim, fn):
        # Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
        super().__init__()

        self.norm_t_bar = nn.LayerNorm(dim)
        self.norm_v_bar = nn.LayerNorm(dim)
        self.norm_t_bar_s = nn.LayerNorm(dim)
        self.norm_v_bar_s = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, t_bar, v_bar, t_bar_s, v_bar_s, **kwargs):
        return self.fn(self.norm_t_bar(t_bar),
                       self.norm_v_bar(v_bar),
                       self.norm_t_bar_s(t_bar_s),
                       self.norm_v_bar_s(v_bar_s), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class RelativePosition(nn.Module):

    def __init__(self, num_units, max_relative_position, max_seq_length):
        super().__init__()
        self.num_units = num_units
        self.max_relative_position = max_relative_position
        self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
        self.max_length = max_seq_length
        range_vec_q = torch.arange(max_seq_length)
        range_vec_k = torch.arange(max_seq_length)
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
        distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
        final_mat = distance_mat_clipped + self.max_relative_position
        self.final_mat = torch.LongTensor(final_mat)
        nn.init.xavier_uniform_(self.embeddings_table)

    def forward(self, length_q, length_k):
        embeddings = self.embeddings_table[self.final_mat[:length_q, :length_k]]
        return embeddings


class MultiModalAttentionLayer(nn.Module):
    def __init__(self, embed_dim, n_heads, max_relative_position, max_seq_length, dropout):
        super().__init__()
        assert embed_dim % n_heads == 0

        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        self.relative_positions_text = RelativePosition(self.head_dim, max_relative_position, max_seq_length)
        self.relative_positions_img = RelativePosition(self.head_dim, max_relative_position, max_seq_length)

        # text qkv embeddings
        self.fc_k_text = nn.Linear(embed_dim, embed_dim)
        self.fc_q_text = nn.Linear(embed_dim, embed_dim)
        self.fc_v_text = nn.Linear(embed_dim, embed_dim)

        # image qkv embeddings
        self.fc_k_img = nn.Linear(embed_dim, embed_dim)
        self.fc_q_img = nn.Linear(embed_dim, embed_dim)
        self.fc_v_img = nn.Linear(embed_dim, embed_dim)

        # spatial qk embeddings (shared for visual and text)
        self.fc_k_spatial = nn.Linear(embed_dim, embed_dim)
        self.fc_q_spatial = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.scale = embed_dim**0.5

    def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):
        text_feat = text_feat
        img_feat = img_feat
        text_spatial_feat = text_spatial_feat
        img_spatial_feat = img_spatial_feat
        seq_length = text_feat.shape[1]

        # self attention of text
        # b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim.
        key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
        query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads)
        value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
        dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh) 
        dots_text = dots_text/ self.scale

        # 1D relative positions (query, key)
        rel_pos_embed_text = self.relative_positions_text(seq_length, seq_length)
        rel_pos_key_text = torch.einsum('bhrd,lrd->bhlr', key_text_nh, rel_pos_embed_text)
        rel_pos_query_text = torch.einsum('bhld,lrd->bhlr', query_text_nh, rel_pos_embed_text)

        # shared spatial <-> text hidden features
        key_spatial_text = self.fc_k_spatial(text_spatial_feat)
        query_spatial_text = self.fc_q_spatial(text_spatial_feat)
        key_spatial_text_nh = rearrange(key_spatial_text, 'b t (head k) -> head b t k', head=self.n_heads)
        query_spatial_text_nh = rearrange(query_spatial_text, 'b l (head k) -> head b l k', head=self.n_heads)
        dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh)
        dots_text_spatial = dots_text_spatial/ self.scale

        # Line 38 of pseudo-code
        text_attn_scores = dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial

        # self-attention of image
        key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
        query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads)
        value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
        dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh) 
        dots_img = dots_img/ self.scale

        # 1D relative positions (query, key)
        rel_pos_embed_img = self.relative_positions_img(seq_length, seq_length)
        rel_pos_key_img = torch.einsum('bhrd,lrd->bhlr', key_img_nh, rel_pos_embed_text)
        rel_pos_query_img = torch.einsum('bhld,lrd->bhlr', query_img_nh, rel_pos_embed_text)

        # shared spatial <-> image features
        key_spatial_img = self.fc_k_spatial(img_spatial_feat)
        query_spatial_img = self.fc_q_spatial(img_spatial_feat)
        key_spatial_img_nh = rearrange(key_spatial_img, 'b t (head k) -> head b t k', head=self.n_heads)
        query_spatial_img_nh = rearrange(query_spatial_img, 'b l (head k) -> head b l k', head=self.n_heads)
        dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh)
        dots_img_spatial = dots_img_spatial/ self.scale

        # Line 59 of pseudo-code
        img_attn_scores = dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial

        text_attn_probs = self.dropout(torch.softmax(text_attn_scores, dim=-1))
        img_attn_probs = self.dropout(torch.softmax(img_attn_scores, dim=-1))

        text_context = torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_text_nh)
        img_context = torch.einsum('hblt,hbtv->hblv', img_attn_probs, value_img_nh)

        context = text_context + img_context

        embeddings = rearrange(context, 'head b t d -> b t (head d)')
        return self.to_out(embeddings)

class DocFormerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([])
        for _ in range(config['num_hidden_layers']):
            encoder_block = nn.ModuleList([
                PreNormAttn(config['hidden_size'],
                            MultiModalAttentionLayer(config['hidden_size'],
                                                     config['num_attention_heads'],
                                                     config['max_relative_positions'],
                                                     config['max_position_embeddings'],
                                                     config['hidden_dropout_prob'],
                                                     )
                            ),
                PreNorm(config['hidden_size'],
                        FeedForward(config['hidden_size'],
                                    config['hidden_size'] * config['intermediate_ff_size_factor'],
                                    dropout=config['hidden_dropout_prob']))
            ])
            self.layers.append(encoder_block)

    def forward(
            self,
            text_feat,  # text feat or output from last encoder block
            img_feat,
            text_spatial_feat,
            img_spatial_feat,
    ):
        # Fig 1 encoder part (skip conn for both attn & FF): https://arxiv.org/abs/1706.03762
        # TODO: ensure 1st skip conn (var "skip") in such a multimodal setting makes sense (most likely does)
        for attn, ff in self.layers:
            skip = text_feat + img_feat + text_spatial_feat + img_spatial_feat
            x = attn(text_feat, img_feat, text_spatial_feat, img_spatial_feat) + skip
            x = ff(x) + x
            text_feat = x
        return x


class LanguageFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        from transformers import LayoutLMForTokenClassification
        layoutlm_dummy = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=1)
        self.embedding_vector = nn.Embedding.from_pretrained(layoutlm_dummy.layoutlm.embeddings.word_embeddings.weight)

    def forward(self, x):
        return self.embedding_vector(x)
        


class ExtractFeatures(nn.Module):

    '''
    Inputs: dictionary
    Output: v_bar, t_bar, v_bar_s, t_bar_s
    '''

    def __init__(self, config):
        super().__init__()
        self.visual_feature = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
        self.language_feature = LanguageFeatureExtractor()
        self.spatial_feature = DocFormerEmbeddings(config)

    def forward(self, encoding):
      
        image = encoding['resized_scaled_img']
            
        language = encoding['input_ids']
        x_feature = encoding['x_features']
        y_feature = encoding['y_features']

        v_bar = self.visual_feature(image)
        t_bar = self.language_feature(language)

        v_bar_s, t_bar_s = self.spatial_feature(x_feature, y_feature)
        
        return v_bar, t_bar, v_bar_s, t_bar_s

    
    
class DocFormer(nn.Module):
    
    '''
    Easy boiler plate, because this model will just take as an input, the dictionary which is obtained from create_features function
    '''
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.extract_feature = ExtractFeatures(config)
        self.encoder = DocFormerEncoder(config)
        self.dropout = nn.Dropout(config['hidden_dropout_prob'])

    def forward(self, x ,use_tdi=False):
        v_bar, t_bar, v_bar_s, t_bar_s = self.extract_feature(x,use_tdi)
        features = {'v_bar': v_bar, 't_bar': t_bar, 'v_bar_s': v_bar_s, 't_bar_s': t_bar_s}
        output = self.encoder(features['t_bar'], features['v_bar'], features['t_bar_s'], features['v_bar_s'])
        output = self.dropout(output)
        return output