obichimav commited on
Commit
d136a46
·
verified ·
1 Parent(s): 72e9508

Update app.py

Browse files

Updated the app with more examples

Files changed (1) hide show
  1. app.py +234 -1
app.py CHANGED
@@ -1,3 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio App for NYC Taxi Fare Prediction & Road Route Visualization using OSRM
@@ -179,6 +387,28 @@ def predict_fare_and_map(plat, plong, dlat, dlong, psngr, dt):
179
  f"Route Distance (OSRM): {route_distance_km:.2f} km")
180
  return output_text, map_html
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # -----------------------------
183
  # Gradio Interface
184
  # -----------------------------
@@ -196,12 +426,15 @@ iface = gr.Interface(
196
  gr.Textbox(label="Prediction & Distance"),
197
  gr.HTML(label="Map")
198
  ],
 
199
  title="NYC Taxi Fare Prediction with OSRM Road Route",
200
  description=(
201
  "Enter pickup/dropoff coordinates, passenger count, and pickup datetime to predict the taxi fare. "
202
- "The app displays the actual road route (blue line) from OSRM on a Folium map."
 
203
  )
204
  )
205
 
206
  if __name__ == "__main__":
207
  iface.launch()
 
 
1
+ # #!/usr/bin/env python
2
+ # """
3
+ # Gradio App for NYC Taxi Fare Prediction & Road Route Visualization using OSRM
4
+
5
+ # Requirements:
6
+ # pip install torch gradio requests polyline folium pandas numpy
7
+ # """
8
+
9
+ # import torch
10
+ # import torch.nn as nn
11
+ # import numpy as np
12
+ # import pandas as pd
13
+ # import requests
14
+ # import polyline
15
+ # import folium
16
+ # import gradio as gr
17
+
18
+ # # -----------------------------
19
+ # # Model Definition (TabularModel)
20
+ # # -----------------------------
21
+ # class TabularModel(nn.Module):
22
+ # def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):
23
+ # """
24
+ # Model for tabular data combining embeddings for categorical variables and
25
+ # a feed-forward network for continuous features.
26
+ # """
27
+ # super().__init__()
28
+ # self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in emb_szs])
29
+ # self.emb_drop = nn.Dropout(p)
30
+ # self.bn_cont = nn.BatchNorm1d(n_cont)
31
+
32
+ # n_emb = sum([nf for _, nf in emb_szs])
33
+ # n_in = n_emb + n_cont
34
+
35
+ # layerlist = []
36
+ # for i in layers:
37
+ # layerlist.append(nn.Linear(n_in, i))
38
+ # layerlist.append(nn.ReLU(inplace=True))
39
+ # layerlist.append(nn.BatchNorm1d(i))
40
+ # layerlist.append(nn.Dropout(p))
41
+ # n_in = i
42
+ # layerlist.append(nn.Linear(layers[-1], out_sz))
43
+ # self.layers = nn.Sequential(*layerlist)
44
+
45
+ # def forward(self, x_cat, x_cont):
46
+ # embeddings = []
47
+ # for i, e in enumerate(self.embeds):
48
+ # embeddings.append(e(x_cat[:, i]))
49
+ # x = torch.cat(embeddings, 1)
50
+ # x = self.emb_drop(x)
51
+
52
+ # x_cont = self.bn_cont(x_cont)
53
+ # x = torch.cat([x, x_cont], 1)
54
+ # x = self.layers(x)
55
+ # return x
56
+
57
+ # # -----------------------------
58
+ # # Load the trained model
59
+ # # -----------------------------
60
+ # # These parameters must match those used during training.
61
+ # emb_szs = [(24, 12), (2, 1), (7, 4)]
62
+ # n_cont = 6 # [pickup_lat, pickup_long, dropoff_lat, dropoff_long, passenger_count, dist_km]
63
+ # out_sz = 1
64
+ # layers = [200, 100]
65
+ # p = 0.4
66
+
67
+ # model = TabularModel(emb_szs, n_cont, out_sz, layers, p)
68
+ # # Load model state (using weights_only=True to address the warning)
69
+ # model.load_state_dict(torch.load("TaxiFareRegrModel.pt", map_location=torch.device("cpu"), weights_only=True))
70
+ # model.eval()
71
+
72
+ # # -----------------------------
73
+ # # Helper Function: Haversine
74
+ # # -----------------------------
75
+ # def haversine_distance_coords(lat1, lon1, lat2, lon2):
76
+ # """Compute haversine distance (in km) between two coordinate pairs."""
77
+ # r = 6371 # Earth's radius in kilometers
78
+ # phi1 = np.radians(lat1)
79
+ # phi2 = np.radians(lat2)
80
+ # delta_phi = np.radians(lat2 - lat1)
81
+ # delta_lambda = np.radians(lon2 - lon1)
82
+ # a = np.sin(delta_phi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(delta_lambda/2)**2
83
+ # c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
84
+ # return r * c
85
+
86
+ # # -----------------------------
87
+ # # OSRM Route Retrieval
88
+ # # -----------------------------
89
+ # def get_osrm_route(lat1, lon1, lat2, lon2):
90
+ # """
91
+ # Query OSRM for a route between (lat1, lon1) and (lat2, lon2).
92
+ # Returns:
93
+ # - route_points: list of (lat, lon) tuples for the route polyline
94
+ # - distance_m: route distance in meters (from OSRM)
95
+ # - duration_s: route duration in seconds (from OSRM)
96
+ # """
97
+ # # OSRM expects coordinates as "lon,lat;lon,lat"
98
+ # coords = f"{lon1},{lat1};{lon2},{lat2}"
99
+ # OSRM_URL = f"http://router.project-osrm.org/route/v1/driving/{coords}?overview=full&geometries=polyline"
100
+
101
+ # response = requests.get(OSRM_URL)
102
+ # response.raise_for_status()
103
+ # data = response.json()
104
+
105
+ # if data.get("code") != "Ok":
106
+ # raise Exception("Route not found")
107
+
108
+ # route = data["routes"][0]
109
+ # encoded_poly = route["geometry"]
110
+ # route_points = polyline.decode(encoded_poly)
111
+ # distance_m = route["distance"]
112
+ # duration_s = route["duration"]
113
+ # return route_points, distance_m, duration_s
114
+
115
+ # # -----------------------------
116
+ # # Main Prediction & Mapping Function
117
+ # # -----------------------------
118
+ # def predict_fare_and_map(plat, plong, dlat, dlong, psngr, dt):
119
+ # """
120
+ # 1. Process pickup datetime to extract categorical features.
121
+ # 2. Compute haversine distance for the model input.
122
+ # 3. Use the PyTorch model to predict the taxi fare.
123
+ # 4. Query OSRM for the actual road route geometry & distance.
124
+ # 5. Draw a Folium map with the OSRM route (blue line) and markers.
125
+ # 6. Return a text string with predicted fare and route distance, plus the map HTML.
126
+ # """
127
+ # # Process datetime
128
+ # try:
129
+ # pickup_dt = pd.to_datetime(dt)
130
+ # except Exception as e:
131
+ # return f"Error parsing date/time: {e}", ""
132
+
133
+ # hour = pickup_dt.hour
134
+ # am_or_pm = 0 if hour < 12 else 1
135
+ # weekday_str = pickup_dt.strftime("%a")
136
+ # weekday_map = {'Fri': 0, 'Mon': 1, 'Sat': 2, 'Sun': 3, 'Thu': 4, 'Tue': 5, 'Wed': 6}
137
+ # weekday = weekday_map.get(weekday_str, 0)
138
+
139
+ # # Prepare tensors for model input (use haversine distance)
140
+ # dist_km = haversine_distance_coords(plat, plong, dlat, dlong)
141
+ # cat_array = np.array([[hour, am_or_pm, weekday]])
142
+ # cat_tensor = torch.tensor(cat_array, dtype=torch.int64)
143
+ # cont_array = np.array([[plat, plong, dlat, dlong, psngr, dist_km]])
144
+ # cont_tensor = torch.tensor(cont_array, dtype=torch.float)
145
+
146
+ # # Predict fare
147
+ # with torch.no_grad():
148
+ # pred = model(cat_tensor, cont_tensor)
149
+ # fare_pred = pred.item()
150
+
151
+ # # Get route from OSRM
152
+ # try:
153
+ # route_points, route_distance_m, route_duration_s = get_osrm_route(plat, plong, dlat, dlong)
154
+ # except Exception as e:
155
+ # return f"Error from OSRM: {e}", ""
156
+
157
+ # # Create Folium map centered between pickup & dropoff
158
+ # mid_lat = (plat + dlat) / 2
159
+ # mid_lon = (plong + dlong) / 2
160
+ # m = folium.Map(location=[mid_lat, mid_lon], zoom_start=12)
161
+
162
+ # # Add markers
163
+ # folium.Marker([plat, plong], icon=folium.Icon(color="green"), tooltip="Pickup").add_to(m)
164
+ # folium.Marker([dlat, dlong], icon=folium.Icon(color="red"), tooltip="Dropoff").add_to(m)
165
+
166
+ # # Draw the route polyline (blue line) with popup showing OSRM distance
167
+ # folium.PolyLine(
168
+ # route_points,
169
+ # color="blue",
170
+ # weight=3,
171
+ # opacity=0.7,
172
+ # popup=f"OSRM Distance: {route_distance_m/1000:.2f} km"
173
+ # ).add_to(m)
174
+
175
+ # map_html = m._repr_html_()
176
+
177
+ # route_distance_km = route_distance_m / 1000
178
+ # output_text = (f"Predicted Fare: ${fare_pred:.2f}\n"
179
+ # f"Route Distance (OSRM): {route_distance_km:.2f} km")
180
+ # return output_text, map_html
181
+
182
+ # # -----------------------------
183
+ # # Gradio Interface
184
+ # # -----------------------------
185
+ # iface = gr.Interface(
186
+ # fn=predict_fare_and_map,
187
+ # inputs=[
188
+ # gr.Number(label="Pickup Latitude", value=40.75),
189
+ # gr.Number(label="Pickup Longitude", value=-73.99),
190
+ # gr.Number(label="Dropoff Latitude", value=40.73),
191
+ # gr.Number(label="Dropoff Longitude", value=-73.98),
192
+ # gr.Number(label="Passenger Count", value=1),
193
+ # gr.Textbox(label="Pickup Date and Time (YYYY-MM-DD HH:MM:SS)", value="2010-04-19 08:17:56")
194
+ # ],
195
+ # outputs=[
196
+ # gr.Textbox(label="Prediction & Distance"),
197
+ # gr.HTML(label="Map")
198
+ # ],
199
+ # title="NYC Taxi Fare Prediction with OSRM Road Route",
200
+ # description=(
201
+ # "Enter pickup/dropoff coordinates, passenger count, and pickup datetime to predict the taxi fare. "
202
+ # "The app displays the actual road route (blue line) from OSRM on a Folium map."
203
+ # )
204
+ # )
205
+
206
+ # if __name__ == "__main__":
207
+ # iface.launch()
208
+
209
  #!/usr/bin/env python
210
  """
211
  Gradio App for NYC Taxi Fare Prediction & Road Route Visualization using OSRM
 
387
  f"Route Distance (OSRM): {route_distance_km:.2f} km")
388
  return output_text, map_html
389
 
390
+ # -----------------------------
391
+ # Example Locations (Popular NYC Spots)
392
+ # Each example is a list of 6 inputs:
393
+ # [pickup_lat, pickup_lon, dropoff_lat, dropoff_lon, passenger_count, pickup_datetime]
394
+ # -----------------------------
395
+ examples = [
396
+ # 1. Times Square to Central Park (short ride)
397
+ [40.7580, -73.9855, 40.7690, -73.9819, 1, "2010-04-19 08:17:56"],
398
+ # 2. Times Square to JFK Airport (long ride)
399
+ [40.7580, -73.9855, 40.6413, -73.7781, 1, "2010-04-19 08:17:56"],
400
+ # 3. Grand Central Terminal to Empire State Building (very short ride)
401
+ [40.7527, -73.9772, 40.7484, -73.9857, 1, "2010-04-19 08:17:56"],
402
+ # 4. Brooklyn Bridge to Wall Street (short urban ride)
403
+ [40.7061, -73.9969, 40.7069, -74.0113, 1, "2010-04-19 08:17:56"],
404
+ # 5. Yankee Stadium to Central Park (moderate ride)
405
+ [40.8296, -73.9262, 40.7829, -73.9654, 1, "2010-04-19 08:17:56"],
406
+ # 6. Columbia University area to Rockefeller Center (cross-city ride)
407
+ [40.8075, -73.9626, 40.7587, -73.9787, 1, "2010-04-19 08:17:56"],
408
+ # 7. Battery Park to Central Park (longer ride across Manhattan)
409
+ [40.7033, -74.0170, 40.7829, -73.9654, 1, "2010-04-19 08:17:56"]
410
+ ]
411
+
412
  # -----------------------------
413
  # Gradio Interface
414
  # -----------------------------
 
426
  gr.Textbox(label="Prediction & Distance"),
427
  gr.HTML(label="Map")
428
  ],
429
+ examples=examples,
430
  title="NYC Taxi Fare Prediction with OSRM Road Route",
431
  description=(
432
  "Enter pickup/dropoff coordinates, passenger count, and pickup datetime to predict the taxi fare. "
433
+ "The app displays the actual road route (blue line) from OSRM on a Folium map. "
434
+ "You can also choose from several example routes between popular locations in New York."
435
  )
436
  )
437
 
438
  if __name__ == "__main__":
439
  iface.launch()
440
+