import pandas as pd
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY

from rutabert_pipeline.model import BertForClassification
from rutabert_pipeline.pipeline import ColumnTypeAnnotationPipeline


if __name__ == "__main__":
    PIPELINE_REGISTRY.register_pipeline(
        "column-type-annotation",
        pipeline_class=ColumnTypeAnnotationPipeline,
        pt_model=BertForClassification
    )

    table = pd.read_csv("../rutabert_pipeline/data/example.csv", header=0)
    data_list = []
    for col_idx in table.columns:
        label_id = 0
        label = "none"
        column_data = " ".join(list(map(lambda x: str(x).strip(), table[col_idx])))
        data_list.append(["example.csv", col_idx, label_id, label, column_data])
    df = pd.DataFrame(data_list, columns=["table_id", "column_id", "label_id", "label", "column_data"])
    pipeline = pipeline("column-type-annotation", model="sti-team/rutabert-base")
    output = pipeline(df)
    print(output)