|
from transformers import Pipeline |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
import floret |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
class Pipeline_One(Pipeline): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
repo_id = "Maslionok/pipeline1" |
|
filename = "LID-40-3-2000000-1-4.bin" |
|
branch = "main" |
|
|
|
|
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=branch) |
|
|
|
|
|
self.model = floret.load_model(model_path) |
|
|
|
def __init__(self, model_path: str): |
|
""" |
|
Initialize the Floret language detection pipeline |
|
|
|
Args: |
|
model_path (str): Path to the .bin model file |
|
""" |
|
super().__init__() |
|
self.model = floret.FastText.load_model(model_path) |
|
|
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
|
return {}, {}, {} |
|
|
|
def preprocess(self, text, **kwargs): |
|
return text |
|
|
|
def _forward(self, inputs): |
|
model_output = self.model.predict(**inputs, k=1) |
|
|
|
return model_output |
|
|
|
def postprocess(self, outputs, **kwargs): |
|
return outputs |
|
|
|
|
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
task="language-detection", |
|
pipeline_class=Pipeline_One, |
|
default={"model": None}, |
|
) |