Spaces:
Runtime error
Runtime error
import logging | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from typing import Callable, Dict, Generic, Optional, TypeVar | |
from typing_extensions import TypeAlias | |
from steamship.base.client import Client | |
from steamship.invocable import InvocableResponse | |
from steamship.invocable.plugin_service import PluginRequest | |
from steamship.plugin.inputs.train_plugin_input import TrainPluginInput | |
from steamship.plugin.outputs.model_checkpoint import ModelCheckpoint | |
from steamship.plugin.outputs.train_plugin_output import TrainPluginOutput | |
ModelConstructor: TypeAlias = Callable[[], "TrainableModel"] | |
# Global variable to store the model for reuse in memory. | |
MODEL_CACHE: Dict[str, "TrainableModel"] = {} | |
ConfigType = TypeVar("ConfigType") | |
class TrainableModel(ABC, Generic[ConfigType]): | |
"""Base class for trainable models. | |
Trainable models are not plugins. They are a thin wrapper around the state of a model designed to be **used with** | |
the Steamship plugin system. | |
# State Management | |
100% of a TrainableModel's state management should save to & read from a folder on disk via the methods | |
`save_to_folder` and `load_from_folder`. | |
# Remote Saving and Loading | |
`TrainableModel` instances automatically save to a user's Workspace on Steamship via `save_remote` method. They | |
can load themselves from a user's workspace via the `load_remote` method. | |
When saving a model, the caller provides `handle`, such as "V1" or "epoch_23". This allows that particular checkpoint | |
to be re-loaded. By default, every save operation also saves the model to the "default" checkpoint, overwriting it | |
if it already existed. When a user loads a model without specifying a checkpoint, the "default" checkpoint will be used. | |
# Data Scope | |
A TrainableModel's data is saved & loaded with respect to | |
1) The user's active Workspace, and | |
2) The provided Plugin Instance within that workspace. | |
The active workspace is read from the Steamship client context, and the `plugin_instance_id` is supplied as a | |
method argument on the `save_remote` and `load_remote` methods. | |
This organization enables a user to have arbitrarily many trained model instances of the same type colocated within | |
a Workspace. | |
# Training | |
A training job is fully parameterized by the `TrainPluginInput` object. | |
# Result Reporting | |
A training job's results are reported via the `TrainPluginOutput` object. These results include a reference to the | |
`save_remote` output, but they do not include the model parameters themselves. For example, after training, one | |
could write: | |
>>> archive_path_in_steamship = model.save_remote(..) | |
>>> output = TrainPluginOutput(archive_path_in_steamship=archive_path_in_steamship, | |
... | |
) | |
That output is the ultimate return object of the training process, but the Plugin that owns this model need not | |
wait for synchronous completion to update the Steamship Engine with intermediate results. It can use the | |
`Response.post_update` to proactively stream results back to the server. | |
# Third-party / External Models | |
This model class is a convenient wrapper for models running on third party systems (e.g. Google's AutoML). In such | |
a case: | |
- The `train` method would begin the job on the 3rd party system. | |
- The `save_to_folder` method would write the Job ID and any other useful data to the checkpoint path | |
- The `load_from_folder` method would read this Job ID from disk and obtain an authenticated client with the | |
third party system. | |
- Any `run` method the implementer created would ferry back results fetched from the third-party system. | |
- Any status reporting in TrainPluginOutput would ferry back status fetched from the third-party system. | |
""" | |
config: ConfigType = None | |
def receive_config(self, config: ConfigType): | |
"""Stores config from plugin instance, so it is accessible by model on load or train.""" | |
self.config = config | |
def save_to_folder(self, checkpoint_path: Path): | |
"""Saves 100% of the state of this model to the provided path.""" | |
raise NotImplementedError() | |
def load_from_folder(self, checkpoint_path: Path): | |
"""Load 100% of the state of this model to the provided path.""" | |
raise NotImplementedError() | |
def train(self, input: PluginRequest[TrainPluginInput]) -> InvocableResponse[TrainPluginOutput]: | |
"""Train or fine-tune the model, parameterized by the information in the TrainPluginInput object.""" | |
raise NotImplementedError() | |
def train_status( | |
self, input: PluginRequest[TrainPluginInput] | |
) -> InvocableResponse[TrainPluginOutput]: | |
"""Check on the status of an in-process training job, if it is running externally asynchronously.""" | |
raise NotImplementedError() | |
def load_from_local_checkpoint(cls, checkpoint: ModelCheckpoint, config: ConfigType): | |
model = cls() | |
model.receive_config(config=config) | |
model.load_from_folder(checkpoint.folder_path_on_disk()) | |
return model | |
def load_remote( | |
cls, | |
client: Client, | |
plugin_instance_id: str, | |
checkpoint_handle: Optional[str] = None, | |
use_cache: bool = True, | |
model_parent_directory: Path = None, | |
plugin_instance_config: ConfigType = None, | |
): | |
if checkpoint_handle is None: | |
# For some reason doing this defaulting in the signature wasn't working. | |
checkpoint_handle = ModelCheckpoint.DEFAULT_HANDLE | |
model_key = f"{plugin_instance_id}/{checkpoint_handle}" | |
logging.info(f"TrainableModel:load_remote - Model Key: {model_key}") | |
global MODEL_CACHE | |
if use_cache: | |
if model_key in MODEL_CACHE: | |
logging.info(f"TrainableModel:load_remote - Returning cached: {model_key}") | |
return MODEL_CACHE[model_key] | |
checkpoint = ModelCheckpoint( | |
client=client, | |
parent_directory=model_parent_directory, | |
handle=checkpoint_handle, | |
plugin_instance_id=plugin_instance_id, | |
) | |
# If we haven't loaded the model, we need to download and start the model | |
logging.info(f"TrainableModel:load_remote - Downloading: {model_key}") | |
checkpoint.download_model_bundle() | |
logging.info(f"TrainableModel:load_remote - Loading: {model_key}") | |
model = cls.load_from_local_checkpoint(checkpoint, plugin_instance_config) | |
logging.info(f"TrainableModel:load_remote - Loaded: {model_key}") | |
if use_cache: | |
MODEL_CACHE[model_key] = model | |
return model | |
def save_remote( | |
self, | |
client: Client, | |
plugin_instance_id: str, | |
checkpoint_handle: Optional[str] = None, | |
model_parent_directory: Path = None, | |
set_as_default: bool = True, | |
) -> str: | |
if checkpoint_handle is None: | |
# For some reason doing this defaulting in the signature wasn't working. | |
checkpoint_handle = ModelCheckpoint.DEFAULT_HANDLE | |
checkpoint = ModelCheckpoint( | |
client=client, | |
parent_directory=model_parent_directory, | |
handle=checkpoint_handle, | |
plugin_instance_id=plugin_instance_id, | |
) | |
self.save_to_folder(checkpoint.folder_path_on_disk()) | |
checkpoint.upload_model_bundle(set_as_default=set_as_default) | |
return checkpoint.archive_path_in_steamship() | |