#!/usr/bin/env python | |
# coding=utf-8 | |
""" BaseTuner: a subclass of BasePipeline. | |
""" | |
from lmflow.pipeline.base_pipeline import BasePipeline | |
class BaseAligner(BasePipeline): | |
""" A subclass of BasePipeline which is alignable. | |
""" | |
def __init__(self, *args, **kwargs): | |
pass | |
def _check_if_alignable(self, model, dataset, reward_model): | |
# TODO: check if the model is alignable and dataset is compatible | |
# TODO: add reward_model | |
pass | |
def align(self, model, dataset, reward_model): | |
raise NotImplementedError(".align is not implemented") | |