import os
import sys
import torch
import pandas as pd
import streamlit as st
from datetime import datetime
from transformers import (
    T5ForConditionalGeneration, 
    T5Tokenizer,
    Trainer, 
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
import random

# Ensure reproducibility
torch.manual_seed(42)
random.seed(42)

# Environment setup
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

class TravelDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.data = data
        self.max_length = max_length

        print(f"Dataset loaded with {len(data)} samples")
        print("Columns:", list(data.columns))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Input: query
        input_text = row['query']
        # Target: reference_information
        target_text = row['reference_information']
        
        # Tokenize inputs
        input_encodings = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize targets
        target_encodings = self.tokenizer(
            target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': input_encodings['input_ids'].squeeze(),
            'attention_mask': input_encodings['attention_mask'].squeeze(),
            'labels': target_encodings['input_ids'].squeeze()
        }

def load_dataset():
    """
    Load the travel planning dataset from CSV.
    """
    try:
        data = pd.read_csv("hf://datasets/osunlp/TravelPlanner/train.csv")
        
        required_columns = ['query', 'reference_information']
        for col in required_columns:
            if col not in data.columns:
                raise ValueError(f"Missing required column: {col}")
        
        print(f"Dataset loaded successfully with {len(data)} rows.")
        return data
    except Exception as e:
        print(f"Error loading dataset: {e}")
        sys.exit(1)

def train_model():
    try:
        # Load dataset
        data = load_dataset()

        # Initialize model and tokenizer
        print("Initializing T5 model and tokenizer...")
        tokenizer = T5Tokenizer.from_pretrained('t5-base', legacy=False)
        model = T5ForConditionalGeneration.from_pretrained('t5-base')

        # Split data
        train_size = int(0.8 * len(data))
        train_data = data[:train_size]
        val_data = data[train_size:]

        train_dataset = TravelDataset(train_data, tokenizer)
        val_dataset = TravelDataset(val_data, tokenizer)

        training_args = TrainingArguments(
            output_dir="./trained_travel_planner",
            num_train_epochs=3,
            per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            evaluation_strategy="steps",
            eval_steps=50,
            save_steps=100,
            weight_decay=0.01,
            logging_dir="./logs",
            logging_steps=10,
            load_best_model_at_end=True,
        )

        data_collator = DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            model=model,
            padding=True
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator
        )

        print("Training model...")
        trainer.train()

        model.save_pretrained("./trained_travel_planner")
        tokenizer.save_pretrained("./trained_travel_planner")

        print("Model training complete!")
        return model, tokenizer
    except Exception as e:
        print(f"Training error: {e}")
        return None, None

def generate_travel_plan(query, model, tokenizer):
    """
    Generate a travel plan using the trained model.
    """
    try:
        inputs = tokenizer(
            query,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True
        )

        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
            model = model.cuda()

        outputs = model.generate(
            **inputs,
            max_length=512,
            num_beams=4,
            no_repeat_ngram_size=3,
            num_return_sequences=1
        )

        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error generating travel plan: {e}"

def main():
    st.set_page_config(
        page_title="AI Travel Planner",
        page_icon="✈️",
        layout="wide"
    )
    st.title("✈️ AI Travel Planner")

    # Sidebar to train model
    with st.sidebar:
        st.header("Model Management")
        if st.button("Retrain Model"):
            with st.spinner("Training the model..."):
                model, tokenizer = train_model()
                if model:
                    st.session_state['model'] = model
                    st.session_state['tokenizer'] = tokenizer
                    st.success("Model retrained successfully!")
                else:
                    st.error("Model retraining failed.")

    # Load model if not already loaded
    if 'model' not in st.session_state:
        with st.spinner("Loading model..."):
            model, tokenizer = train_model()
            st.session_state['model'] = model
            st.session_state['tokenizer'] = tokenizer

    # Input query
    st.subheader("Plan Your Trip")
    query = st.text_area("Enter your trip query (e.g., 'Plan a 3-day trip to Paris focusing on culture and food')")

    if st.button("Generate Plan"):
        if not query:
            st.error("Please enter a query.")
        else:
            with st.spinner("Generating your travel plan..."):
                travel_plan = generate_travel_plan(
                    query, 
                    st.session_state['model'], 
                    st.session_state['tokenizer']
                )
                st.subheader("Your Travel Plan")
                st.write(travel_plan)

if __name__ == "__main__":
    main()