rutabert-base / rutabert_pipeline /inference_example.py
LedZeppe1in
Added custom rutabert pipeline for column type annotation
1507360
import pandas as pd
from transformers import pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from rutabert_pipeline.model import BertForClassification
from rutabert_pipeline.pipeline import ColumnTypeAnnotationPipeline
if __name__ == "__main__":
PIPELINE_REGISTRY.register_pipeline(
"column-type-annotation",
pipeline_class=ColumnTypeAnnotationPipeline,
pt_model=BertForClassification
)
table = pd.read_csv("../rutabert_pipeline/data/example.csv", header=0)
data_list = []
for col_idx in table.columns:
label_id = 0
label = "none"
column_data = " ".join(list(map(lambda x: str(x).strip(), table[col_idx])))
data_list.append(["example.csv", col_idx, label_id, label, column_data])
df = pd.DataFrame(data_list, columns=["table_id", "column_id", "label_id", "label", "column_data"])
pipeline = pipeline("column-type-annotation", model="sti-team/rutabert-base")
output = pipeline(df)
print(output)