|
import floret |
|
|
|
|
|
class FloretLangIdentifier: |
|
def __init__(self, model_path): |
|
self.model = floret.load_model(model_path) |
|
|
|
def predict(self, text): |
|
predictions = self.model.predict(text) |
|
return predictions |
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import Pipeline |
|
|
|
|
|
class MyPipeline(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "maybe_arg" in kwargs: |
|
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, inputs, maybe_arg=2): |
|
return inputs |
|
|
|
def _forward(self, model_inputs): |
|
|
|
outputs = self.model.predict_language(**model_inputs) |
|
|
|
return outputs |
|
|
|
def postprocess(self, model_outputs): |
|
return model_outputs |