Spaces:
Runtime error
Runtime error
import logging | |
from abc import ABC, abstractmethod | |
from steamship.invocable import InvocableResponse, post | |
from steamship.invocable.plugin_service import PluginRequest, PluginService, TrainablePluginService | |
from steamship.plugin.inputs.block_and_tag_plugin_input import BlockAndTagPluginInput | |
from steamship.plugin.inputs.train_plugin_input import TrainPluginInput | |
from steamship.plugin.inputs.training_parameter_plugin_input import TrainingParameterPluginInput | |
from steamship.plugin.outputs.block_and_tag_plugin_output import BlockAndTagPluginOutput | |
from steamship.plugin.outputs.train_plugin_output import TrainPluginOutput | |
from steamship.plugin.outputs.training_parameter_plugin_output import TrainingParameterPluginOutput | |
from steamship.plugin.trainable_model import TrainableModel | |
# Note! | |
# ===== | |
# | |
# This is the PLUGIN IMPLEMENTOR's View of a Tagger. | |
# | |
# If you are using the Steamship Client, you probably want steamship.client.operations.tagger instead | |
# of this file. | |
# | |
class Tagger(PluginService[BlockAndTagPluginInput, BlockAndTagPluginOutput], ABC): | |
def run( | |
self, request: PluginRequest[BlockAndTagPluginInput] | |
) -> InvocableResponse[BlockAndTagPluginOutput]: | |
raise NotImplementedError() | |
def run_endpoint(self, **kwargs) -> InvocableResponse[BlockAndTagPluginOutput]: | |
"""Exposes the Tagger's `run` operation to the Steamship Engine via the expected HTTP path POST /tag""" | |
return self.run(PluginRequest[BlockAndTagPluginInput].parse_obj(kwargs)) | |
class TrainableTagger(TrainablePluginService[BlockAndTagPluginInput, BlockAndTagPluginOutput], ABC): | |
def run_with_model( | |
self, request: PluginRequest[BlockAndTagPluginInput], model: TrainableModel | |
) -> InvocableResponse[BlockAndTagPluginOutput]: | |
raise NotImplementedError() | |
# noinspection PyUnusedLocal | |
def run_endpoint(self, **kwargs) -> InvocableResponse[BlockAndTagPluginOutput]: | |
"""Exposes the Tagger's `run` operation to the Steamship Engine via the expected HTTP path POST /tag""" | |
return self.run(PluginRequest[BlockAndTagPluginInput].parse_obj(kwargs)) | |
# noinspection PyUnusedLocal | |
def get_training_parameters_endpoint( | |
self, **kwargs | |
) -> InvocableResponse[TrainingParameterPluginOutput]: | |
"""Exposes the Service's `get_training_parameters` operation to the Steamship Engine via the expected HTTP path POST /getTrainingParameters""" | |
return self.get_training_parameters(PluginRequest[TrainingParameterPluginInput](**kwargs)) | |
# noinspection PyUnusedLocal | |
def train_endpoint(self, **kwargs) -> InvocableResponse[TrainPluginOutput]: | |
"""Exposes the Service's `train` operation to the Steamship Engine via the expected HTTP path POST /train""" | |
logging.info(f"Tagger:train_endpoint called. Calling train {kwargs}") | |
arg = PluginRequest[TrainPluginInput].parse_obj(kwargs) | |
model = self.model_cls()() | |
model.receive_config(config=self.config) | |
if arg.is_status_check: | |
return self.train_status(arg, model) | |
else: | |
return self.train(arg, model) | |