Spaces:
Sleeping
Sleeping
DebatableMiracle
commited on
Commit
·
31bef6c
1
Parent(s):
6d5cd51
app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,703 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import gymnasium as gym
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.express as px
|
8 |
+
import plotly.graph_objects as go
|
9 |
+
from plotly.subplots import make_subplots
|
10 |
+
import base64
|
11 |
+
from io import BytesIO
|
12 |
+
from PIL import Image
|
13 |
+
import time
|
14 |
|
15 |
+
# Use session state to persist data across reruns
|
16 |
+
if 'trained_qtable' not in st.session_state:
|
17 |
+
st.session_state.trained_qtable = None
|
18 |
+
if 'agent_videos' not in st.session_state:
|
19 |
+
st.session_state.agent_videos = {}
|
20 |
+
if 'training_completed' not in st.session_state:
|
21 |
+
st.session_state.training_completed = False
|
22 |
+
if 'training_params' not in st.session_state:
|
23 |
+
st.session_state.training_params = {}
|
24 |
+
if 'final_metrics' not in st.session_state:
|
25 |
+
st.session_state.final_metrics = {}
|
26 |
+
|
27 |
+
# Set page configuration for a cleaner look
|
28 |
+
st.set_page_config(
|
29 |
+
page_title="Taxi-v3 Q-Learning Dashboard",
|
30 |
+
page_icon="🚕",
|
31 |
+
layout="wide",
|
32 |
+
initial_sidebar_state="expanded"
|
33 |
+
)
|
34 |
+
|
35 |
+
# Custom CSS to make the dashboard look cleaner
|
36 |
+
st.markdown("""
|
37 |
+
<style>
|
38 |
+
.main .block-container {
|
39 |
+
padding-top: 2rem;
|
40 |
+
padding-bottom: 2rem;
|
41 |
+
}
|
42 |
+
.stTabs [data-baseweb="tab-list"] {
|
43 |
+
gap: 10px;
|
44 |
+
}
|
45 |
+
.stTabs [data-baseweb="tab"] {
|
46 |
+
background-color: #f0f2f6;
|
47 |
+
border-radius: 4px 4px 0px 0px;
|
48 |
+
padding: 10px 20px;
|
49 |
+
font-weight: 600;
|
50 |
+
}
|
51 |
+
.stTabs [aria-selected="true"] {
|
52 |
+
background-color: #e6f0ff;
|
53 |
+
border-bottom: 2px solid #4e8df5;
|
54 |
+
}
|
55 |
+
.reportview-container .main .block-container {
|
56 |
+
max-width: 1200px;
|
57 |
+
}
|
58 |
+
div[data-testid="stSidebarNav"] li div a {
|
59 |
+
margin-left: 1rem;
|
60 |
+
padding: 1rem;
|
61 |
+
width: 300px;
|
62 |
+
border-radius: 0.5rem;
|
63 |
+
}
|
64 |
+
div[data-testid="stSidebarNav"] li div::focus-visible {
|
65 |
+
background-color: rgba(151, 166, 195, 0.15);
|
66 |
+
}
|
67 |
+
.stMetric {
|
68 |
+
background-color: #f0f2f6;
|
69 |
+
padding: 15px 20px;
|
70 |
+
border-radius: 6px;
|
71 |
+
margin-bottom: 10px;
|
72 |
+
}
|
73 |
+
.css-12w0qpk {
|
74 |
+
background-color: #f8f9fa;
|
75 |
+
}
|
76 |
+
</style>
|
77 |
+
""", unsafe_allow_html=True)
|
78 |
+
|
79 |
+
# Header
|
80 |
+
st.markdown("""
|
81 |
+
<div style="text-align: center; margin-bottom: 30px;">
|
82 |
+
<h1 style="color: #1e3a8a; margin-bottom:0;">🚕 Taxi-v3 Q-Learning Dashboard</h1>
|
83 |
+
<p style="color: #64748b; font-size: 1.2em;">Interactive Reinforcement Learning Visualization</p>
|
84 |
+
</div>
|
85 |
+
""", unsafe_allow_html=True)
|
86 |
+
|
87 |
+
# Create a two-column layout for the main dashboard
|
88 |
+
col1, col2 = st.columns([3, 2])
|
89 |
+
|
90 |
+
with col2:
|
91 |
+
st.markdown("### 🎮 Environment Preview")
|
92 |
+
|
93 |
+
# Fix: Create a proper environment preview by resetting first
|
94 |
+
preview_env = gym.make("Taxi-v3", render_mode="rgb_array")
|
95 |
+
preview_env.reset() # Reset the environment first
|
96 |
+
env_preview = preview_env.render()
|
97 |
+
st.image(env_preview, caption="Taxi-v3 Environment", use_column_width=True)
|
98 |
+
|
99 |
+
st.markdown("""
|
100 |
+
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-top: 20px;">
|
101 |
+
<h4 style="margin-top: 0;">📝 About this Environment</h4>
|
102 |
+
<p>The Taxi-v3 task involves navigating a taxi to pick up a passenger and drop them off at a destination.</p>
|
103 |
+
<ul>
|
104 |
+
<li><b>Yellow</b>: taxi</li>
|
105 |
+
<li><b>Blue</b>: pick-up location</li>
|
106 |
+
<li><b>Purple</b>: drop-off location</li>
|
107 |
+
<li><b>Green</b>: passenger</li>
|
108 |
+
<li><b>Letters (R, G, Y, B)</b>: locations</li>
|
109 |
+
</ul>
|
110 |
+
</div>
|
111 |
+
""", unsafe_allow_html=True)
|
112 |
+
|
113 |
+
with col1:
|
114 |
+
st.markdown("### ⚙️ Training Parameters")
|
115 |
+
|
116 |
+
# Only show parameters if training hasn't completed yet
|
117 |
+
if not st.session_state.training_completed:
|
118 |
+
# Create a cleaner parameter input section
|
119 |
+
col_a, col_b = st.columns(2)
|
120 |
+
|
121 |
+
with col_a:
|
122 |
+
n_episodes = st.number_input("Training Episodes", min_value=1000, max_value=100000, value=25000, step=1000)
|
123 |
+
learning_rate = st.slider("Learning Rate (α)", 0.01, 1.0, 0.7, 0.01,
|
124 |
+
format="%.2f", help="Controls how much new information overrides old information")
|
125 |
+
gamma = st.slider("Discount Factor (γ)", 0.80, 0.99, 0.95, 0.01,
|
126 |
+
format="%.2f", help="Determines the importance of future rewards")
|
127 |
+
max_steps = st.slider("Max Steps per Episode", 50, 500, 99)
|
128 |
+
|
129 |
+
with col_b:
|
130 |
+
min_epsilon = st.slider("Min Exploration Rate (ε)", 0.01, 0.5, 0.05, 0.01,
|
131 |
+
format="%.2f", help="Minimum probability of random action")
|
132 |
+
max_epsilon = st.slider("Max Exploration Rate (ε)", 0.5, 1.0, 1.0, 0.01,
|
133 |
+
format="%.2f", help="Starting probability of random action")
|
134 |
+
decay_rate = st.slider("Epsilon Decay Rate", 0.0001, 0.01, 0.001, 0.0001,
|
135 |
+
format="%.4f", help="How quickly exploration decreases")
|
136 |
+
n_eval_episodes = st.slider("Evaluation Episodes", 10, 200, 100,
|
137 |
+
help="Number of episodes to evaluate performance")
|
138 |
+
|
139 |
+
# Additional parameters in a collapsed section
|
140 |
+
with st.expander("Advanced Settings"):
|
141 |
+
log_freq = st.slider("Q-table Update Frequency (every N episodes)", 1, 1000, 500)
|
142 |
+
eval_every = st.slider("Evaluation Frequency (% of training)", 5, 50, 10,
|
143 |
+
help="How often to evaluate agent performance")
|
144 |
+
video_length = st.slider("Evaluation Video Length (steps)", 10, 200, 50,
|
145 |
+
help="Maximum steps to show in visualization videos")
|
146 |
+
else:
|
147 |
+
# If training is completed, show the parameters that were used
|
148 |
+
st.info("Training completed with the following parameters:")
|
149 |
+
params = st.session_state.training_params
|
150 |
+
|
151 |
+
col_a, col_b = st.columns(2)
|
152 |
+
with col_a:
|
153 |
+
st.write(f"**Training Episodes**: {params['n_episodes']}")
|
154 |
+
st.write(f"**Learning Rate (α)**: {params['learning_rate']}")
|
155 |
+
st.write(f"**Discount Factor (γ)**: {params['gamma']}")
|
156 |
+
st.write(f"**Max Steps per Episode**: {params['max_steps']}")
|
157 |
+
|
158 |
+
with col_b:
|
159 |
+
st.write(f"**Min Exploration Rate (ε)**: {params['min_epsilon']}")
|
160 |
+
st.write(f"**Max Exploration Rate (ε)**: {params['max_epsilon']}")
|
161 |
+
st.write(f"**Epsilon Decay Rate**: {params['decay_rate']}")
|
162 |
+
st.write(f"**Evaluation Episodes**: {params['n_eval_episodes']}")
|
163 |
+
|
164 |
+
# Option to reset and train again
|
165 |
+
if st.button("Reset and Train Again", type="secondary"):
|
166 |
+
st.session_state.training_completed = False
|
167 |
+
st.session_state.trained_qtable = None
|
168 |
+
st.session_state.agent_videos = {}
|
169 |
+
st.session_state.training_params = {}
|
170 |
+
st.session_state.final_metrics = {}
|
171 |
+
st.experimental_rerun()
|
172 |
+
|
173 |
+
# Initialize Q-table
|
174 |
+
def initialize_q_table(state_space, action_space):
|
175 |
+
return np.zeros((state_space, action_space))
|
176 |
+
|
177 |
+
# Policies
|
178 |
+
def greedy_policy(Qtable, state):
|
179 |
+
return np.argmax(Qtable[state, :])
|
180 |
+
|
181 |
+
def epsilon_greedy_policy(Qtable, state, epsilon):
|
182 |
+
if np.random.uniform(0, 1) > epsilon:
|
183 |
+
return greedy_policy(Qtable, state)
|
184 |
+
else:
|
185 |
+
return env.action_space.sample()
|
186 |
+
|
187 |
+
# Function to create animation of agent behavior
|
188 |
+
def create_agent_video(env, Q, max_steps=100, seed=None):
|
189 |
+
frames = []
|
190 |
+
state, info = env.reset(seed=seed)
|
191 |
+
|
192 |
+
# Add the initial frame
|
193 |
+
frames.append(env.render())
|
194 |
+
|
195 |
+
for _ in range(max_steps):
|
196 |
+
# Choose action based on greedy policy
|
197 |
+
action = greedy_policy(Q, state)
|
198 |
+
|
199 |
+
# Step in environment
|
200 |
+
state, reward, terminated, truncated, _ = env.step(action)
|
201 |
+
|
202 |
+
# Render the frame after taking action
|
203 |
+
frames.append(env.render())
|
204 |
+
|
205 |
+
# Break if episode is done
|
206 |
+
if terminated or truncated:
|
207 |
+
break
|
208 |
+
|
209 |
+
return frames
|
210 |
+
|
211 |
+
# Evaluation function
|
212 |
+
def evaluate_agent(env, max_steps, n_eval_episodes, Q, seed=None):
|
213 |
+
rewards = []
|
214 |
+
steps = []
|
215 |
+
success_count = 0
|
216 |
+
|
217 |
+
for episode in range(n_eval_episodes):
|
218 |
+
state, info = env.reset(seed=seed[episode] if seed else None)
|
219 |
+
total_rewards_ep = 0
|
220 |
+
num_steps = 0
|
221 |
+
|
222 |
+
for step in range(max_steps):
|
223 |
+
action = greedy_policy(Q, state)
|
224 |
+
state, reward, terminated, truncated, _ = env.step(action)
|
225 |
+
total_rewards_ep += reward
|
226 |
+
num_steps += 1
|
227 |
+
|
228 |
+
if terminated or truncated:
|
229 |
+
if reward > 0: # Successfully completed the task
|
230 |
+
success_count += 1
|
231 |
+
break
|
232 |
+
|
233 |
+
rewards.append(total_rewards_ep)
|
234 |
+
steps.append(num_steps)
|
235 |
+
|
236 |
+
success_rate = success_count / n_eval_episodes * 100
|
237 |
+
return np.mean(rewards), np.std(rewards), np.mean(steps), success_rate
|
238 |
+
|
239 |
+
# Function to convert frames to HTML video
|
240 |
+
def frames_to_html_video(frames, fps=5):
|
241 |
+
if not frames:
|
242 |
+
return "<p>No frames available</p>"
|
243 |
+
|
244 |
+
try:
|
245 |
+
# Create a PIL image from each frame
|
246 |
+
pil_images = [Image.fromarray(frame) for frame in frames]
|
247 |
+
|
248 |
+
# Save as animated GIF to a BytesIO object
|
249 |
+
buffer = BytesIO()
|
250 |
+
pil_images[0].save(
|
251 |
+
buffer,
|
252 |
+
format='GIF',
|
253 |
+
save_all=True,
|
254 |
+
append_images=pil_images[1:],
|
255 |
+
duration=1000/fps,
|
256 |
+
loop=0
|
257 |
+
)
|
258 |
+
buffer.seek(0)
|
259 |
+
|
260 |
+
# Encode as base64
|
261 |
+
encoded = base64.b64encode(buffer.read()).decode("utf-8")
|
262 |
+
|
263 |
+
# Embed in HTML
|
264 |
+
html = f'<img src="data:image/gif;base64,{encoded}" alt="agent behavior" style="width:100%">'
|
265 |
+
return html
|
266 |
+
except Exception as e:
|
267 |
+
return f"<p>Error generating video: {str(e)}</p>"
|
268 |
+
|
269 |
+
# Training function
|
270 |
+
def train_agent(env, eval_env, params):
|
271 |
+
# Unpack parameters
|
272 |
+
n_episodes = params["n_episodes"]
|
273 |
+
learning_rate = params["learning_rate"]
|
274 |
+
gamma = params["gamma"]
|
275 |
+
max_steps = params["max_steps"]
|
276 |
+
min_epsilon = params["min_epsilon"]
|
277 |
+
max_epsilon = params["max_epsilon"]
|
278 |
+
decay_rate = params["decay_rate"]
|
279 |
+
log_freq = params["log_freq"]
|
280 |
+
eval_every = params["eval_every"]
|
281 |
+
n_eval_episodes = params["n_eval_episodes"]
|
282 |
+
video_length = params["video_length"]
|
283 |
+
|
284 |
+
# Store parameters in session state
|
285 |
+
st.session_state.training_params = params
|
286 |
+
|
287 |
+
# Initialize Q-table
|
288 |
+
Qtable = initialize_q_table(env.observation_space.n, env.action_space.n)
|
289 |
+
|
290 |
+
# Training metrics
|
291 |
+
reward_log = []
|
292 |
+
steps_log = []
|
293 |
+
qtable_snapshots = {}
|
294 |
+
videos = {}
|
295 |
+
epsilons = []
|
296 |
+
|
297 |
+
# Calculate checkpoints (ensure at least one checkpoint at the beginning)
|
298 |
+
num_checkpoints = 10 # Number of checkpoints to create
|
299 |
+
checkpoint_episodes = [int(n_episodes * i / num_checkpoints) for i in range(1, num_checkpoints + 1)]
|
300 |
+
checkpoint_episodes[0] = max(1, checkpoint_episodes[0]) # Ensure first checkpoint is at least at episode 1
|
301 |
+
|
302 |
+
# Progress tracking
|
303 |
+
progress_bar = st.progress(0)
|
304 |
+
status_text = st.empty()
|
305 |
+
|
306 |
+
# Dashboard components
|
307 |
+
tab1, tab2, tab3 = st.tabs(["📊 Training Progress", "🎬 Agent Evolution", "📋 Q-Table Visualization"])
|
308 |
+
|
309 |
+
with tab1:
|
310 |
+
col_metrics1, col_metrics2, col_metrics3, col_metrics4 = st.columns(4)
|
311 |
+
with col_metrics1:
|
312 |
+
current_episode_metric = st.empty()
|
313 |
+
with col_metrics2:
|
314 |
+
avg_reward_metric = st.empty()
|
315 |
+
with col_metrics3:
|
316 |
+
avg_steps_metric = st.empty()
|
317 |
+
with col_metrics4:
|
318 |
+
success_rate_metric = st.empty()
|
319 |
+
|
320 |
+
metrics_chart = st.empty()
|
321 |
+
epsilon_chart = st.empty()
|
322 |
+
|
323 |
+
with tab2:
|
324 |
+
video_placeholder = st.empty()
|
325 |
+
|
326 |
+
with tab3:
|
327 |
+
qtable_visualization = st.empty()
|
328 |
+
|
329 |
+
# Training loop
|
330 |
+
start_time = time.time()
|
331 |
+
|
332 |
+
for episode in range(n_episodes):
|
333 |
+
# Calculate epsilon for this episode
|
334 |
+
epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
|
335 |
+
epsilons.append(epsilon)
|
336 |
+
|
337 |
+
state, info = env.reset()
|
338 |
+
total_reward = 0
|
339 |
+
|
340 |
+
# Episode steps
|
341 |
+
for step in range(max_steps):
|
342 |
+
action = epsilon_greedy_policy(Qtable, state, epsilon)
|
343 |
+
new_state, reward, terminated, truncated, _ = env.step(action)
|
344 |
+
|
345 |
+
# Update Q-table
|
346 |
+
Qtable[state][action] += learning_rate * (
|
347 |
+
reward + gamma * np.max(Qtable[new_state, :]) - Qtable[state][action]
|
348 |
+
)
|
349 |
+
|
350 |
+
total_reward += reward
|
351 |
+
if terminated or truncated:
|
352 |
+
break
|
353 |
+
|
354 |
+
state = new_state
|
355 |
+
|
356 |
+
# Evaluation at checkpoints
|
357 |
+
if episode in checkpoint_episodes or episode == n_episodes - 1:
|
358 |
+
mean_reward, std_reward, mean_steps, success_rate = evaluate_agent(
|
359 |
+
env, max_steps, n_eval_episodes, Qtable
|
360 |
+
)
|
361 |
+
reward_log.append((episode, mean_reward, std_reward))
|
362 |
+
steps_log.append((episode, mean_steps))
|
363 |
+
|
364 |
+
# Create and store video of agent behavior
|
365 |
+
try:
|
366 |
+
video_frames = create_agent_video(eval_env, Qtable, max_steps=video_length)
|
367 |
+
videos[episode] = video_frames
|
368 |
+
except Exception as e:
|
369 |
+
st.warning(f"Could not create video for episode {episode}: {str(e)}")
|
370 |
+
videos[episode] = []
|
371 |
+
|
372 |
+
# Take Q-table snapshot
|
373 |
+
qtable_snapshots[episode] = Qtable.copy()
|
374 |
+
|
375 |
+
# Update metrics display
|
376 |
+
current_episode_metric.metric("Episodes", f"{episode}/{n_episodes}",
|
377 |
+
delta=f"{episode/n_episodes:.1%}")
|
378 |
+
avg_reward_metric.metric("Avg. Reward", f"{mean_reward:.2f}",
|
379 |
+
delta=f"{mean_reward - reward_log[-2][1]:.2f}" if len(reward_log) > 1 else None)
|
380 |
+
avg_steps_metric.metric("Avg. Steps", f"{mean_steps:.1f}")
|
381 |
+
success_rate_metric.metric("Success Rate", f"{success_rate:.1f}%")
|
382 |
+
|
383 |
+
# Update progress charts
|
384 |
+
if reward_log:
|
385 |
+
# Prepare data for plots
|
386 |
+
progress_df = pd.DataFrame(
|
387 |
+
reward_log, columns=["Episode", "Mean Reward", "Std Reward"]
|
388 |
+
)
|
389 |
+
steps_df = pd.DataFrame(steps_log, columns=["Episode", "Mean Steps"])
|
390 |
+
|
391 |
+
# Create subplots
|
392 |
+
fig = make_subplots(specs=[[{"secondary_y": True}]])
|
393 |
+
|
394 |
+
# Add reward line
|
395 |
+
fig.add_trace(
|
396 |
+
go.Scatter(
|
397 |
+
x=progress_df["Episode"],
|
398 |
+
y=progress_df["Mean Reward"],
|
399 |
+
mode="lines+markers",
|
400 |
+
name="Mean Reward",
|
401 |
+
line=dict(color="#1f77b4", width=3),
|
402 |
+
marker=dict(size=8)
|
403 |
+
)
|
404 |
+
)
|
405 |
+
|
406 |
+
# Add steps line on secondary axis
|
407 |
+
fig.add_trace(
|
408 |
+
go.Scatter(
|
409 |
+
x=steps_df["Episode"],
|
410 |
+
y=steps_df["Mean Steps"],
|
411 |
+
mode="lines+markers",
|
412 |
+
name="Mean Steps",
|
413 |
+
line=dict(color="#ff7f0e", width=3, dash="dot"),
|
414 |
+
marker=dict(size=8)
|
415 |
+
),
|
416 |
+
secondary_y=True
|
417 |
+
)
|
418 |
+
|
419 |
+
# Add confidence interval for reward
|
420 |
+
fig.add_trace(
|
421 |
+
go.Scatter(
|
422 |
+
x=progress_df["Episode"].tolist() + progress_df["Episode"].tolist()[::-1],
|
423 |
+
y=(progress_df["Mean Reward"] + progress_df["Std Reward"]).tolist() +
|
424 |
+
(progress_df["Mean Reward"] - progress_df["Std Reward"]).tolist()[::-1],
|
425 |
+
fill="toself",
|
426 |
+
fillcolor="rgba(31, 119, 180, 0.2)",
|
427 |
+
line=dict(color="rgba(255,255,255,0)"),
|
428 |
+
hoverinfo="skip",
|
429 |
+
showlegend=False
|
430 |
+
)
|
431 |
+
)
|
432 |
+
|
433 |
+
# Update layout
|
434 |
+
fig.update_layout(
|
435 |
+
title="Agent Performance Over Training",
|
436 |
+
xaxis_title="Training Episode",
|
437 |
+
margin=dict(l=20, r=20, t=40, b=20),
|
438 |
+
legend=dict(
|
439 |
+
orientation="h",
|
440 |
+
yanchor="bottom",
|
441 |
+
y=1.02,
|
442 |
+
xanchor="right",
|
443 |
+
x=1
|
444 |
+
),
|
445 |
+
height=400
|
446 |
+
)
|
447 |
+
|
448 |
+
# Set y-axes titles
|
449 |
+
fig.update_yaxes(title_text="Reward", secondary_y=False)
|
450 |
+
fig.update_yaxes(title_text="Steps", secondary_y=True)
|
451 |
+
|
452 |
+
metrics_chart.plotly_chart(fig, use_container_width=True)
|
453 |
+
|
454 |
+
# Epsilon decay chart
|
455 |
+
epsilon_df = pd.DataFrame({
|
456 |
+
"Episode": list(range(len(epsilons))),
|
457 |
+
"Epsilon": epsilons
|
458 |
+
})
|
459 |
+
|
460 |
+
epsilon_fig = px.line(
|
461 |
+
epsilon_df,
|
462 |
+
x="Episode",
|
463 |
+
y="Epsilon",
|
464 |
+
title="Exploration Rate (Epsilon) Decay"
|
465 |
+
)
|
466 |
+
|
467 |
+
epsilon_fig.update_layout(
|
468 |
+
xaxis_title="Training Episode",
|
469 |
+
yaxis_title="Epsilon Value",
|
470 |
+
height=250,
|
471 |
+
margin=dict(l=20, r=20, t=40, b=20)
|
472 |
+
)
|
473 |
+
|
474 |
+
epsilon_chart.plotly_chart(epsilon_fig, use_container_width=True)
|
475 |
+
|
476 |
+
# Update Q-table visualization
|
477 |
+
qtable_fig = px.imshow(
|
478 |
+
Qtable,
|
479 |
+
labels=dict(x="Actions", y="States", color="Q-Value"),
|
480 |
+
x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'],
|
481 |
+
zmin=Qtable.min(),
|
482 |
+
zmax=Qtable.max(),
|
483 |
+
color_continuous_scale="Viridis"
|
484 |
+
)
|
485 |
+
|
486 |
+
qtable_fig.update_layout(
|
487 |
+
title=f"Q-table at Episode {episode}",
|
488 |
+
height=600,
|
489 |
+
margin=dict(l=20, r=20, t=40, b=20)
|
490 |
+
)
|
491 |
+
|
492 |
+
qtable_visualization.plotly_chart(qtable_fig, use_container_width=True)
|
493 |
+
|
494 |
+
# Q-table snapshot at regular intervals
|
495 |
+
if episode % log_freq == 0:
|
496 |
+
qtable_snapshots[episode] = Qtable.copy()
|
497 |
+
|
498 |
+
# Update progress
|
499 |
+
if episode % 100 == 0:
|
500 |
+
elapsed = time.time() - start_time
|
501 |
+
estimated = elapsed / (episode + 1) * (n_episodes - episode - 1) if episode > 0 else 0
|
502 |
+
status_text.text(f"Training in progress... Time elapsed: {elapsed:.1f}s | Estimated time remaining: {estimated:.1f}s")
|
503 |
+
progress_bar.progress((episode + 1) / n_episodes)
|
504 |
+
|
505 |
+
# Training complete
|
506 |
+
progress_bar.progress(1.0)
|
507 |
+
status_text.success(f"✅ Training completed in {time.time() - start_time:.1f} seconds!")
|
508 |
+
|
509 |
+
# Store results in session state for persistence
|
510 |
+
st.session_state.trained_qtable = Qtable
|
511 |
+
st.session_state.agent_videos = videos
|
512 |
+
st.session_state.training_completed = True
|
513 |
+
|
514 |
+
# Final evaluation
|
515 |
+
final_mean, final_std, final_steps, final_success = evaluate_agent(
|
516 |
+
env, max_steps, n_eval_episodes * 2, Qtable # Double evaluation episodes for final eval
|
517 |
+
)
|
518 |
+
|
519 |
+
# Store final metrics
|
520 |
+
st.session_state.final_metrics = {
|
521 |
+
"mean_reward": final_mean,
|
522 |
+
"std_reward": final_std,
|
523 |
+
"mean_steps": final_steps,
|
524 |
+
"success_rate": final_success,
|
525 |
+
"q_min": Qtable.min(),
|
526 |
+
"q_max": Qtable.max()
|
527 |
+
}
|
528 |
+
|
529 |
+
return Qtable, videos
|
530 |
+
|
531 |
+
# Environment setup
|
532 |
+
env = gym.make("Taxi-v3")
|
533 |
+
eval_env = gym.make("Taxi-v3", render_mode="rgb_array")
|
534 |
+
|
535 |
+
# Create training button if training hasn't completed
|
536 |
+
if not st.session_state.training_completed:
|
537 |
+
train_col1, train_col2 = st.columns([3, 1])
|
538 |
+
with train_col1:
|
539 |
+
st.write("") # For spacing
|
540 |
+
with train_col2:
|
541 |
+
start_training = st.button("🚀 Start Training", type="primary", use_container_width=True)
|
542 |
+
|
543 |
+
# Start training when button is clicked
|
544 |
+
if start_training:
|
545 |
+
params = {
|
546 |
+
"n_episodes": n_episodes,
|
547 |
+
"learning_rate": learning_rate,
|
548 |
+
"gamma": gamma,
|
549 |
+
"max_steps": max_steps,
|
550 |
+
"min_epsilon": min_epsilon,
|
551 |
+
"max_epsilon": max_epsilon,
|
552 |
+
"decay_rate": decay_rate,
|
553 |
+
"log_freq": log_freq,
|
554 |
+
"eval_every": eval_every,
|
555 |
+
"n_eval_episodes": n_eval_episodes,
|
556 |
+
"video_length": video_length if 'video_length' in locals() else 50
|
557 |
+
}
|
558 |
+
|
559 |
+
trained_qtable, agent_videos = train_agent(env, eval_env, params)
|
560 |
+
|
561 |
+
# If training is completed, show results
|
562 |
+
if st.session_state.training_completed:
|
563 |
+
# Create tabs for different visualizations
|
564 |
+
tab1, tab2, tab3 = st.tabs(["📊 Training Results", "🎬 Agent Evolution", "📋 Q-Table Visualization"])
|
565 |
+
|
566 |
+
with tab1:
|
567 |
+
# Summary metrics in nice boxes
|
568 |
+
st.markdown("### 📊 Final Performance")
|
569 |
+
|
570 |
+
metrics = st.session_state.final_metrics
|
571 |
+
metric_cols = st.columns(4)
|
572 |
+
with metric_cols[0]:
|
573 |
+
st.metric("Final Average Reward", f"{metrics['mean_reward']:.2f}", delta=f"±{metrics['std_reward']:.2f}")
|
574 |
+
with metric_cols[1]:
|
575 |
+
st.metric("Average Steps to Complete", f"{metrics['mean_steps']:.1f}")
|
576 |
+
with metric_cols[2]:
|
577 |
+
st.metric("Success Rate", f"{metrics['success_rate']:.1f}%")
|
578 |
+
with metric_cols[3]:
|
579 |
+
st.metric("Q-values Range", f"{metrics['q_min']:.2f} to {metrics['q_max']:.2f}")
|
580 |
+
|
581 |
+
# Download trained Q-table
|
582 |
+
st.subheader("Export Model")
|
583 |
+
|
584 |
+
# Convert Q-table to bytes for download
|
585 |
+
def get_table_download_link(array):
|
586 |
+
csvfile = BytesIO()
|
587 |
+
np.save(csvfile, array)
|
588 |
+
b64 = base64.b64encode(csvfile.getvalue()).decode()
|
589 |
+
href = f'<a href="data:application/octet-stream;base64,{b64}" download="qtable.npy">Download Q-table (.npy)</a>'
|
590 |
+
return href
|
591 |
+
|
592 |
+
st.markdown(get_table_download_link(st.session_state.trained_qtable), unsafe_allow_html=True)
|
593 |
+
|
594 |
+
with tab2:
|
595 |
+
# Create video selection slider
|
596 |
+
video_episodes = sorted(list(st.session_state.agent_videos.keys()))
|
597 |
+
|
598 |
+
if video_episodes:
|
599 |
+
selected_episode = st.select_slider(
|
600 |
+
"Select checkpoint to view agent behavior:",
|
601 |
+
options=video_episodes,
|
602 |
+
format_func=lambda x: f"Episode {x} ({x/st.session_state.training_params['n_episodes']:.0%})"
|
603 |
+
)
|
604 |
+
|
605 |
+
# Display the selected video
|
606 |
+
st.markdown(f"### Agent Behavior at Episode {selected_episode} ({selected_episode/st.session_state.training_params['n_episodes']:.0%})")
|
607 |
+
|
608 |
+
video_html = frames_to_html_video(st.session_state.agent_videos[selected_episode])
|
609 |
+
st.markdown(video_html, unsafe_allow_html=True)
|
610 |
+
|
611 |
+
# Add explanation of agent behavior
|
612 |
+
st.markdown("""
|
613 |
+
#### What Am I Looking At?
|
614 |
+
|
615 |
+
This animated visualization shows how the trained agent behaves at different stages of training.
|
616 |
+
You can observe:
|
617 |
+
|
618 |
+
- The **yellow square** represents the taxi
|
619 |
+
- The **letters (R, G, Y, B)** represent four fixed locations
|
620 |
+
- The **blue letter** represents the passenger pickup location
|
621 |
+
- The **purple letter** represents the passenger dropoff destination
|
622 |
+
- When the passenger is in the taxi, the taxi turns green
|
623 |
+
|
624 |
+
The agent makes decisions based on its learned Q-values. Early in training, movements may appear random as the agent explores.
|
625 |
+
Later in training, the agent should take more direct routes to complete the task efficiently.
|
626 |
+
""")
|
627 |
+
|
628 |
+
with tab3:
|
629 |
+
# Q-table visualization
|
630 |
+
st.markdown("### Q-Table Visualization")
|
631 |
+
st.info("This heatmap shows the learned Q-values that guide the agent's decision making.")
|
632 |
+
|
633 |
+
# Generate a heatmap of the Q-table
|
634 |
+
qtable_fig = px.imshow(
|
635 |
+
st.session_state.trained_qtable,
|
636 |
+
labels=dict(x="Actions", y="States", color="Q-Value"),
|
637 |
+
x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'],
|
638 |
+
zmin=st.session_state.final_metrics["q_min"],
|
639 |
+
zmax=st.session_state.final_metrics["q_max"],
|
640 |
+
color_continuous_scale="Viridis"
|
641 |
+
)
|
642 |
+
|
643 |
+
qtable_fig.update_layout(
|
644 |
+
title="Final Q-table",
|
645 |
+
height=600,
|
646 |
+
margin=dict(l=20, r=20, t=40, b=20)
|
647 |
+
)
|
648 |
+
|
649 |
+
st.plotly_chart(qtable_fig, use_container_width=True)
|
650 |
+
|
651 |
+
# Add Q-table explanation
|
652 |
+
st.markdown("""
|
653 |
+
#### Understanding the Q-Table
|
654 |
+
|
655 |
+
The Q-table is the heart of the Q-learning algorithm:
|
656 |
+
|
657 |
+
- Each **row** represents a different state (there are 500 possible states in Taxi-v3)
|
658 |
+
- Each **column** represents an action (South, North, East, West, Pickup, Dropoff)
|
659 |
+
- The **values** (colors) represent the expected future reward for taking that action in that state
|
660 |
+
- **Brighter colors** indicate higher expected rewards
|
661 |
+
|
662 |
+
The agent selects actions by choosing the highest value (brightest color) for its current state.
|
663 |
+
""")
|
664 |
+
|
665 |
+
# Add educational resources at the bottom
|
666 |
+
with st.expander("📚 Learn More About Q-Learning"):
|
667 |
+
st.markdown("""
|
668 |
+
### Key Concepts in Q-Learning
|
669 |
+
|
670 |
+
* **Q-Value**: Represents the expected future reward for taking action A in state S
|
671 |
+
* **Exploration vs Exploitation**: Balancing between trying new actions and using known good actions
|
672 |
+
* **Learning Rate (α)**: Controls how much new information overrides old information
|
673 |
+
* **Discount Factor (γ)**: Determines the importance of future rewards
|
674 |
+
* **Epsilon-greedy Policy**: A strategy that balances exploration and exploitation
|
675 |
+
|
676 |
+
### Taxi-v3 Environment Details
|
677 |
+
|
678 |
+
The Taxi environment consists of a 5x5 grid world where a taxi needs to:
|
679 |
+
1. Navigate to the passenger's location
|
680 |
+
2. Pick up the passenger
|
681 |
+
3. Navigate to the destination
|
682 |
+
4. Drop off the passenger
|
683 |
+
|
684 |
+
**Actions**:
|
685 |
+
- Move South (0)
|
686 |
+
- Move North (1)
|
687 |
+
- Move East (2)
|
688 |
+
- Move West (3)
|
689 |
+
- Pickup passenger (4)
|
690 |
+
- Dropoff passenger (5)
|
691 |
+
|
692 |
+
**Rewards**:
|
693 |
+
- -1 per time step
|
694 |
+
- +20 for successful dropoff
|
695 |
+
- -10 for illegal pickup/dropoff actions
|
696 |
+
""")
|
697 |
+
|
698 |
+
# Footer
|
699 |
+
st.markdown("""
|
700 |
+
<div style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #eee;">
|
701 |
+
<p style="color: #64748b;">Interactive Q-Learning Dashboard for Reinforcement Learning Education</p>
|
702 |
+
</div>
|
703 |
+
""", unsafe_allow_html=True)
|