DebatableMiracle commited on
Commit
31bef6c
·
1 Parent(s): 6d5cd51
Files changed (1) hide show
  1. app.py +701 -2
app.py CHANGED
@@ -1,4 +1,703 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import numpy as np
3
+ import gymnasium as gym
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+ import base64
11
+ from io import BytesIO
12
+ from PIL import Image
13
+ import time
14
 
15
+ # Use session state to persist data across reruns
16
+ if 'trained_qtable' not in st.session_state:
17
+ st.session_state.trained_qtable = None
18
+ if 'agent_videos' not in st.session_state:
19
+ st.session_state.agent_videos = {}
20
+ if 'training_completed' not in st.session_state:
21
+ st.session_state.training_completed = False
22
+ if 'training_params' not in st.session_state:
23
+ st.session_state.training_params = {}
24
+ if 'final_metrics' not in st.session_state:
25
+ st.session_state.final_metrics = {}
26
+
27
+ # Set page configuration for a cleaner look
28
+ st.set_page_config(
29
+ page_title="Taxi-v3 Q-Learning Dashboard",
30
+ page_icon="🚕",
31
+ layout="wide",
32
+ initial_sidebar_state="expanded"
33
+ )
34
+
35
+ # Custom CSS to make the dashboard look cleaner
36
+ st.markdown("""
37
+ <style>
38
+ .main .block-container {
39
+ padding-top: 2rem;
40
+ padding-bottom: 2rem;
41
+ }
42
+ .stTabs [data-baseweb="tab-list"] {
43
+ gap: 10px;
44
+ }
45
+ .stTabs [data-baseweb="tab"] {
46
+ background-color: #f0f2f6;
47
+ border-radius: 4px 4px 0px 0px;
48
+ padding: 10px 20px;
49
+ font-weight: 600;
50
+ }
51
+ .stTabs [aria-selected="true"] {
52
+ background-color: #e6f0ff;
53
+ border-bottom: 2px solid #4e8df5;
54
+ }
55
+ .reportview-container .main .block-container {
56
+ max-width: 1200px;
57
+ }
58
+ div[data-testid="stSidebarNav"] li div a {
59
+ margin-left: 1rem;
60
+ padding: 1rem;
61
+ width: 300px;
62
+ border-radius: 0.5rem;
63
+ }
64
+ div[data-testid="stSidebarNav"] li div::focus-visible {
65
+ background-color: rgba(151, 166, 195, 0.15);
66
+ }
67
+ .stMetric {
68
+ background-color: #f0f2f6;
69
+ padding: 15px 20px;
70
+ border-radius: 6px;
71
+ margin-bottom: 10px;
72
+ }
73
+ .css-12w0qpk {
74
+ background-color: #f8f9fa;
75
+ }
76
+ </style>
77
+ """, unsafe_allow_html=True)
78
+
79
+ # Header
80
+ st.markdown("""
81
+ <div style="text-align: center; margin-bottom: 30px;">
82
+ <h1 style="color: #1e3a8a; margin-bottom:0;">🚕 Taxi-v3 Q-Learning Dashboard</h1>
83
+ <p style="color: #64748b; font-size: 1.2em;">Interactive Reinforcement Learning Visualization</p>
84
+ </div>
85
+ """, unsafe_allow_html=True)
86
+
87
+ # Create a two-column layout for the main dashboard
88
+ col1, col2 = st.columns([3, 2])
89
+
90
+ with col2:
91
+ st.markdown("### 🎮 Environment Preview")
92
+
93
+ # Fix: Create a proper environment preview by resetting first
94
+ preview_env = gym.make("Taxi-v3", render_mode="rgb_array")
95
+ preview_env.reset() # Reset the environment first
96
+ env_preview = preview_env.render()
97
+ st.image(env_preview, caption="Taxi-v3 Environment", use_column_width=True)
98
+
99
+ st.markdown("""
100
+ <div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-top: 20px;">
101
+ <h4 style="margin-top: 0;">📝 About this Environment</h4>
102
+ <p>The Taxi-v3 task involves navigating a taxi to pick up a passenger and drop them off at a destination.</p>
103
+ <ul>
104
+ <li><b>Yellow</b>: taxi</li>
105
+ <li><b>Blue</b>: pick-up location</li>
106
+ <li><b>Purple</b>: drop-off location</li>
107
+ <li><b>Green</b>: passenger</li>
108
+ <li><b>Letters (R, G, Y, B)</b>: locations</li>
109
+ </ul>
110
+ </div>
111
+ """, unsafe_allow_html=True)
112
+
113
+ with col1:
114
+ st.markdown("### ⚙️ Training Parameters")
115
+
116
+ # Only show parameters if training hasn't completed yet
117
+ if not st.session_state.training_completed:
118
+ # Create a cleaner parameter input section
119
+ col_a, col_b = st.columns(2)
120
+
121
+ with col_a:
122
+ n_episodes = st.number_input("Training Episodes", min_value=1000, max_value=100000, value=25000, step=1000)
123
+ learning_rate = st.slider("Learning Rate (α)", 0.01, 1.0, 0.7, 0.01,
124
+ format="%.2f", help="Controls how much new information overrides old information")
125
+ gamma = st.slider("Discount Factor (γ)", 0.80, 0.99, 0.95, 0.01,
126
+ format="%.2f", help="Determines the importance of future rewards")
127
+ max_steps = st.slider("Max Steps per Episode", 50, 500, 99)
128
+
129
+ with col_b:
130
+ min_epsilon = st.slider("Min Exploration Rate (ε)", 0.01, 0.5, 0.05, 0.01,
131
+ format="%.2f", help="Minimum probability of random action")
132
+ max_epsilon = st.slider("Max Exploration Rate (ε)", 0.5, 1.0, 1.0, 0.01,
133
+ format="%.2f", help="Starting probability of random action")
134
+ decay_rate = st.slider("Epsilon Decay Rate", 0.0001, 0.01, 0.001, 0.0001,
135
+ format="%.4f", help="How quickly exploration decreases")
136
+ n_eval_episodes = st.slider("Evaluation Episodes", 10, 200, 100,
137
+ help="Number of episodes to evaluate performance")
138
+
139
+ # Additional parameters in a collapsed section
140
+ with st.expander("Advanced Settings"):
141
+ log_freq = st.slider("Q-table Update Frequency (every N episodes)", 1, 1000, 500)
142
+ eval_every = st.slider("Evaluation Frequency (% of training)", 5, 50, 10,
143
+ help="How often to evaluate agent performance")
144
+ video_length = st.slider("Evaluation Video Length (steps)", 10, 200, 50,
145
+ help="Maximum steps to show in visualization videos")
146
+ else:
147
+ # If training is completed, show the parameters that were used
148
+ st.info("Training completed with the following parameters:")
149
+ params = st.session_state.training_params
150
+
151
+ col_a, col_b = st.columns(2)
152
+ with col_a:
153
+ st.write(f"**Training Episodes**: {params['n_episodes']}")
154
+ st.write(f"**Learning Rate (α)**: {params['learning_rate']}")
155
+ st.write(f"**Discount Factor (γ)**: {params['gamma']}")
156
+ st.write(f"**Max Steps per Episode**: {params['max_steps']}")
157
+
158
+ with col_b:
159
+ st.write(f"**Min Exploration Rate (ε)**: {params['min_epsilon']}")
160
+ st.write(f"**Max Exploration Rate (ε)**: {params['max_epsilon']}")
161
+ st.write(f"**Epsilon Decay Rate**: {params['decay_rate']}")
162
+ st.write(f"**Evaluation Episodes**: {params['n_eval_episodes']}")
163
+
164
+ # Option to reset and train again
165
+ if st.button("Reset and Train Again", type="secondary"):
166
+ st.session_state.training_completed = False
167
+ st.session_state.trained_qtable = None
168
+ st.session_state.agent_videos = {}
169
+ st.session_state.training_params = {}
170
+ st.session_state.final_metrics = {}
171
+ st.experimental_rerun()
172
+
173
+ # Initialize Q-table
174
+ def initialize_q_table(state_space, action_space):
175
+ return np.zeros((state_space, action_space))
176
+
177
+ # Policies
178
+ def greedy_policy(Qtable, state):
179
+ return np.argmax(Qtable[state, :])
180
+
181
+ def epsilon_greedy_policy(Qtable, state, epsilon):
182
+ if np.random.uniform(0, 1) > epsilon:
183
+ return greedy_policy(Qtable, state)
184
+ else:
185
+ return env.action_space.sample()
186
+
187
+ # Function to create animation of agent behavior
188
+ def create_agent_video(env, Q, max_steps=100, seed=None):
189
+ frames = []
190
+ state, info = env.reset(seed=seed)
191
+
192
+ # Add the initial frame
193
+ frames.append(env.render())
194
+
195
+ for _ in range(max_steps):
196
+ # Choose action based on greedy policy
197
+ action = greedy_policy(Q, state)
198
+
199
+ # Step in environment
200
+ state, reward, terminated, truncated, _ = env.step(action)
201
+
202
+ # Render the frame after taking action
203
+ frames.append(env.render())
204
+
205
+ # Break if episode is done
206
+ if terminated or truncated:
207
+ break
208
+
209
+ return frames
210
+
211
+ # Evaluation function
212
+ def evaluate_agent(env, max_steps, n_eval_episodes, Q, seed=None):
213
+ rewards = []
214
+ steps = []
215
+ success_count = 0
216
+
217
+ for episode in range(n_eval_episodes):
218
+ state, info = env.reset(seed=seed[episode] if seed else None)
219
+ total_rewards_ep = 0
220
+ num_steps = 0
221
+
222
+ for step in range(max_steps):
223
+ action = greedy_policy(Q, state)
224
+ state, reward, terminated, truncated, _ = env.step(action)
225
+ total_rewards_ep += reward
226
+ num_steps += 1
227
+
228
+ if terminated or truncated:
229
+ if reward > 0: # Successfully completed the task
230
+ success_count += 1
231
+ break
232
+
233
+ rewards.append(total_rewards_ep)
234
+ steps.append(num_steps)
235
+
236
+ success_rate = success_count / n_eval_episodes * 100
237
+ return np.mean(rewards), np.std(rewards), np.mean(steps), success_rate
238
+
239
+ # Function to convert frames to HTML video
240
+ def frames_to_html_video(frames, fps=5):
241
+ if not frames:
242
+ return "<p>No frames available</p>"
243
+
244
+ try:
245
+ # Create a PIL image from each frame
246
+ pil_images = [Image.fromarray(frame) for frame in frames]
247
+
248
+ # Save as animated GIF to a BytesIO object
249
+ buffer = BytesIO()
250
+ pil_images[0].save(
251
+ buffer,
252
+ format='GIF',
253
+ save_all=True,
254
+ append_images=pil_images[1:],
255
+ duration=1000/fps,
256
+ loop=0
257
+ )
258
+ buffer.seek(0)
259
+
260
+ # Encode as base64
261
+ encoded = base64.b64encode(buffer.read()).decode("utf-8")
262
+
263
+ # Embed in HTML
264
+ html = f'<img src="data:image/gif;base64,{encoded}" alt="agent behavior" style="width:100%">'
265
+ return html
266
+ except Exception as e:
267
+ return f"<p>Error generating video: {str(e)}</p>"
268
+
269
+ # Training function
270
+ def train_agent(env, eval_env, params):
271
+ # Unpack parameters
272
+ n_episodes = params["n_episodes"]
273
+ learning_rate = params["learning_rate"]
274
+ gamma = params["gamma"]
275
+ max_steps = params["max_steps"]
276
+ min_epsilon = params["min_epsilon"]
277
+ max_epsilon = params["max_epsilon"]
278
+ decay_rate = params["decay_rate"]
279
+ log_freq = params["log_freq"]
280
+ eval_every = params["eval_every"]
281
+ n_eval_episodes = params["n_eval_episodes"]
282
+ video_length = params["video_length"]
283
+
284
+ # Store parameters in session state
285
+ st.session_state.training_params = params
286
+
287
+ # Initialize Q-table
288
+ Qtable = initialize_q_table(env.observation_space.n, env.action_space.n)
289
+
290
+ # Training metrics
291
+ reward_log = []
292
+ steps_log = []
293
+ qtable_snapshots = {}
294
+ videos = {}
295
+ epsilons = []
296
+
297
+ # Calculate checkpoints (ensure at least one checkpoint at the beginning)
298
+ num_checkpoints = 10 # Number of checkpoints to create
299
+ checkpoint_episodes = [int(n_episodes * i / num_checkpoints) for i in range(1, num_checkpoints + 1)]
300
+ checkpoint_episodes[0] = max(1, checkpoint_episodes[0]) # Ensure first checkpoint is at least at episode 1
301
+
302
+ # Progress tracking
303
+ progress_bar = st.progress(0)
304
+ status_text = st.empty()
305
+
306
+ # Dashboard components
307
+ tab1, tab2, tab3 = st.tabs(["📊 Training Progress", "🎬 Agent Evolution", "📋 Q-Table Visualization"])
308
+
309
+ with tab1:
310
+ col_metrics1, col_metrics2, col_metrics3, col_metrics4 = st.columns(4)
311
+ with col_metrics1:
312
+ current_episode_metric = st.empty()
313
+ with col_metrics2:
314
+ avg_reward_metric = st.empty()
315
+ with col_metrics3:
316
+ avg_steps_metric = st.empty()
317
+ with col_metrics4:
318
+ success_rate_metric = st.empty()
319
+
320
+ metrics_chart = st.empty()
321
+ epsilon_chart = st.empty()
322
+
323
+ with tab2:
324
+ video_placeholder = st.empty()
325
+
326
+ with tab3:
327
+ qtable_visualization = st.empty()
328
+
329
+ # Training loop
330
+ start_time = time.time()
331
+
332
+ for episode in range(n_episodes):
333
+ # Calculate epsilon for this episode
334
+ epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
335
+ epsilons.append(epsilon)
336
+
337
+ state, info = env.reset()
338
+ total_reward = 0
339
+
340
+ # Episode steps
341
+ for step in range(max_steps):
342
+ action = epsilon_greedy_policy(Qtable, state, epsilon)
343
+ new_state, reward, terminated, truncated, _ = env.step(action)
344
+
345
+ # Update Q-table
346
+ Qtable[state][action] += learning_rate * (
347
+ reward + gamma * np.max(Qtable[new_state, :]) - Qtable[state][action]
348
+ )
349
+
350
+ total_reward += reward
351
+ if terminated or truncated:
352
+ break
353
+
354
+ state = new_state
355
+
356
+ # Evaluation at checkpoints
357
+ if episode in checkpoint_episodes or episode == n_episodes - 1:
358
+ mean_reward, std_reward, mean_steps, success_rate = evaluate_agent(
359
+ env, max_steps, n_eval_episodes, Qtable
360
+ )
361
+ reward_log.append((episode, mean_reward, std_reward))
362
+ steps_log.append((episode, mean_steps))
363
+
364
+ # Create and store video of agent behavior
365
+ try:
366
+ video_frames = create_agent_video(eval_env, Qtable, max_steps=video_length)
367
+ videos[episode] = video_frames
368
+ except Exception as e:
369
+ st.warning(f"Could not create video for episode {episode}: {str(e)}")
370
+ videos[episode] = []
371
+
372
+ # Take Q-table snapshot
373
+ qtable_snapshots[episode] = Qtable.copy()
374
+
375
+ # Update metrics display
376
+ current_episode_metric.metric("Episodes", f"{episode}/{n_episodes}",
377
+ delta=f"{episode/n_episodes:.1%}")
378
+ avg_reward_metric.metric("Avg. Reward", f"{mean_reward:.2f}",
379
+ delta=f"{mean_reward - reward_log[-2][1]:.2f}" if len(reward_log) > 1 else None)
380
+ avg_steps_metric.metric("Avg. Steps", f"{mean_steps:.1f}")
381
+ success_rate_metric.metric("Success Rate", f"{success_rate:.1f}%")
382
+
383
+ # Update progress charts
384
+ if reward_log:
385
+ # Prepare data for plots
386
+ progress_df = pd.DataFrame(
387
+ reward_log, columns=["Episode", "Mean Reward", "Std Reward"]
388
+ )
389
+ steps_df = pd.DataFrame(steps_log, columns=["Episode", "Mean Steps"])
390
+
391
+ # Create subplots
392
+ fig = make_subplots(specs=[[{"secondary_y": True}]])
393
+
394
+ # Add reward line
395
+ fig.add_trace(
396
+ go.Scatter(
397
+ x=progress_df["Episode"],
398
+ y=progress_df["Mean Reward"],
399
+ mode="lines+markers",
400
+ name="Mean Reward",
401
+ line=dict(color="#1f77b4", width=3),
402
+ marker=dict(size=8)
403
+ )
404
+ )
405
+
406
+ # Add steps line on secondary axis
407
+ fig.add_trace(
408
+ go.Scatter(
409
+ x=steps_df["Episode"],
410
+ y=steps_df["Mean Steps"],
411
+ mode="lines+markers",
412
+ name="Mean Steps",
413
+ line=dict(color="#ff7f0e", width=3, dash="dot"),
414
+ marker=dict(size=8)
415
+ ),
416
+ secondary_y=True
417
+ )
418
+
419
+ # Add confidence interval for reward
420
+ fig.add_trace(
421
+ go.Scatter(
422
+ x=progress_df["Episode"].tolist() + progress_df["Episode"].tolist()[::-1],
423
+ y=(progress_df["Mean Reward"] + progress_df["Std Reward"]).tolist() +
424
+ (progress_df["Mean Reward"] - progress_df["Std Reward"]).tolist()[::-1],
425
+ fill="toself",
426
+ fillcolor="rgba(31, 119, 180, 0.2)",
427
+ line=dict(color="rgba(255,255,255,0)"),
428
+ hoverinfo="skip",
429
+ showlegend=False
430
+ )
431
+ )
432
+
433
+ # Update layout
434
+ fig.update_layout(
435
+ title="Agent Performance Over Training",
436
+ xaxis_title="Training Episode",
437
+ margin=dict(l=20, r=20, t=40, b=20),
438
+ legend=dict(
439
+ orientation="h",
440
+ yanchor="bottom",
441
+ y=1.02,
442
+ xanchor="right",
443
+ x=1
444
+ ),
445
+ height=400
446
+ )
447
+
448
+ # Set y-axes titles
449
+ fig.update_yaxes(title_text="Reward", secondary_y=False)
450
+ fig.update_yaxes(title_text="Steps", secondary_y=True)
451
+
452
+ metrics_chart.plotly_chart(fig, use_container_width=True)
453
+
454
+ # Epsilon decay chart
455
+ epsilon_df = pd.DataFrame({
456
+ "Episode": list(range(len(epsilons))),
457
+ "Epsilon": epsilons
458
+ })
459
+
460
+ epsilon_fig = px.line(
461
+ epsilon_df,
462
+ x="Episode",
463
+ y="Epsilon",
464
+ title="Exploration Rate (Epsilon) Decay"
465
+ )
466
+
467
+ epsilon_fig.update_layout(
468
+ xaxis_title="Training Episode",
469
+ yaxis_title="Epsilon Value",
470
+ height=250,
471
+ margin=dict(l=20, r=20, t=40, b=20)
472
+ )
473
+
474
+ epsilon_chart.plotly_chart(epsilon_fig, use_container_width=True)
475
+
476
+ # Update Q-table visualization
477
+ qtable_fig = px.imshow(
478
+ Qtable,
479
+ labels=dict(x="Actions", y="States", color="Q-Value"),
480
+ x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'],
481
+ zmin=Qtable.min(),
482
+ zmax=Qtable.max(),
483
+ color_continuous_scale="Viridis"
484
+ )
485
+
486
+ qtable_fig.update_layout(
487
+ title=f"Q-table at Episode {episode}",
488
+ height=600,
489
+ margin=dict(l=20, r=20, t=40, b=20)
490
+ )
491
+
492
+ qtable_visualization.plotly_chart(qtable_fig, use_container_width=True)
493
+
494
+ # Q-table snapshot at regular intervals
495
+ if episode % log_freq == 0:
496
+ qtable_snapshots[episode] = Qtable.copy()
497
+
498
+ # Update progress
499
+ if episode % 100 == 0:
500
+ elapsed = time.time() - start_time
501
+ estimated = elapsed / (episode + 1) * (n_episodes - episode - 1) if episode > 0 else 0
502
+ status_text.text(f"Training in progress... Time elapsed: {elapsed:.1f}s | Estimated time remaining: {estimated:.1f}s")
503
+ progress_bar.progress((episode + 1) / n_episodes)
504
+
505
+ # Training complete
506
+ progress_bar.progress(1.0)
507
+ status_text.success(f"✅ Training completed in {time.time() - start_time:.1f} seconds!")
508
+
509
+ # Store results in session state for persistence
510
+ st.session_state.trained_qtable = Qtable
511
+ st.session_state.agent_videos = videos
512
+ st.session_state.training_completed = True
513
+
514
+ # Final evaluation
515
+ final_mean, final_std, final_steps, final_success = evaluate_agent(
516
+ env, max_steps, n_eval_episodes * 2, Qtable # Double evaluation episodes for final eval
517
+ )
518
+
519
+ # Store final metrics
520
+ st.session_state.final_metrics = {
521
+ "mean_reward": final_mean,
522
+ "std_reward": final_std,
523
+ "mean_steps": final_steps,
524
+ "success_rate": final_success,
525
+ "q_min": Qtable.min(),
526
+ "q_max": Qtable.max()
527
+ }
528
+
529
+ return Qtable, videos
530
+
531
+ # Environment setup
532
+ env = gym.make("Taxi-v3")
533
+ eval_env = gym.make("Taxi-v3", render_mode="rgb_array")
534
+
535
+ # Create training button if training hasn't completed
536
+ if not st.session_state.training_completed:
537
+ train_col1, train_col2 = st.columns([3, 1])
538
+ with train_col1:
539
+ st.write("") # For spacing
540
+ with train_col2:
541
+ start_training = st.button("🚀 Start Training", type="primary", use_container_width=True)
542
+
543
+ # Start training when button is clicked
544
+ if start_training:
545
+ params = {
546
+ "n_episodes": n_episodes,
547
+ "learning_rate": learning_rate,
548
+ "gamma": gamma,
549
+ "max_steps": max_steps,
550
+ "min_epsilon": min_epsilon,
551
+ "max_epsilon": max_epsilon,
552
+ "decay_rate": decay_rate,
553
+ "log_freq": log_freq,
554
+ "eval_every": eval_every,
555
+ "n_eval_episodes": n_eval_episodes,
556
+ "video_length": video_length if 'video_length' in locals() else 50
557
+ }
558
+
559
+ trained_qtable, agent_videos = train_agent(env, eval_env, params)
560
+
561
+ # If training is completed, show results
562
+ if st.session_state.training_completed:
563
+ # Create tabs for different visualizations
564
+ tab1, tab2, tab3 = st.tabs(["📊 Training Results", "🎬 Agent Evolution", "📋 Q-Table Visualization"])
565
+
566
+ with tab1:
567
+ # Summary metrics in nice boxes
568
+ st.markdown("### 📊 Final Performance")
569
+
570
+ metrics = st.session_state.final_metrics
571
+ metric_cols = st.columns(4)
572
+ with metric_cols[0]:
573
+ st.metric("Final Average Reward", f"{metrics['mean_reward']:.2f}", delta=f"±{metrics['std_reward']:.2f}")
574
+ with metric_cols[1]:
575
+ st.metric("Average Steps to Complete", f"{metrics['mean_steps']:.1f}")
576
+ with metric_cols[2]:
577
+ st.metric("Success Rate", f"{metrics['success_rate']:.1f}%")
578
+ with metric_cols[3]:
579
+ st.metric("Q-values Range", f"{metrics['q_min']:.2f} to {metrics['q_max']:.2f}")
580
+
581
+ # Download trained Q-table
582
+ st.subheader("Export Model")
583
+
584
+ # Convert Q-table to bytes for download
585
+ def get_table_download_link(array):
586
+ csvfile = BytesIO()
587
+ np.save(csvfile, array)
588
+ b64 = base64.b64encode(csvfile.getvalue()).decode()
589
+ href = f'<a href="data:application/octet-stream;base64,{b64}" download="qtable.npy">Download Q-table (.npy)</a>'
590
+ return href
591
+
592
+ st.markdown(get_table_download_link(st.session_state.trained_qtable), unsafe_allow_html=True)
593
+
594
+ with tab2:
595
+ # Create video selection slider
596
+ video_episodes = sorted(list(st.session_state.agent_videos.keys()))
597
+
598
+ if video_episodes:
599
+ selected_episode = st.select_slider(
600
+ "Select checkpoint to view agent behavior:",
601
+ options=video_episodes,
602
+ format_func=lambda x: f"Episode {x} ({x/st.session_state.training_params['n_episodes']:.0%})"
603
+ )
604
+
605
+ # Display the selected video
606
+ st.markdown(f"### Agent Behavior at Episode {selected_episode} ({selected_episode/st.session_state.training_params['n_episodes']:.0%})")
607
+
608
+ video_html = frames_to_html_video(st.session_state.agent_videos[selected_episode])
609
+ st.markdown(video_html, unsafe_allow_html=True)
610
+
611
+ # Add explanation of agent behavior
612
+ st.markdown("""
613
+ #### What Am I Looking At?
614
+
615
+ This animated visualization shows how the trained agent behaves at different stages of training.
616
+ You can observe:
617
+
618
+ - The **yellow square** represents the taxi
619
+ - The **letters (R, G, Y, B)** represent four fixed locations
620
+ - The **blue letter** represents the passenger pickup location
621
+ - The **purple letter** represents the passenger dropoff destination
622
+ - When the passenger is in the taxi, the taxi turns green
623
+
624
+ The agent makes decisions based on its learned Q-values. Early in training, movements may appear random as the agent explores.
625
+ Later in training, the agent should take more direct routes to complete the task efficiently.
626
+ """)
627
+
628
+ with tab3:
629
+ # Q-table visualization
630
+ st.markdown("### Q-Table Visualization")
631
+ st.info("This heatmap shows the learned Q-values that guide the agent's decision making.")
632
+
633
+ # Generate a heatmap of the Q-table
634
+ qtable_fig = px.imshow(
635
+ st.session_state.trained_qtable,
636
+ labels=dict(x="Actions", y="States", color="Q-Value"),
637
+ x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'],
638
+ zmin=st.session_state.final_metrics["q_min"],
639
+ zmax=st.session_state.final_metrics["q_max"],
640
+ color_continuous_scale="Viridis"
641
+ )
642
+
643
+ qtable_fig.update_layout(
644
+ title="Final Q-table",
645
+ height=600,
646
+ margin=dict(l=20, r=20, t=40, b=20)
647
+ )
648
+
649
+ st.plotly_chart(qtable_fig, use_container_width=True)
650
+
651
+ # Add Q-table explanation
652
+ st.markdown("""
653
+ #### Understanding the Q-Table
654
+
655
+ The Q-table is the heart of the Q-learning algorithm:
656
+
657
+ - Each **row** represents a different state (there are 500 possible states in Taxi-v3)
658
+ - Each **column** represents an action (South, North, East, West, Pickup, Dropoff)
659
+ - The **values** (colors) represent the expected future reward for taking that action in that state
660
+ - **Brighter colors** indicate higher expected rewards
661
+
662
+ The agent selects actions by choosing the highest value (brightest color) for its current state.
663
+ """)
664
+
665
+ # Add educational resources at the bottom
666
+ with st.expander("📚 Learn More About Q-Learning"):
667
+ st.markdown("""
668
+ ### Key Concepts in Q-Learning
669
+
670
+ * **Q-Value**: Represents the expected future reward for taking action A in state S
671
+ * **Exploration vs Exploitation**: Balancing between trying new actions and using known good actions
672
+ * **Learning Rate (α)**: Controls how much new information overrides old information
673
+ * **Discount Factor (γ)**: Determines the importance of future rewards
674
+ * **Epsilon-greedy Policy**: A strategy that balances exploration and exploitation
675
+
676
+ ### Taxi-v3 Environment Details
677
+
678
+ The Taxi environment consists of a 5x5 grid world where a taxi needs to:
679
+ 1. Navigate to the passenger's location
680
+ 2. Pick up the passenger
681
+ 3. Navigate to the destination
682
+ 4. Drop off the passenger
683
+
684
+ **Actions**:
685
+ - Move South (0)
686
+ - Move North (1)
687
+ - Move East (2)
688
+ - Move West (3)
689
+ - Pickup passenger (4)
690
+ - Dropoff passenger (5)
691
+
692
+ **Rewards**:
693
+ - -1 per time step
694
+ - +20 for successful dropoff
695
+ - -10 for illegal pickup/dropoff actions
696
+ """)
697
+
698
+ # Footer
699
+ st.markdown("""
700
+ <div style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #eee;">
701
+ <p style="color: #64748b;">Interactive Q-Learning Dashboard for Reinforcement Learning Education</p>
702
+ </div>
703
+ """, unsafe_allow_html=True)