import torch
import torch.nn as nn

from typing import List
from dataclasses import dataclass

from transformers import PreTrainedModel
from transformers.file_utils import ModelOutput

from .configuration_textcnn import TextCNNConfig


@dataclass
class TextCNNModelOutput(ModelOutput):
    last_hidden_states: torch.FloatTensor = None
    ngram_feature_maps: List[torch.FloatTensor] = None
        
        
@dataclass
class TextCNNSequenceClassificerOutput(ModelOutput):
    loss: torch.FloatTensor = None
    logits: torch.FloatTensor = None
    last_hidden_states: torch.FloatTensor = None
    ngram_feature_maps: List[torch.FloatTensor] = None
    

class TextCNNPreTrainedModel(PreTrainedModel):
    config_class = TextCNNConfig
    base_model_prefix = "textcnn"
    
    def _init_weights(self, module):
        return NotImplementedError
    
    @property
    def dummy_inputs(self):
        pad_token = self.config.pad_token_id
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),
            "input_ids": input_ids,
        }
        return dummy_inputs
        
        
class TextCNNModel(TextCNNPreTrainedModel):
    """ A style classifier Text-CNN """
    
    def __init__(self, config):
        super().__init__(config)
        self.embeder = nn.Embedding(config.vocab_size, config.embed_dim)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, n, (f, config.embed_dim))
            for (n, f) in zip(config.num_filters, config.filter_sizes)
        ])
        
    def get_input_embeddings(self):
        return self.embeder
    
    def set_input_embeddings(self, value):
        self.embeder = value
        
    def forward(self, input_ids):
        # input_ids.shape == (bsz, seq_len)
        # x.shape == (bsz, 1, seq_len, emb_dim)
        x = self.embeder(input_ids).unsqueeze(1) # add channel dim
        outputs = []
        for conv in self.convs:
            # conv_output.shape == (bsz, n_filter[i], ngram_seq_len)
            conv_output = torch.relu(conv(x)).squeeze(3)
            # output.shape == (bsz, n_filter[i])
            output = torch.max_pool1d(conv_output, conv_output.size(2)).squeeze(2)
            outputs.append(output)
        # outputs.shape == (bsz, feature_dim)
        outputs = torch.cat(outputs, dim=1)
        
        return TextCNNModelOutput(
            last_hidden_states=outputs,
            ngram_feature_maps=pools,
        )
        
        
class TextCNNForSequenceClassification(TextCNNPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.feature_dim = sum(config.num_filters)
        self.textcnn = TextCNNModel(config)
        self.fc = nn.Sequential(
            nn.Dropout(config.dropout),
            nn.Linear(self.feature_dim, int(self.feature_dim / 2)),
            nn.ReLU(),
            nn.Linear(int(self.feature_dim / 2), config.num_labels)
        )
        
    def forward(self, input_ids, labels=None):
        # input_ids.shape == (bsz, seq_len)
        # labels.shape == (bsz,)
        outputs = self.textcnn(input_ids)
        # outputs.shape == (bsz, feature_dim)
        logits = self.fc(outputs[0])
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
        
        return TextCNNSequenceClassificerOutput(
            loss=loss,
            logits=logits,
            last_hidden_states=outputs.last_hidden_states,
            ngram_feature_maps=outputs.ngram_feature_maps,
        )