import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Generic, Optional, TypeVar

from typing_extensions import TypeAlias

from steamship.base.client import Client
from steamship.invocable import InvocableResponse
from steamship.invocable.plugin_service import PluginRequest
from steamship.plugin.inputs.train_plugin_input import TrainPluginInput
from steamship.plugin.outputs.model_checkpoint import ModelCheckpoint
from steamship.plugin.outputs.train_plugin_output import TrainPluginOutput

ModelConstructor: TypeAlias = Callable[[], "TrainableModel"]

# Global variable to store the model for reuse in memory.
MODEL_CACHE: Dict[str, "TrainableModel"] = {}

ConfigType = TypeVar("ConfigType")


class TrainableModel(ABC, Generic[ConfigType]):
    """Base class for trainable models.

    Trainable models are not plugins. They are a thin wrapper around the state of a model designed to be **used with**
    the Steamship plugin system.

    # State Management

    100% of a TrainableModel's state management should save to & read from a folder on disk via the methods
    `save_to_folder` and `load_from_folder`.

    # Remote Saving and Loading

    `TrainableModel` instances automatically save to a user's Workspace on Steamship via `save_remote` method. They
    can load themselves from a user's workspace via the `load_remote` method.

    When saving a model, the caller provides `handle`, such as "V1" or "epoch_23". This allows that particular checkpoint
    to be re-loaded. By default, every save operation also saves the model to the "default" checkpoint, overwriting it
    if it already existed. When a user loads a model without specifying a checkpoint, the "default" checkpoint will be used.

    # Data Scope

    A TrainableModel's data is saved & loaded with respect to

    1) The user's active Workspace, and
    2) The provided Plugin Instance within that workspace.

    The active workspace is read from the Steamship client context, and the `plugin_instance_id` is supplied as a
    method argument on the `save_remote` and `load_remote` methods.

    This organization enables a user to have arbitrarily many trained model instances of the same type colocated within
    a Workspace.

    # Training

    A training job is fully parameterized by the `TrainPluginInput` object.

    # Result Reporting

    A training job's results are reported via the `TrainPluginOutput` object. These results include a reference to the
    `save_remote` output, but they do not include the model parameters themselves. For example, after training, one
    could write:

    >>> archive_path_in_steamship = model.save_remote(..)
    >>> output = TrainPluginOutput(archive_path_in_steamship=archive_path_in_steamship,
        ...
        )

    That output is the ultimate return object of the training process, but the Plugin that owns this model need not
    wait for synchronous completion to update the Steamship Engine with intermediate results. It can use the
    `Response.post_update` to proactively stream results back to the server.

    # Third-party / External Models

    This model class is a convenient wrapper for models running on third party systems (e.g. Google's AutoML). In such
    a case:

    - The `train` method would begin the job on the 3rd party system.
    - The `save_to_folder` method would write the Job ID and any other useful data to the checkpoint path
    - The `load_from_folder` method would read this Job ID from disk and obtain an authenticated client with the
      third party system.
    - Any `run` method the implementer created would ferry back results fetched from the third-party system.
    - Any status reporting in TrainPluginOutput would ferry back status fetched from the third-party system.

    """

    config: ConfigType = None

    def receive_config(self, config: ConfigType):
        """Stores config from plugin instance, so it is accessible by model on load or train."""
        self.config = config

    @abstractmethod
    def save_to_folder(self, checkpoint_path: Path):
        """Saves 100% of the state of this model to the provided path."""
        raise NotImplementedError()

    @abstractmethod
    def load_from_folder(self, checkpoint_path: Path):
        """Load 100% of the state of this model to the provided path."""
        raise NotImplementedError()

    @abstractmethod
    def train(self, input: PluginRequest[TrainPluginInput]) -> InvocableResponse[TrainPluginOutput]:
        """Train or fine-tune the model, parameterized by the information in the TrainPluginInput object."""
        raise NotImplementedError()

    @abstractmethod
    def train_status(
        self, input: PluginRequest[TrainPluginInput]
    ) -> InvocableResponse[TrainPluginOutput]:
        """Check on the status of an in-process training job, if it is running externally asynchronously."""
        raise NotImplementedError()

    @classmethod
    def load_from_local_checkpoint(cls, checkpoint: ModelCheckpoint, config: ConfigType):
        model = cls()
        model.receive_config(config=config)
        model.load_from_folder(checkpoint.folder_path_on_disk())
        return model

    @classmethod
    def load_remote(
        cls,
        client: Client,
        plugin_instance_id: str,
        checkpoint_handle: Optional[str] = None,
        use_cache: bool = True,
        model_parent_directory: Path = None,
        plugin_instance_config: ConfigType = None,
    ):
        if checkpoint_handle is None:
            # For some reason doing this defaulting in the signature wasn't working.
            checkpoint_handle = ModelCheckpoint.DEFAULT_HANDLE

        model_key = f"{plugin_instance_id}/{checkpoint_handle}"
        logging.info(f"TrainableModel:load_remote - Model Key: {model_key}")

        global MODEL_CACHE

        if use_cache:
            if model_key in MODEL_CACHE:
                logging.info(f"TrainableModel:load_remote - Returning cached: {model_key}")
                return MODEL_CACHE[model_key]

        checkpoint = ModelCheckpoint(
            client=client,
            parent_directory=model_parent_directory,
            handle=checkpoint_handle,
            plugin_instance_id=plugin_instance_id,
        )

        # If we haven't loaded the model, we need to download and start the model
        logging.info(f"TrainableModel:load_remote - Downloading: {model_key}")
        checkpoint.download_model_bundle()
        logging.info(f"TrainableModel:load_remote - Loading: {model_key}")
        model = cls.load_from_local_checkpoint(checkpoint, plugin_instance_config)
        logging.info(f"TrainableModel:load_remote - Loaded: {model_key}")

        if use_cache:
            MODEL_CACHE[model_key] = model

        return model

    def save_remote(
        self,
        client: Client,
        plugin_instance_id: str,
        checkpoint_handle: Optional[str] = None,
        model_parent_directory: Path = None,
        set_as_default: bool = True,
    ) -> str:
        if checkpoint_handle is None:
            # For some reason doing this defaulting in the signature wasn't working.
            checkpoint_handle = ModelCheckpoint.DEFAULT_HANDLE

        checkpoint = ModelCheckpoint(
            client=client,
            parent_directory=model_parent_directory,
            handle=checkpoint_handle,
            plugin_instance_id=plugin_instance_id,
        )
        self.save_to_folder(checkpoint.folder_path_on_disk())
        checkpoint.upload_model_bundle(set_as_default=set_as_default)
        return checkpoint.archive_path_in_steamship()