import pandas as pd | |
from transformers import pipeline | |
from transformers.pipelines import PIPELINE_REGISTRY | |
from rutabert_pipeline.model import BertForClassification | |
from rutabert_pipeline.pipeline import ColumnTypeAnnotationPipeline | |
if __name__ == "__main__": | |
PIPELINE_REGISTRY.register_pipeline( | |
"column-type-annotation", | |
pipeline_class=ColumnTypeAnnotationPipeline, | |
pt_model=BertForClassification | |
) | |
table = pd.read_csv("../rutabert_pipeline/data/example.csv", header=0) | |
data_list = [] | |
for col_idx in table.columns: | |
label_id = 0 | |
label = "none" | |
column_data = " ".join(list(map(lambda x: str(x).strip(), table[col_idx]))) | |
data_list.append(["example.csv", col_idx, label_id, label, column_data]) | |
df = pd.DataFrame(data_list, columns=["table_id", "column_id", "label_id", "label", "column_data"]) | |
pipeline = pipeline("column-type-annotation", model="sti-team/rutabert-base") | |
output = pipeline(df) | |
print(output) | |