Spaces:
Sleeping
Sleeping
File size: 7,618 Bytes
124d1f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, util
import faiss
import numpy as np
from transformers import pipeline
import time
import ast
import re
# --- 1. DATA LOADING AND INITIALIZATION ---
print("===== Application Startup =====")
start_time = time.time()
# Load the travel dataset and limit to the first 20,000 rows (same as index)
print("Loading TravelPlanner dataset...")
dataset = load_dataset("osunlp/TravelPlanner", "test")
print("Dataset ready.")
# --- 2. EMBEDDING AND RECOMMENDATION ENGINE ---
print("Loading embedding model...")
model_name = "all-mpnet-base-v2"
embedding_model = SentenceTransformer(f"sentence-transformers/{model_name}")
index_file = "trip_index.faiss"
print(f"Loading FAISS index from {index_file}...")
try:
index = faiss.read_index(index_file)
print(f"Index is ready. Total vectors in index: {index.ntotal}")
except RuntimeError:
print(f"Error: FAISS index file '{index_file}' not found.")
print("Please run the `build_index.py` script first to create the index.")
exit()
# --- 3. SYNTHETIC GENERATION ---
def format_plan_details(plan_string):
"""
Parses and formats the raw plan string from the dataset into readable Markdown.
"""
# If the plan is not in the expected dictionary format, return it as is.
if not plan_string or not plan_string.strip().startswith('['):
return plan_string
try:
# Safely parse the string representation of a list of dictionaries
plan_list = ast.literal_eval(plan_string)
except (ValueError, SyntaxError):
# If parsing fails, return the original string to avoid crashing
return plan_string
formatted_sections = []
for section in plan_list:
description = section.get('Description', 'Details')
content = section.get('Content', '').strip()
# Add a bold title for each section
formatted_sections.append(f"#### {description}")
# Use specific formatting based on the section's description
if any(keyword in description for keyword in ['Attractions', 'Restaurants', 'Accommodations', 'Flight']):
lines = content.split('\n')
if lines:
# Make the header bold
formatted_sections.append(f"**{lines[0]}**")
# Format the rest of the lines as a clean, bulleted list
for item in lines[1:]:
clean_item = ' '.join(item.split()) # Remove extra whitespace
if clean_item:
formatted_sections.append(f"- {clean_item}")
elif 'Self-driving' in description or 'Taxi' in description:
# Make simple travel descriptions more readable
mode_emoji = "🚗" if 'Self-driving' in description else "🚕"
formatted_sections.append(f"- {mode_emoji} {content.replace(', ', ', ')}")
else:
# Default formatting for any other type of content
formatted_sections.append(content)
# Add a newline for spacing between sections
formatted_sections.append("")
return "\n".join(formatted_sections)
def get_recommendations_and_generate(query_text, k=3):
# 1. Get Recommendations from existing data
query_vector = embedding_model.encode([query_text])
query_vector = np.array(query_vector, dtype=np.float32)
distances, indices = index.search(query_vector, k)
results = []
for idx_numpy in indices[0]:
idx = int(idx_numpy)
trip_plan = {
"dest": dataset['test']['dest'][idx],
"days": dataset['test']['days'][idx],
"reference_information": dataset['test']['reference_information'][idx]
}
results.append(trip_plan)
while len(results) < 3:
results.append({"dest": "No trip plan found", "days":"", "reference_information": ""})
# 2. Create a prompt for the generative model
prompt = f"Write a complete travel plan that includes a title and a day-by-day itinerary. The trip must be about: {query_text}."
print("Loading generative model...")
generator = pipeline('text-generation', model='gpt2')
# 3. Generate 10 new, creative trip ideas
print("Generating 10 synthetic trip ideas...")
generated_outputs = generator(
prompt,
max_new_tokens=250, # Increased tokens for more detailed plans
num_return_sequences=10,
pad_token_id=50256
)
# 4. Find the best trip out of the 10 generated
print("Finding the most relevant generated trip...")
generated_texts = [output['generated_text'].replace(prompt, "").strip() for output in generated_outputs]
# Embed all 10 generated texts
generated_embeddings = embedding_model.encode(generated_texts)
# Calculate cosine similarity between the user's query and each generated text
similarities = util.cos_sim(query_vector, generated_embeddings)
# Find the index of the most similar generated trip
best_recipe_index = np.argmax(similarities)
best_generated_trip = generated_texts[best_recipe_index]
return results[0], results[1], results[2], best_generated_trip
# --- 4. GRADIO USER INTERFACE ---
def format_trip_plan(trip):
# Formats the recommended trips with markdown
if not trip or 'reference_information' not in trip:
return "### No similar trip plan found."
formatted_plan = format_plan_details(trip['reference_information'])
return f"### {trip['days']}-days trip to {trip['dest'].upper()}\n**Suggested Plan:**\n{formatted_plan}"
def format_generated_trip(trip_text):
return trip_text
def trip_planner_wizard(destination, days):
# Combine user inputs into a single query for processing
days = int(days) # Ensure days is an integer for the f-string
query_text = f"a {days}-day trip to {destination}"
rec1, rec2, rec3, gen_rec_text = get_recommendations_and_generate(query_text)
return format_trip_plan(rec1), format_trip_plan(rec2), format_trip_plan(rec3), format_generated_trip(gen_rec_text)
end_time = time.time()
print(f"Models and data loaded in {end_time - start_time:.2f} seconds.")
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ✈️ TripPlanner AI")
gr.Markdown("Enter your destination and desired trip length, and get plan recommendations plus a new AI-generated idea!")
with gr.Row():
destination_input = gr.Textbox(label="Destination", placeholder="e.g., Paris")
days_input = gr.Number(label="Number of Days", value=3)
with gr.Row():
submit_btn = gr.Button("Get Trip Plans", variant="primary")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### Recommended Trip Plans from Dataset")
output_rec1 = gr.Markdown()
output_rec2 = gr.Markdown()
output_rec3 = gr.Markdown()
with gr.Column(scale=1):
gr.Markdown("### ✨ New AI-Generated Idea")
output_gen = gr.Textbox(label="AI Generated Trip Plan", lines=20, interactive=False)
submit_btn.click(
fn=trip_planner_wizard,
inputs=[destination_input, days_input],
outputs=[output_rec1, output_rec2, output_rec3, output_gen]
)
gr.Examples(
examples=[
["Paris", 3],
["Orlando", 7],
["Tokyo", 5],
["the Greek Islands", 10]
],
inputs=[destination_input, days_input]
)
demo.launch(ssr_mode=False)
|