from transformers import Pipeline | |
from transformers.pipelines import PIPELINE_REGISTRY | |
import floret | |
from huggingface_hub import hf_hub_download | |
class Pipeline_One(Pipeline): | |
# def __init__(self, model, **kwargs): | |
# self.model = model | |
# self.framework = "floret" | |
# super().__init__(model=model) | |
def _sanitize_parameters(self, **kwargs): | |
# Add any additional parameter handling if necessary | |
return {}, {}, {} | |
def preprocess(self, text, **kwargs): | |
return text | |
def _forward(self, inputs): | |
model_output = self.model.predict(**inputs, k=1) | |
return model_output | |
def postprocess(self, outputs, **kwargs): | |
return outputs | |