import torch
from pathlib import Path

import torch
from pathlib import Path


def save_model(model, optimizer, epoch, loss, directory, model_name='model', **kwargs):
    """
    Save a PyTorch model checkpoint.

    Args:
    model: Trained model.
    optimizer: Optimizer used for training.
    epoch: The last epoch the model was trained on.
    loss: The last loss recorded during training.
    directory: The directory where to save the model.
    model_name: Base name for the model file, defaults to 'model'.
    kwargs: Additional keyword arguments representing metrics to be included in the filename.
    To use the function, you would do something like this:
    >>>save_checkpoint(model, optimizer, epoch, loss, './model_dir', f1_score=val_f1score)
    """
    # Create the directory if it does not exist
    Path(directory).mkdir(parents=True, exist_ok=True)

    # Create the filename
    metrics_str = '_'.join(f'{key}={value:.4f}' for key, value in kwargs.items())
    filename = f'{directory}/{model_name}_epoch={epoch}_loss={loss:.4f}_{metrics_str}.pth'

    # Save the model checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        **kwargs
    }, filename)


def get_device() -> torch.device:
    """
        Retrieves the appropriate Torch device for running computations.

        Returns:
            torch.device: The Torch device to be used for computations.

        Raises:
            None

        Examples:
            >>> device = get_device()
            >>> print(device)
            cuda

        """
    if torch.cuda.is_available():
        device = "cuda"  # NVIDIA GPU
    elif torch.backends.mps.is_available():
        device = "mps"  # Apple GPU
    else:
        device = "cpu"  # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available
    # print(f"Using {device} device")
    return torch.device(device)


def load_checkpoint(model, optimizer, filename):
    """
    Load a PyTorch model checkpoint.

    Args:
    model: Model to load the weights into.
    optimizer: Optimizer to load the state into.
    filename: The path of the checkpoint file.

    Returns:
    The epoch at which training was stopped, the last loss recorded, and any additional metrics.
    """
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    # Extract additional metrics
    metrics = {key: value for key, value in checkpoint.items() if
               key not in ['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss']}

    return epoch, loss, metrics

# To use the function, you would do something like this:
# epoch, loss, metrics = load_checkpoint(model, optimizer, 'model_checkpoint.pth')