File size: 7,618 Bytes
124d1f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
import faiss
import numpy as np
from transformers import pipeline
import time
import ast
import re

# --- 1. DATA LOADING AND INITIALIZATION ---
print("===== Application Startup =====")
start_time = time.time()

# Load the travel dataset and limit to the first 20,000 rows (same as index)
print("Loading TravelPlanner dataset...")
dataset = load_dataset("osunlp/TravelPlanner", "test")
print("Dataset ready.")

# --- 2. EMBEDDING AND RECOMMENDATION ENGINE ---
print("Loading embedding model...")
model_name = "all-mpnet-base-v2"
embedding_model = SentenceTransformer(f"sentence-transformers/{model_name}")

index_file = "trip_index.faiss"

print(f"Loading FAISS index from {index_file}...")

try:
    index = faiss.read_index(index_file)
    print(f"Index is ready. Total vectors in index: {index.ntotal}")
except RuntimeError:
    print(f"Error: FAISS index file '{index_file}' not found.")
    print("Please run the `build_index.py` script first to create the index.")
    exit()


# --- 3. SYNTHETIC GENERATION ---
def format_plan_details(plan_string):
    """
    Parses and formats the raw plan string from the dataset into readable Markdown.
    """
    # If the plan is not in the expected dictionary format, return it as is.
    if not plan_string or not plan_string.strip().startswith('['):
        return plan_string

    try:
        # Safely parse the string representation of a list of dictionaries
        plan_list = ast.literal_eval(plan_string)
    except (ValueError, SyntaxError):
        # If parsing fails, return the original string to avoid crashing
        return plan_string

    formatted_sections = []
    for section in plan_list:
        description = section.get('Description', 'Details')
        content = section.get('Content', '').strip()
        
        # Add a bold title for each section
        formatted_sections.append(f"#### {description}")

        # Use specific formatting based on the section's description
        if any(keyword in description for keyword in ['Attractions', 'Restaurants', 'Accommodations', 'Flight']):
            lines = content.split('\n')
            if lines:
                # Make the header bold
                formatted_sections.append(f"**{lines[0]}**")
                # Format the rest of the lines as a clean, bulleted list
                for item in lines[1:]:
                    clean_item = ' '.join(item.split()) # Remove extra whitespace
                    if clean_item:
                        formatted_sections.append(f"- {clean_item}")
        
        elif 'Self-driving' in description or 'Taxi' in description:
            # Make simple travel descriptions more readable
            mode_emoji = "🚗" if 'Self-driving' in description else "🚕"
            formatted_sections.append(f"- {mode_emoji} {content.replace(', ', ', ')}")
        
        else:
            # Default formatting for any other type of content
            formatted_sections.append(content)
            
        # Add a newline for spacing between sections
        formatted_sections.append("") 

    return "\n".join(formatted_sections)

def get_recommendations_and_generate(query_text, k=3):
    # 1. Get Recommendations from existing data
    query_vector = embedding_model.encode([query_text])
    query_vector = np.array(query_vector, dtype=np.float32)
    distances, indices = index.search(query_vector, k)
    
    results = []
    for idx_numpy in indices[0]:
        idx = int(idx_numpy)
        trip_plan = {
            "dest": dataset['test']['dest'][idx],
            "days": dataset['test']['days'][idx],
            "reference_information": dataset['test']['reference_information'][idx]
        }
        results.append(trip_plan)
        
    while len(results) < 3:
        results.append({"dest": "No trip plan found", "days":"", "reference_information": ""})
        
    # 2. Create a prompt for the generative model
    prompt = f"Write a complete travel plan that includes a title and a day-by-day itinerary. The trip must be about: {query_text}."
    print("Loading generative model...")
    generator = pipeline('text-generation', model='gpt2')
    
    # 3. Generate 10 new, creative trip ideas
    print("Generating 10 synthetic trip ideas...")
    generated_outputs = generator(
        prompt, 
        max_new_tokens=250,  # Increased tokens for more detailed plans
        num_return_sequences=10, 
        pad_token_id=50256
    )
    
    # 4. Find the best trip out of the 10 generated
    print("Finding the most relevant generated trip...")
    generated_texts = [output['generated_text'].replace(prompt, "").strip() for output in generated_outputs]
    
    # Embed all 10 generated texts
    generated_embeddings = embedding_model.encode(generated_texts)
    
    # Calculate cosine similarity between the user's query and each generated text
    similarities = util.cos_sim(query_vector, generated_embeddings)
    
    # Find the index of the most similar generated trip
    best_recipe_index = np.argmax(similarities)
    best_generated_trip = generated_texts[best_recipe_index]

    return results[0], results[1], results[2], best_generated_trip

# --- 4. GRADIO USER INTERFACE ---
def format_trip_plan(trip):
    # Formats the recommended trips with markdown
    if not trip or 'reference_information' not in trip:
        return "### No similar trip plan found."
    formatted_plan = format_plan_details(trip['reference_information'])
    return f"### {trip['days']}-days trip to {trip['dest'].upper()}\n**Suggested Plan:**\n{formatted_plan}"

def format_generated_trip(trip_text):
    return trip_text

def trip_planner_wizard(destination, days):
    # Combine user inputs into a single query for processing
    days = int(days) # Ensure days is an integer for the f-string
    query_text = f"a {days}-day trip to {destination}"
    rec1, rec2, rec3, gen_rec_text = get_recommendations_and_generate(query_text)
    return format_trip_plan(rec1), format_trip_plan(rec2), format_trip_plan(rec3), format_generated_trip(gen_rec_text)

end_time = time.time()
print(f"Models and data loaded in {end_time - start_time:.2f} seconds.")

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ✈️ TripPlanner AI")
    gr.Markdown("Enter your destination and desired trip length, and get plan recommendations plus a new AI-generated idea!")
    
    with gr.Row():
        destination_input = gr.Textbox(label="Destination", placeholder="e.g., Paris")
        days_input = gr.Number(label="Number of Days", value=3)
        
    with gr.Row():
        submit_btn = gr.Button("Get Trip Plans", variant="primary")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("### Recommended Trip Plans from Dataset")
            output_rec1 = gr.Markdown()
            output_rec2 = gr.Markdown()
            output_rec3 = gr.Markdown()
        with gr.Column(scale=1):
            gr.Markdown("### ✨ New AI-Generated Idea")
            output_gen = gr.Textbox(label="AI Generated Trip Plan", lines=20, interactive=False)

    submit_btn.click(
        fn=trip_planner_wizard,
        inputs=[destination_input, days_input],
        outputs=[output_rec1, output_rec2, output_rec3, output_gen]
    )
    
    gr.Examples(
        examples=[
            ["Paris", 3],
            ["Orlando", 7],
            ["Tokyo", 5],
            ["the Greek Islands", 10]
        ],
        inputs=[destination_input, days_input]
    )

demo.launch(ssr_mode=False)