obichimav's picture
Update app.py
d136a46 verified
# #!/usr/bin/env python
# """
# Gradio App for NYC Taxi Fare Prediction & Road Route Visualization using OSRM
# Requirements:
# pip install torch gradio requests polyline folium pandas numpy
# """
# import torch
# import torch.nn as nn
# import numpy as np
# import pandas as pd
# import requests
# import polyline
# import folium
# import gradio as gr
# # -----------------------------
# # Model Definition (TabularModel)
# # -----------------------------
# class TabularModel(nn.Module):
# def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):
# """
# Model for tabular data combining embeddings for categorical variables and
# a feed-forward network for continuous features.
# """
# super().__init__()
# self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in emb_szs])
# self.emb_drop = nn.Dropout(p)
# self.bn_cont = nn.BatchNorm1d(n_cont)
# n_emb = sum([nf for _, nf in emb_szs])
# n_in = n_emb + n_cont
# layerlist = []
# for i in layers:
# layerlist.append(nn.Linear(n_in, i))
# layerlist.append(nn.ReLU(inplace=True))
# layerlist.append(nn.BatchNorm1d(i))
# layerlist.append(nn.Dropout(p))
# n_in = i
# layerlist.append(nn.Linear(layers[-1], out_sz))
# self.layers = nn.Sequential(*layerlist)
# def forward(self, x_cat, x_cont):
# embeddings = []
# for i, e in enumerate(self.embeds):
# embeddings.append(e(x_cat[:, i]))
# x = torch.cat(embeddings, 1)
# x = self.emb_drop(x)
# x_cont = self.bn_cont(x_cont)
# x = torch.cat([x, x_cont], 1)
# x = self.layers(x)
# return x
# # -----------------------------
# # Load the trained model
# # -----------------------------
# # These parameters must match those used during training.
# emb_szs = [(24, 12), (2, 1), (7, 4)]
# n_cont = 6 # [pickup_lat, pickup_long, dropoff_lat, dropoff_long, passenger_count, dist_km]
# out_sz = 1
# layers = [200, 100]
# p = 0.4
# model = TabularModel(emb_szs, n_cont, out_sz, layers, p)
# # Load model state (using weights_only=True to address the warning)
# model.load_state_dict(torch.load("TaxiFareRegrModel.pt", map_location=torch.device("cpu"), weights_only=True))
# model.eval()
# # -----------------------------
# # Helper Function: Haversine
# # -----------------------------
# def haversine_distance_coords(lat1, lon1, lat2, lon2):
# """Compute haversine distance (in km) between two coordinate pairs."""
# r = 6371 # Earth's radius in kilometers
# phi1 = np.radians(lat1)
# phi2 = np.radians(lat2)
# delta_phi = np.radians(lat2 - lat1)
# delta_lambda = np.radians(lon2 - lon1)
# a = np.sin(delta_phi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(delta_lambda/2)**2
# c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
# return r * c
# # -----------------------------
# # OSRM Route Retrieval
# # -----------------------------
# def get_osrm_route(lat1, lon1, lat2, lon2):
# """
# Query OSRM for a route between (lat1, lon1) and (lat2, lon2).
# Returns:
# - route_points: list of (lat, lon) tuples for the route polyline
# - distance_m: route distance in meters (from OSRM)
# - duration_s: route duration in seconds (from OSRM)
# """
# # OSRM expects coordinates as "lon,lat;lon,lat"
# coords = f"{lon1},{lat1};{lon2},{lat2}"
# OSRM_URL = f"http://router.project-osrm.org/route/v1/driving/{coords}?overview=full&geometries=polyline"
# response = requests.get(OSRM_URL)
# response.raise_for_status()
# data = response.json()
# if data.get("code") != "Ok":
# raise Exception("Route not found")
# route = data["routes"][0]
# encoded_poly = route["geometry"]
# route_points = polyline.decode(encoded_poly)
# distance_m = route["distance"]
# duration_s = route["duration"]
# return route_points, distance_m, duration_s
# # -----------------------------
# # Main Prediction & Mapping Function
# # -----------------------------
# def predict_fare_and_map(plat, plong, dlat, dlong, psngr, dt):
# """
# 1. Process pickup datetime to extract categorical features.
# 2. Compute haversine distance for the model input.
# 3. Use the PyTorch model to predict the taxi fare.
# 4. Query OSRM for the actual road route geometry & distance.
# 5. Draw a Folium map with the OSRM route (blue line) and markers.
# 6. Return a text string with predicted fare and route distance, plus the map HTML.
# """
# # Process datetime
# try:
# pickup_dt = pd.to_datetime(dt)
# except Exception as e:
# return f"Error parsing date/time: {e}", ""
# hour = pickup_dt.hour
# am_or_pm = 0 if hour < 12 else 1
# weekday_str = pickup_dt.strftime("%a")
# weekday_map = {'Fri': 0, 'Mon': 1, 'Sat': 2, 'Sun': 3, 'Thu': 4, 'Tue': 5, 'Wed': 6}
# weekday = weekday_map.get(weekday_str, 0)
# # Prepare tensors for model input (use haversine distance)
# dist_km = haversine_distance_coords(plat, plong, dlat, dlong)
# cat_array = np.array([[hour, am_or_pm, weekday]])
# cat_tensor = torch.tensor(cat_array, dtype=torch.int64)
# cont_array = np.array([[plat, plong, dlat, dlong, psngr, dist_km]])
# cont_tensor = torch.tensor(cont_array, dtype=torch.float)
# # Predict fare
# with torch.no_grad():
# pred = model(cat_tensor, cont_tensor)
# fare_pred = pred.item()
# # Get route from OSRM
# try:
# route_points, route_distance_m, route_duration_s = get_osrm_route(plat, plong, dlat, dlong)
# except Exception as e:
# return f"Error from OSRM: {e}", ""
# # Create Folium map centered between pickup & dropoff
# mid_lat = (plat + dlat) / 2
# mid_lon = (plong + dlong) / 2
# m = folium.Map(location=[mid_lat, mid_lon], zoom_start=12)
# # Add markers
# folium.Marker([plat, plong], icon=folium.Icon(color="green"), tooltip="Pickup").add_to(m)
# folium.Marker([dlat, dlong], icon=folium.Icon(color="red"), tooltip="Dropoff").add_to(m)
# # Draw the route polyline (blue line) with popup showing OSRM distance
# folium.PolyLine(
# route_points,
# color="blue",
# weight=3,
# opacity=0.7,
# popup=f"OSRM Distance: {route_distance_m/1000:.2f} km"
# ).add_to(m)
# map_html = m._repr_html_()
# route_distance_km = route_distance_m / 1000
# output_text = (f"Predicted Fare: ${fare_pred:.2f}\n"
# f"Route Distance (OSRM): {route_distance_km:.2f} km")
# return output_text, map_html
# # -----------------------------
# # Gradio Interface
# # -----------------------------
# iface = gr.Interface(
# fn=predict_fare_and_map,
# inputs=[
# gr.Number(label="Pickup Latitude", value=40.75),
# gr.Number(label="Pickup Longitude", value=-73.99),
# gr.Number(label="Dropoff Latitude", value=40.73),
# gr.Number(label="Dropoff Longitude", value=-73.98),
# gr.Number(label="Passenger Count", value=1),
# gr.Textbox(label="Pickup Date and Time (YYYY-MM-DD HH:MM:SS)", value="2010-04-19 08:17:56")
# ],
# outputs=[
# gr.Textbox(label="Prediction & Distance"),
# gr.HTML(label="Map")
# ],
# title="NYC Taxi Fare Prediction with OSRM Road Route",
# description=(
# "Enter pickup/dropoff coordinates, passenger count, and pickup datetime to predict the taxi fare. "
# "The app displays the actual road route (blue line) from OSRM on a Folium map."
# )
# )
# if __name__ == "__main__":
# iface.launch()
#!/usr/bin/env python
"""
Gradio App for NYC Taxi Fare Prediction & Road Route Visualization using OSRM
Requirements:
pip install torch gradio requests polyline folium pandas numpy
"""
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import requests
import polyline
import folium
import gradio as gr
# -----------------------------
# Model Definition (TabularModel)
# -----------------------------
class TabularModel(nn.Module):
def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):
"""
Model for tabular data combining embeddings for categorical variables and
a feed-forward network for continuous features.
"""
super().__init__()
self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in emb_szs])
self.emb_drop = nn.Dropout(p)
self.bn_cont = nn.BatchNorm1d(n_cont)
n_emb = sum([nf for _, nf in emb_szs])
n_in = n_emb + n_cont
layerlist = []
for i in layers:
layerlist.append(nn.Linear(n_in, i))
layerlist.append(nn.ReLU(inplace=True))
layerlist.append(nn.BatchNorm1d(i))
layerlist.append(nn.Dropout(p))
n_in = i
layerlist.append(nn.Linear(layers[-1], out_sz))
self.layers = nn.Sequential(*layerlist)
def forward(self, x_cat, x_cont):
embeddings = []
for i, e in enumerate(self.embeds):
embeddings.append(e(x_cat[:, i]))
x = torch.cat(embeddings, 1)
x = self.emb_drop(x)
x_cont = self.bn_cont(x_cont)
x = torch.cat([x, x_cont], 1)
x = self.layers(x)
return x
# -----------------------------
# Load the trained model
# -----------------------------
# These parameters must match those used during training.
emb_szs = [(24, 12), (2, 1), (7, 4)]
n_cont = 6 # [pickup_lat, pickup_long, dropoff_lat, dropoff_long, passenger_count, dist_km]
out_sz = 1
layers = [200, 100]
p = 0.4
model = TabularModel(emb_szs, n_cont, out_sz, layers, p)
# Load model state (using weights_only=True to address the warning)
model.load_state_dict(torch.load("TaxiFareRegrModel.pt", map_location=torch.device("cpu"), weights_only=True))
model.eval()
# -----------------------------
# Helper Function: Haversine
# -----------------------------
def haversine_distance_coords(lat1, lon1, lat2, lon2):
"""Compute haversine distance (in km) between two coordinate pairs."""
r = 6371 # Earth's radius in kilometers
phi1 = np.radians(lat1)
phi2 = np.radians(lat2)
delta_phi = np.radians(lat2 - lat1)
delta_lambda = np.radians(lon2 - lon1)
a = np.sin(delta_phi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(delta_lambda/2)**2
c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
return r * c
# -----------------------------
# OSRM Route Retrieval
# -----------------------------
def get_osrm_route(lat1, lon1, lat2, lon2):
"""
Query OSRM for a route between (lat1, lon1) and (lat2, lon2).
Returns:
- route_points: list of (lat, lon) tuples for the route polyline
- distance_m: route distance in meters (from OSRM)
- duration_s: route duration in seconds (from OSRM)
"""
# OSRM expects coordinates as "lon,lat;lon,lat"
coords = f"{lon1},{lat1};{lon2},{lat2}"
OSRM_URL = f"http://router.project-osrm.org/route/v1/driving/{coords}?overview=full&geometries=polyline"
response = requests.get(OSRM_URL)
response.raise_for_status()
data = response.json()
if data.get("code") != "Ok":
raise Exception("Route not found")
route = data["routes"][0]
encoded_poly = route["geometry"]
route_points = polyline.decode(encoded_poly)
distance_m = route["distance"]
duration_s = route["duration"]
return route_points, distance_m, duration_s
# -----------------------------
# Main Prediction & Mapping Function
# -----------------------------
def predict_fare_and_map(plat, plong, dlat, dlong, psngr, dt):
"""
1. Process pickup datetime to extract categorical features.
2. Compute haversine distance for the model input.
3. Use the PyTorch model to predict the taxi fare.
4. Query OSRM for the actual road route geometry & distance.
5. Draw a Folium map with the OSRM route (blue line) and markers.
6. Return a text string with predicted fare and route distance, plus the map HTML.
"""
# Process datetime
try:
pickup_dt = pd.to_datetime(dt)
except Exception as e:
return f"Error parsing date/time: {e}", ""
hour = pickup_dt.hour
am_or_pm = 0 if hour < 12 else 1
weekday_str = pickup_dt.strftime("%a")
weekday_map = {'Fri': 0, 'Mon': 1, 'Sat': 2, 'Sun': 3, 'Thu': 4, 'Tue': 5, 'Wed': 6}
weekday = weekday_map.get(weekday_str, 0)
# Prepare tensors for model input (use haversine distance)
dist_km = haversine_distance_coords(plat, plong, dlat, dlong)
cat_array = np.array([[hour, am_or_pm, weekday]])
cat_tensor = torch.tensor(cat_array, dtype=torch.int64)
cont_array = np.array([[plat, plong, dlat, dlong, psngr, dist_km]])
cont_tensor = torch.tensor(cont_array, dtype=torch.float)
# Predict fare
with torch.no_grad():
pred = model(cat_tensor, cont_tensor)
fare_pred = pred.item()
# Get route from OSRM
try:
route_points, route_distance_m, route_duration_s = get_osrm_route(plat, plong, dlat, dlong)
except Exception as e:
return f"Error from OSRM: {e}", ""
# Create Folium map centered between pickup & dropoff
mid_lat = (plat + dlat) / 2
mid_lon = (plong + dlong) / 2
m = folium.Map(location=[mid_lat, mid_lon], zoom_start=12)
# Add markers
folium.Marker([plat, plong], icon=folium.Icon(color="green"), tooltip="Pickup").add_to(m)
folium.Marker([dlat, dlong], icon=folium.Icon(color="red"), tooltip="Dropoff").add_to(m)
# Draw the route polyline (blue line) with popup showing OSRM distance
folium.PolyLine(
route_points,
color="blue",
weight=3,
opacity=0.7,
popup=f"OSRM Distance: {route_distance_m/1000:.2f} km"
).add_to(m)
map_html = m._repr_html_()
route_distance_km = route_distance_m / 1000
output_text = (f"Predicted Fare: ${fare_pred:.2f}\n"
f"Route Distance (OSRM): {route_distance_km:.2f} km")
return output_text, map_html
# -----------------------------
# Example Locations (Popular NYC Spots)
# Each example is a list of 6 inputs:
# [pickup_lat, pickup_lon, dropoff_lat, dropoff_lon, passenger_count, pickup_datetime]
# -----------------------------
examples = [
# 1. Times Square to Central Park (short ride)
[40.7580, -73.9855, 40.7690, -73.9819, 1, "2010-04-19 08:17:56"],
# 2. Times Square to JFK Airport (long ride)
[40.7580, -73.9855, 40.6413, -73.7781, 1, "2010-04-19 08:17:56"],
# 3. Grand Central Terminal to Empire State Building (very short ride)
[40.7527, -73.9772, 40.7484, -73.9857, 1, "2010-04-19 08:17:56"],
# 4. Brooklyn Bridge to Wall Street (short urban ride)
[40.7061, -73.9969, 40.7069, -74.0113, 1, "2010-04-19 08:17:56"],
# 5. Yankee Stadium to Central Park (moderate ride)
[40.8296, -73.9262, 40.7829, -73.9654, 1, "2010-04-19 08:17:56"],
# 6. Columbia University area to Rockefeller Center (cross-city ride)
[40.8075, -73.9626, 40.7587, -73.9787, 1, "2010-04-19 08:17:56"],
# 7. Battery Park to Central Park (longer ride across Manhattan)
[40.7033, -74.0170, 40.7829, -73.9654, 1, "2010-04-19 08:17:56"]
]
# -----------------------------
# Gradio Interface
# -----------------------------
iface = gr.Interface(
fn=predict_fare_and_map,
inputs=[
gr.Number(label="Pickup Latitude", value=40.75),
gr.Number(label="Pickup Longitude", value=-73.99),
gr.Number(label="Dropoff Latitude", value=40.73),
gr.Number(label="Dropoff Longitude", value=-73.98),
gr.Number(label="Passenger Count", value=1),
gr.Textbox(label="Pickup Date and Time (YYYY-MM-DD HH:MM:SS)", value="2010-04-19 08:17:56")
],
outputs=[
gr.Textbox(label="Prediction & Distance"),
gr.HTML(label="Map")
],
examples=examples,
title="NYC Taxi Fare Prediction with OSRM Road Route",
description=(
"Enter pickup/dropoff coordinates, passenger count, and pickup datetime to predict the taxi fare. "
"The app displays the actual road route (blue line) from OSRM on a Folium map. "
"You can also choose from several example routes between popular locations in New York."
)
)
if __name__ == "__main__":
iface.launch()