talk-to-data / src /inference.py
RohitMidha23
return table also
4a19d8a
from transformers import TapasTokenizer, TapasForQuestionAnswering
import pandas as pd
from typing import List, Dict
from src.constants import id2aggregation
def infer(query: str, file_name: str, model_name: str="google/tapas-base-finetuned-wtq") -> Dict[str, str]:
# Load the file
table = pd.read_csv(file_name, delimiter=",")
table = table.astype(str)
# Load the model
model = TapasForQuestionAnswering.from_pretrained(model_name)
tokenizer = TapasTokenizer.from_pretrained(model_name)
# Make predictions
queries = [query]
inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
outputs = model(**inputs)
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
inputs, outputs.logits.detach(), outputs.logits_aggregation.detach()
) # predicted_answer_coordinates: contains coordinates for the respective answer cells, predicted_aggregation_indices: contains the aggregation type for each query
aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
answers = []
for coordinates in predicted_answer_coordinates:
if len(coordinates) == 1:
# only a single cell:
answers.append(table.iat[coordinates[0]])
else:
# multiple cells
cell_values = []
for coordinate in coordinates:
cell_values.append(table.iat[coordinate])
answers.append(", ".join(cell_values))
# Create the answer string
answer_str = ""
for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
if predicted_agg == "NONE":
answer_str = answer
else:
answer_str = f"{predicted_agg} : {answer}"
return {
"query": query,
"answer": answer_str
}, table