import json
import os
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm


def load_table_schemas(tables_file):
    """
    Load table schemas from the tables.jsonl file.
    
    Args:
        tables_file: Path to the tables.jsonl file.
    
    Returns:
        A dictionary mapping table IDs to their column names.
    """
    table_schemas = {}
    with open(tables_file, 'r') as f:
        for line in f:
            table_data = json.loads(line)
            table_id = table_data["id"]
            table_columns = table_data["header"]
            table_schemas[table_id] = table_columns
    return table_schemas


# Step 1: Load and Preprocess WikiSQL Data
def load_wikisql(data_dir):
    """
    Load WikiSQL data and prepare it for training.
    Args:
        data_dir: Path to the WikiSQL dataset directory.
    Returns:
        List of examples with input and target text.
    """
    def parse_file(file_path):
        with open(file_path, 'r') as f:
            return [json.loads(line) for line in f]

    tables_data = parse_file(os.path.join(data_dir, "train.tables.jsonl"))
    train_data = parse_file(os.path.join(data_dir, "train.jsonl"))
    dev_data = parse_file(os.path.join(data_dir, "dev.jsonl"))

    print("====>", train_data[0])
    tables_file = "./data/train.tables.jsonl"
    table_schemas = load_table_schemas(tables_file)

    dev_tables = './data/dev.tables.jsonl'
    dev_tables_schema = load_table_schemas(dev_tables)

    def format_data(data, type):
        formatted = []
        for item in data:
            table_id = item["table_id"]
            table_columns = table_schemas[table_id] if type == 'train' else dev_tables_schema[table_id]
            question = item["question"]
            sql = item["sql"]
            sql_query = sql_to_text(sql, table_columns)
            print("SQL Query", sql_query)
            formatted.append({"input": f"Question: {question}", "target": sql_query})
        return formatted

    return format_data(train_data, "train"), format_data(dev_data, "dev")


def sql_to_text(sql, table_columns):
    """
    Convert SQL dictionary from WikiSQL to text representation.
    
    Args:
        sql: SQL dictionary from WikiSQL (e.g., {"sel": 5, "conds": [[3, 0, "value"]], "agg": 0}).
        table_columns: List of column names corresponding to the table.
        
    Returns:
        SQL query as a string.
    """
    # Aggregation functions mapping
    agg_functions = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
    operators = ["=", ">", "<"]

    # Get selected column
    sel_column = table_columns[sql["sel"]]
    agg_func = agg_functions[sql["agg"]]
    select_clause = f"SELECT {agg_func}({sel_column})" if agg_func else f"SELECT {sel_column}"

    # Get conditions
    if sql["conds"]:
        conditions = []
        for cond in sql["conds"]:
            col_idx, operator, value = cond
            col_name = table_columns[col_idx]
            conditions.append(f"{col_name} {operators[operator]} '{value}'")
        where_clause = " WHERE " + " AND ".join(conditions)
    else:
        where_clause = ""

    # Combine clauses into a full query
    return select_clause + where_clause

# Step 2: Tokenize the Data
def tokenize_data(data, tokenizer, max_length=128):
    """
    Tokenize the input and target text.
    Args:
        data: List of examples with "input" and "target".
        tokenizer: Pretrained tokenizer.
        max_length: Maximum sequence length for the model.
    Returns:
        Tokenized dataset.
    """
    inputs = [item["input"] for item in data]
    targets = [item["target"] for item in data]

    tokenized = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    labels = tokenizer(
        targets,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    tokenized["labels"] = labels["input_ids"]
    return tokenized


# Step 3: Load Model and Tokenizer
model_name = "t5-small"  # Use "t5-small", "t5-base", or "t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Step 4: Prepare Training and Validation Data
data_dir = "data"  # Path to the WikiSQL dataset
train_data, dev_data = load_wikisql(data_dir)

# Tokenize Data
train_dataset = tokenize_data(train_data, tokenizer)
dev_dataset = tokenize_data(dev_data, tokenizer)

# # Convert to Hugging Face Dataset format
train_dataset = Dataset.from_dict(train_dataset)
dev_dataset = Dataset.from_dict(dev_dataset)

# # # Step 5: Define Training Arguments
# training_args = Seq2SeqTrainingArguments(
#     output_dir="./t5_sql_finetuned",
#     evaluation_strategy="steps",
#     save_steps=1000,
#     eval_steps=100,
#     logging_steps=100,
#     per_device_train_batch_size=16,
#     per_device_eval_batch_size=16,
#     num_train_epochs=3,
#     save_total_limit=2,
#     learning_rate=5e-5,
#     predict_with_generate=True,
#     fp16=torch.cuda.is_available(),  # Enable mixed precision for faster training
#     logging_dir="./logs",
# )

# # # Step 6: Define Trainer
# trainer = Seq2SeqTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=dev_dataset,
#     tokenizer=tokenizer,
# )

# # # Step 7: Train the Model
# trainer.train()

# # # Step 8: Save the Model
# trainer.save_model("./t5_sql_finetuned")
# tokenizer.save_pretrained("./t5_sql_finetuned")

# # Step 9: Test the Model
test_question = "Find all orders with product_id greater than 5."
input_text = f"Question: {test_question}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)

outputs = model.generate(**inputs, max_length=128)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated SQL:", generated_sql)