import pandas as pd
import tensorflow as tf
import tf_keras as keras
from constants import (PROCESSED_DATA_DIR,
                       METADATA_FILEPATH,
                       BATCH_SIZE,
                       EPOCHS,
                       BERT_BASE,
                       MAX_SEQUENCE_LENGHT,
                       PROJECT_NAME,
                       FilePath,
                       PageMetadata,
                       ImageSize,
                       ImageInputShape)
from pandera.typing import DataFrame
from typing import Tuple, List
from transformers import TFBertModel
from tf_keras import layers, models
from PIL import Image

# Allow for unlimited image size, some documents are pretty big...
Image.MAX_IMAGE_PIXELS = None


def stratified_split(
        df: pd.DataFrame,
        train_frac: float,
        val_frac: float,
        test_frac: float,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    train_dfs, val_dfs, test_dfs = [], [], []

    for label, group in df.groupby('label'):
        n = len(group)
        train_end = int(n * train_frac)
        val_end = train_end + int(n * val_frac)

        train_dfs.append(group.iloc[:train_end])
        val_dfs.append(group.iloc[train_end:val_end])
        test_dfs.append(group.iloc[val_end:])

    train_df = pd.concat(train_dfs).reset_index(drop=True)
    val_df = pd.concat(val_dfs).reset_index(drop=True)
    test_df = pd.concat(test_dfs).reset_index(drop=True)

    return train_df, val_df, test_df


def dataset_from_dataframe(df: pd.DataFrame) -> tf.data.Dataset:
    return tf.data.Dataset.from_tensor_slices((
        df['img_filepath'].values,
        df['input_ids'].values,
        df['attention_mask'].values,
        df['label'].values,
    ))


def load_image(image_path: FilePath, image_size: ImageSize) -> Image:
    img_width, img_height = image_size

    # Load image
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [img_width, img_height])
    image /= 255.0

    return image


def prepare_dataset(
        ds: tf.data.Dataset,
        image_size: ImageSize,
        batch_size=32,
        buffer_size=1000
) -> tf.data.Dataset:
    def load_image_and_format_tensor_shape(
            img_path: FilePath,
            input_ids: List[int],
            attention_mask: List[int],
            label: str
    ):
        image = load_image(img_path, image_size)
        return ((image, input_ids, attention_mask), label)

    return ds.map(
        load_image_and_format_tensor_shape,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    ) \
        .shuffle(buffer_size=buffer_size) \
        .batch(batch_size) \
        .prefetch(tf.data.experimental.AUTOTUNE)


def prepare_data(
        df: DataFrame[PageMetadata]
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
    print('Splitting the DataFrame into training, validation and test')
    train_df, val_df, test_df = stratified_split(
        df,
        train_frac=0.7,
        val_frac=0.15,
        test_frac=0.15,
    )

    run = wandb.init(project_name=PROJECT_NAME, name='split-dataset')

    split_dataset_artifact = wandb.Artifact('split-dataset-metadata', type='dataset')

    train_table = wandb.Table(dataframe=train_df)
    val_table = wandb.Table(dataframe=val_df)
    test_table = wandb.Table(dataframe=test_df)

    split_dataset_artifact.add(train_table, name='train_metadata')
    split_dataset_artifact.add(val_table, name='val_metadata')
    split_dataset_artifact.add(test_table, name='test_metadata')

    run.log_artifact(split_dataset_artifact)
    run.finish()

    print('Batching and shuffling the datasets')
    train_ds = dataset_from_dataframe(train_df)
    train_ds = prepare_dataset(train_ds, img_size, batch_size=BATCH_SIZE)

    val_ds = dataset_from_dataframe(val_df)
    val_ds = prepare_dataset(val_ds, img_size, batch_size=BATCH_SIZE)

    test_ds = dataset_from_dataframe(test_df)
    test_ds = prepare_dataset(test_ds, img_size, batch_size=BATCH_SIZE)

    return train_ds, val_ds, test_ds


def build_image_model(input_shape: ImageInputShape) -> keras.Model:
    img_model = models.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
    ], name='image_classification')

    img_model.summary()
    return img_model


def build_text_model() -> keras.Model:
    bert_model = TFBertModel.from_pretrained(BERT_BASE)

    input_ids = layers.Input(
        shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='input_ids'
    )
    attention_mask = layers.Input(
        shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='attention_mask'
    )

    # The second element of the BERT output is the pooled output i.e. the
    # representation of the [CLS] token
    outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)[1]

    text_model = models.Model(
        inputs=[input_ids, attention_mask],
        outputs=outputs,
        name='bert'
    )
    text_model.summary()

    return text_model


def build_multimodal_model(
        num_classes: int,
        img_input_shape: ImageInputShape
) -> keras.Model:
    img_model = build_image_model(img_input_shape)
    text_model = build_text_model()

    img_input = layers.Input(shape=img_input_shape, name='img_input')
    text_input_ids = layers.Input(
        shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_ids'
    )
    text_input_mask = layers.Input(
        shape=(MAX_SEQUENCE_LENGHT,), dtype=tf.int32, name='text_input_mask'
    )

    img_features = img_model(img_input)
    text_features = text_model([text_input_ids, text_input_mask])

    classification_layers = keras.Sequential([
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(num_classes, activation='softmax'),
    ], name='classification_layers')
    concat_features = layers.concatenate([img_features, text_features],
                                         name='concatenate_features')
    outputs = classification_layers(concat_features)

    multimodal_model = models.Model(
        inputs=[img_input, text_input_ids, text_input_mask],
        outputs=outputs,
        name='multimodal_document_page_classifier'
    )
    return multimodal_model


def train():
    metadata_df: DataFrame[PageMetadata] = pd.read_csv(METADATA_FILEPATH)

    median_height = int(metadata_df['height'].median())
    median_width = int(metadata_df['width'].median())

    img_size: ImageSize = (median_height, median_width)
    img_input_shape: ImageInputShape = img_size + (3,)

    label_names: List[str] = sorted(
        [d.name for d in PROCESSED_DATA_DIR.iterdir() if d.is_dir()]
    )
    num_classes = len(label_names)

    train_ds, val_ds, test_ds = prepare_data(metadata_df)

    multimodal_model = build_multimodal_model(num_classes, img_input_shape)
    multimodal_model.summary()
    multimodal_model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    multimodal_model.fit(
        train_ds,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=val_ds,
    )


if __name__ = '__main__':
    train()
    evaluate()