| import abc | |
| from .artifact import Artifact | |
| class InferenceEngine(abc.ABC, Artifact): | |
| """Abstract base class for inference.""" | |
| def infer(self, dataset): | |
| """Perform inference on the input dataset.""" | |
| pass | |
| class HFPipelineBasedInferenceEngine(Artifact): | |
| """Abstract base class for inference.""" | |
| model_name: str | |
| max_new_tokens: int | |
| def prepare(self): | |
| from transformers import pipeline | |
| self.model = pipeline(model=self.model_name) | |
| def infer(self, dataset): | |
| return [ | |
| output["generated_text"] | |
| for output in self.model( | |
| [instance["source"] for instance in dataset], | |
| max_new_tokens=self.max_new_tokens, | |
| ) | |
| ] | |