pipeline1 / impresso_langident_wrapper.py
Gleb Vinarskis
first commit
dc2b383
raw
history blame
922 Bytes
import floret # Assuming Floret is already installed
class FloretLangIdentifier:
def __init__(self, model_path):
self.model = floret.load_model(model_path)
def predict(self, text):
predictions = self.model.predict(text)
return predictions
from transformers import Pipeline
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, maybe_arg=2):
return inputs
def _forward(self, model_inputs):
# model_inputs == {"model_input": model_input}
outputs = self.model.predict_language(**model_inputs)
# Maybe {"logits": Tensor(...)}
return outputs
def postprocess(self, model_outputs):
return model_outputs