import torch
import lightning
# from torch.utils.data import Dataset
# from typing import Any, Dict
# import argparse
from pydantic import BaseModel
# from get_dataset_dictionaries import get_dict_pair
# import os
# import shutil

# import optuna
# from optuna.integration import PyTorchLightningPruningCallback
# from functools import partial

class FFNModule(torch.nn.Module):
    """
    A pytorch module that regresses from a hidden state representation of a word
    to its continuous linguistic feature norm vector.

    It is a FFN with the general structure of:
    input -> (linear -> nonlinearity -> dropout) x (num_layers - 1) -> linear -> output
    """
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int,
        num_layers: int,
        dropout: float,
    ):
        super(FFNModule, self).__init__()

        layers = []
        for _ in range(num_layers - 1):
            layers.append(torch.nn.Linear(input_size, hidden_size))
            layers.append(torch.nn.ReLU())
            layers.append(torch.nn.Dropout(dropout))
            # changes input size to hidden size after first layer
            input_size = hidden_size
        layers.append(torch.nn.Linear(hidden_size, output_size))
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)
    
class FFNParams(BaseModel):
    input_size: int
    output_size: int
    hidden_size: int
    num_layers: int
    dropout: float

class TrainingParams(BaseModel):
    num_epochs: int
    batch_size: int
    learning_rate: float
    weight_decay: float

class FeatureNormPredictor(lightning.LightningModule):
    def __init__(self, ffn_params : FFNParams, training_params : TrainingParams):
        super().__init__()
        self.save_hyperparameters()
        self.ffn_params = ffn_params
        self.training_params = training_params
        self.model = FFNModule(**ffn_params.model_dump())
        self.loss_function = torch.nn.MSELoss()
        self.training_params = training_params

    def training_step(self, batch, batch_idx):
        x,y = batch
        outputs = self.model(x)
        loss = self.loss_function(outputs, y)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x,y = batch
        outputs = self.model(x)
        loss = self.loss_function(outputs, y)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        return self.model(batch)
    
    def predict(self, batch):
        return self.model(batch)
    
    def __call__(self, input):
        return self.model(input)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), 
            lr=self.training_params.learning_rate,
            weight_decay=self.training_params.weight_decay,
        )
        return optimizer
    
    def save_model(self, path: str):
        torch.save(self.model.state_dict(), path)

    def load_model(self, path: str):
        self.model.load_state_dict(torch.load(path))

    
# class HiddenStateFeatureNormDataset(Dataset):
#     def __init__(
#         self, 
#         input_embeddings: Dict[str, torch.Tensor],
#         feature_norms: Dict[str, torch.Tensor],
#     ):
        
#         # Invariant: input_embeddings and target_feature_norms have exactly the same keys
#         # this should be done by the train/test split and upstream data processing
#         assert(input_embeddings.keys() == feature_norms.keys())

#         self.words = list(input_embeddings.keys())
#         self.input_embeddings = torch.stack([
#             input_embeddings[word] for word in self.words
#         ])
#         self.feature_norms = torch.stack([
#             feature_norms[word] for word in self.words
#         ])
        
#     def __len__(self):
#         return len(self.words)
    
#     def __getitem__(self, idx):
#         return self.input_embeddings[idx], self.feature_norms[idx]

# # this is used when not optimizing
# def train(args : Dict[str, Any]):

#     # input_embeddings = torch.load(args.input_embeddings)
#     # feature_norms = torch.load(args.feature_norms)
#     # words = list(input_embeddings.keys())

#     input_embeddings, feature_norms, norm_list = get_dict_pair(
#         args.norm,
#         args.embedding_dir,
#         args.lm_layer,
#         translated= False if args.raw_buchanan else True,
#         normalized= True if args.normal_buchanan else False
#     ) 
#     norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w')
#     norms_file.write("\n".join(norm_list))
#     norms_file.close()

#     words = list(input_embeddings.keys())
    
#     model = FeatureNormPredictor(
#         FFNParams(
#             input_size=input_embeddings[words[0]].shape[0],
#             output_size=feature_norms[words[0]].shape[0],
#             hidden_size=args.hidden_size,
#             num_layers=args.num_layers,
#             dropout=args.dropout,
#         ),
#         TrainingParams(
#             num_epochs=args.num_epochs,
#             batch_size=args.batch_size,
#             learning_rate=args.learning_rate,
#             weight_decay=args.weight_decay,
#         ),
#     )

#     # train/val split
#     train_size = int(len(words) * 0.8)
#     valid_size = len(words) - train_size
#     train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size])

#     # TODO: Methodology Decision: should we be normalizing the hidden states/feature norms?
#     train_embeddings = {word: input_embeddings[word] for word in train_words}
#     train_feature_norms = {word: feature_norms[word] for word in train_words}
#     validation_embeddings = {word: input_embeddings[word] for word in validation_words}
#     validation_feature_norms = {word: feature_norms[word] for word in validation_words}

#     train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms)
#     train_dataloader = torch.utils.data.DataLoader(
#         train_dataset,
#         batch_size=args.batch_size,
#         shuffle=True,
#     )
#     validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms)
#     validation_dataloader = torch.utils.data.DataLoader(
#         validation_dataset,
#         batch_size=args.batch_size,
#         shuffle=True,
#     )

#     callbacks = [
#         lightning.pytorch.callbacks.ModelCheckpoint(
#             save_last=True,
#             dirpath=args.save_dir,
#             filename=args.save_model_name,
#         ),
#     ]
#     if args.early_stopping is not None:
#         callbacks.append(lightning.pytorch.callbacks.EarlyStopping(
#             monitor="val_loss",
#             patience=args.early_stopping,
#             mode='min',
#             min_delta=0.0
#         ))

#     #TODO Design Decision - other trainer args? Is device necessary?
#     # cpu is fine for the scale of this model - only a few layers and a few hundred words
#     trainer = lightning.Trainer(
#         max_epochs=args.num_epochs,
#         callbacks=callbacks,
#         accelerator="cpu",
#         log_every_n_steps=7
#     )

#     trainer.fit(model, train_dataloader, validation_dataloader)

#     trainer.validate(model, validation_dataloader)

#     return model

# # this is used when optimizing
# def objective(trial: optuna.trial.Trial, args: Dict[str, Any]) -> float:
#     # optimizing hidden size, batch size, and learning rate
#     input_embeddings, feature_norms, norm_list = get_dict_pair(
#         args.norm,
#         args.embedding_dir,
#         args.lm_layer,
#         translated= False if args.raw_buchanan else True,
#         normalized= True if args.normal_buchanan else False
#     )
#     norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w')
#     norms_file.write("\n".join(norm_list))
#     norms_file.close()

#     words = list(input_embeddings.keys())
#     input_size=input_embeddings[words[0]].shape[0]
#     output_size=feature_norms[words[0]].shape[0]
#     min_size = min(output_size, input_size)
#     max_size = min(output_size, 2*input_size)if min_size == input_size else min(2*output_size, input_size)
#     hidden_size = trial.suggest_int("hidden_size", min_size, max_size, log=True)
#     batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
#     learning_rate = trial.suggest_float("learning_rate", 1e-6, 1, log=True)

#     model = FeatureNormPredictor(
#         FFNParams(
#             input_size=input_size,
#             output_size=output_size,
#             hidden_size=hidden_size,
#             num_layers=args.num_layers,
#             dropout=args.dropout,
#         ),
#         TrainingParams(
#             num_epochs=args.num_epochs,
#             batch_size=batch_size,
#             learning_rate=learning_rate,
#             weight_decay=args.weight_decay,
#         ),
#     )

#     # train/val split
#     train_size = int(len(words) * 0.8)
#     valid_size = len(words) - train_size
#     train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size])

#     train_embeddings = {word: input_embeddings[word] for word in train_words}
#     train_feature_norms = {word: feature_norms[word] for word in train_words}
#     validation_embeddings = {word: input_embeddings[word] for word in validation_words}
#     validation_feature_norms = {word: feature_norms[word] for word in validation_words}

#     train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms)
#     train_dataloader = torch.utils.data.DataLoader(
#         train_dataset,
#         batch_size=args.batch_size,
#         shuffle=True,
#     )
#     validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms)
#     validation_dataloader = torch.utils.data.DataLoader(
#         validation_dataset,
#         batch_size=args.batch_size,
#         shuffle=True,
#     )

#     callbacks = [
#         # all trial models will be saved in temporary directory
#         lightning.pytorch.callbacks.ModelCheckpoint(
#             save_last=True,
#             dirpath=os.path.join(args.save_dir,'optuna_trials'),
#             filename="{}".format(trial.number)
#         ),
#     ]

#     if args.prune is not None:
#         callbacks.append(PyTorchLightningPruningCallback(
#             trial,
#             monitor='val_loss'
#         ))
        
#     if args.early_stopping is not None:
#         callbacks.append(lightning.pytorch.callbacks.EarlyStopping(
#             monitor="val_loss",
#             patience=args.early_stopping,
#             mode='min',
#             min_delta=0.0
#         ))
#     # note that if optimizing is chosen, will automatically not implement vanilla early stopping 
#     #TODO Design Decision - other trainer args? Is device necessary?
#     # cpu is fine for the scale of this model - only a few layers and a few hundred words
#     trainer = lightning.Trainer(
#         max_epochs=args.num_epochs,
#         callbacks=callbacks,
#         accelerator="cpu",
#         log_every_n_steps=7,
#         # enable_checkpointing=False
#     )

#     trainer.fit(model, train_dataloader, validation_dataloader)

#     trainer.validate(model, validation_dataloader)
    
#     return trainer.callback_metrics['val_loss'].item()

# if __name__ == "__main__":
#     # parse args
#     parser = argparse.ArgumentParser()
#     #TODO: Design Decision: Should we input paths, to the pre-extracted layers, or the model/layer we want to generate them from
#     # required inputs
#     parser.add_argument("--norm", type=str, required=True, help="feature norm set to use")
#     parser.add_argument("--embedding_dir", type=str, required=True, help=" directory containing embeddings")
#     parser.add_argument("--lm_layer", type=int, required=True, help="layer of embeddings to use")
#     # if user selects optimize, hidden_size, batch_size and learning_rate will be optimized. 
#     parser.add_argument("--optimize", action="store_true", help="optimize hyperparameters for training")
#     parser.add_argument("--prune", action="store_true", help="prune unpromising trials when optimizing")
#     # optional hyperparameter specs
#     parser.add_argument("--num_layers", type=int, default=2, help="number of layers in FFN")
#     parser.add_argument("--hidden_size", type=int, default=100, help="hidden size of FFN")
#     parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate of FFN")
#     # set this to at least 100 if doing early stopping
#     parser.add_argument("--num_epochs", type=int, default=10, help="number of epochs to train for")
#     parser.add_argument("--batch_size", type=int, default=32, help="batch size for training")
#     parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate for training")
#     parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for training")
#     parser.add_argument("--early_stopping", type=int, default=None, help="number of epochs to wait for early stopping")
#     # optional dataset specs, for buchanan really
#     parser.add_argument('--raw_buchanan', action="store_true", help="do not use translated values for buchanan")
#     parser.add_argument('--normal_buchanan', action="store_true", help="use normalized features for buchanan")
#     # required for output
#     parser.add_argument("--save_dir", type=str, required=True, help="directory to save model to")
#     parser.add_argument("--save_model_name", type=str, required=True, help="name of model to save")

#     args = parser.parse_args()

#     if args.early_stopping is not None:
#         args.num_epochs = max(50, args.num_epochs)

#     torch.manual_seed(10)

#     if args.optimize:
#         # call optimizer code here
#         print("optimizing for learning rate, batch size, and hidden size")
#         pruner = optuna.pruners.MedianPruner() if args.prune else optuna.pruners.NopPruner()
#         sampler = optuna.samplers.TPESampler(seed=10)

#         study = optuna.create_study(direction='minimize', pruner=pruner, sampler=sampler)
#         study.optimize(partial(objective, args=args), n_trials = 100, timeout=600)

#         other_params = {
#             "num_layers": args.num_layers,
#             "num_epochs": args.num_epochs,
#             "dropout": args.dropout,
#             "weight_decay": args.weight_decay,
#         }

#         print("Number of finished trials: {}".format(len(study.trials)))

#         trial = study.best_trial
#         print("Best trial: "+str(trial.number))
        

#         print("  Validation Loss: {}".format(trial.value))

#         print("  Optimized Params: ")
#         for key, value in trial.params.items():
#             print("    {}: {}".format(key, value))

#         print("  User Defined Params: ")
#         for key, value in other_params.items():
#             print("    {}: {}".format(key, value))
        
#         print('saving best trial')
#         for filename in os.listdir(os.path.join(args.save_dir,'optuna_trials')):
#             if filename == "{}.ckpt".format(trial.number):
#                 shutil.move(os.path.join(args.save_dir,'optuna_trials',filename), os.path.join(args.save_dir, "{}.ckpt".format(args.save_model_name)))
#         shutil.rmtree(os.path.join(args.save_dir,'optuna_trials'))

#     else:   
#         model = train(args)