import gradio as gr import pandas as pd from transformers import TapexTokenizer, BartForConditionalGeneration, pipeline # Initialize TAPEX (Microsoft) model and tokenizer tokenizer_tapex = TapexTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq") model_tapex = BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq") # Initialize TAPAS (Google) models and pipelines pipe_tapas = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wtq") pipe_tapas2 = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wikisql-supervised") def process_table_query(query, table_data): """ Process a query and CSV data using TAPEX. """ # Convert all columns in the table to strings for TAPEX compatibility table_data = table_data.astype(str) # Microsoft TAPEX model (using TAPEX tokenizer and model) encoding = tokenizer_tapex(table=table_data, query=query, return_tensors="pt", max_length=1024, truncation=True) outputs = model_tapex.generate(**encoding) result_tapex = tokenizer_tapex.batch_decode(outputs, skip_special_tokens=True)[0] return result_tapex # Gradio interface def answer_query_from_csv(query, file): """ Function to handle file input and return model results. """ # Read the file into a DataFrame table_data = pd.read_csv(file) # Convert object-type columns to lowercase (if they are valid strings) for column in table_data.columns: if table_data[column].dtype == 'object': table_data[column] = table_data[column].apply(lambda x: x.lower() if isinstance(x, str) else x) # Convert all table cells to strings for TAPEX compatibility table_data = table_data.astype(str) # Extract year, month, day, and time components for datetime columns for column in table_data.columns: if pd.api.types.is_datetime64_any_dtype(table_data[column]): table_data[f'{column}_year'] = table_data[column].dt.year table_data[f'{column}_month'] = table_data[column].dt.month table_data[f'{column}_day'] = table_data[column].dt.day table_data[f'{column}_time'] = table_data[column].dt.strftime('%H:%M:%S') # Process the CSV file and query result_tapex = process_table_query(query, table_data) # Process the query using TAPAS pipelines result_tapas = pipe_tapas(table=table_data, query=query)['cells'][0] result_tapas2 = pipe_tapas2(table=table_data, query=query)['cells'][0] return result_tapex, result_tapas, result_tapas2 # Create Gradio interface with gr.Blocks() as interface: gr.Markdown("# Table Question Answering with TAPEX and TAPAS Models") # Add a notice about the token limit gr.Markdown("### Note: Only the first 1024 tokens (query + table data) will be considered. If your table is too large, it will be truncated to fit within this limit.") # Two-column layout (input on the left, output on the right) with gr.Row(): with gr.Column(): # Input fields for the query and file query_input = gr.Textbox(label="Enter your query:") csv_input = gr.File(label="Upload your CSV file") with gr.Column(): # Output textboxes for the answers result_tapex = gr.Textbox(label="TAPEX Answer") result_tapas = gr.Textbox(label="TAPAS (WikiTableQuestions) Answer") result_tapas2 = gr.Textbox(label="TAPAS (WikiSQL) Answer") # Submit button submit_btn = gr.Button("Submit") # Action when submit button is clicked submit_btn.click( fn=answer_query_from_csv, inputs=[query_input, csv_input], outputs=[result_tapex, result_tapas, result_tapas2] ) # Launch the Gradio interface interface.launch(share=True)