Robin-7b / lmflow /pipeline /base_aligner.py
NingKanae's picture
Duplicate from OptimalScale/Robin-7b
98f2419
raw
history blame contribute delete
593 Bytes
#!/usr/bin/env python
# coding=utf-8
""" BaseTuner: a subclass of BasePipeline.
"""
from lmflow.pipeline.base_pipeline import BasePipeline
class BaseAligner(BasePipeline):
""" A subclass of BasePipeline which is alignable.
"""
def __init__(self, *args, **kwargs):
pass
def _check_if_alignable(self, model, dataset, reward_model):
# TODO: check if the model is alignable and dataset is compatible
# TODO: add reward_model
pass
def align(self, model, dataset, reward_model):
raise NotImplementedError(".align is not implemented")