DebatableMiracle
app.py
31bef6c
import streamlit as st
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import base64
from io import BytesIO
from PIL import Image
import time
# Use session state to persist data across reruns
if 'trained_qtable' not in st.session_state:
st.session_state.trained_qtable = None
if 'agent_videos' not in st.session_state:
st.session_state.agent_videos = {}
if 'training_completed' not in st.session_state:
st.session_state.training_completed = False
if 'training_params' not in st.session_state:
st.session_state.training_params = {}
if 'final_metrics' not in st.session_state:
st.session_state.final_metrics = {}
# Set page configuration for a cleaner look
st.set_page_config(
page_title="Taxi-v3 Q-Learning Dashboard",
page_icon="🚕",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS to make the dashboard look cleaner
st.markdown("""
<style>
.main .block-container {
padding-top: 2rem;
padding-bottom: 2rem;
}
.stTabs [data-baseweb="tab-list"] {
gap: 10px;
}
.stTabs [data-baseweb="tab"] {
background-color: #f0f2f6;
border-radius: 4px 4px 0px 0px;
padding: 10px 20px;
font-weight: 600;
}
.stTabs [aria-selected="true"] {
background-color: #e6f0ff;
border-bottom: 2px solid #4e8df5;
}
.reportview-container .main .block-container {
max-width: 1200px;
}
div[data-testid="stSidebarNav"] li div a {
margin-left: 1rem;
padding: 1rem;
width: 300px;
border-radius: 0.5rem;
}
div[data-testid="stSidebarNav"] li div::focus-visible {
background-color: rgba(151, 166, 195, 0.15);
}
.stMetric {
background-color: #f0f2f6;
padding: 15px 20px;
border-radius: 6px;
margin-bottom: 10px;
}
.css-12w0qpk {
background-color: #f8f9fa;
}
</style>
""", unsafe_allow_html=True)
# Header
st.markdown("""
<div style="text-align: center; margin-bottom: 30px;">
<h1 style="color: #1e3a8a; margin-bottom:0;">🚕 Taxi-v3 Q-Learning Dashboard</h1>
<p style="color: #64748b; font-size: 1.2em;">Interactive Reinforcement Learning Visualization</p>
</div>
""", unsafe_allow_html=True)
# Create a two-column layout for the main dashboard
col1, col2 = st.columns([3, 2])
with col2:
st.markdown("### 🎮 Environment Preview")
# Fix: Create a proper environment preview by resetting first
preview_env = gym.make("Taxi-v3", render_mode="rgb_array")
preview_env.reset() # Reset the environment first
env_preview = preview_env.render()
st.image(env_preview, caption="Taxi-v3 Environment", use_column_width=True)
st.markdown("""
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-top: 20px;">
<h4 style="margin-top: 0;">📝 About this Environment</h4>
<p>The Taxi-v3 task involves navigating a taxi to pick up a passenger and drop them off at a destination.</p>
<ul>
<li><b>Yellow</b>: taxi</li>
<li><b>Blue</b>: pick-up location</li>
<li><b>Purple</b>: drop-off location</li>
<li><b>Green</b>: passenger</li>
<li><b>Letters (R, G, Y, B)</b>: locations</li>
</ul>
</div>
""", unsafe_allow_html=True)
with col1:
st.markdown("### ⚙️ Training Parameters")
# Only show parameters if training hasn't completed yet
if not st.session_state.training_completed:
# Create a cleaner parameter input section
col_a, col_b = st.columns(2)
with col_a:
n_episodes = st.number_input("Training Episodes", min_value=1000, max_value=100000, value=25000, step=1000)
learning_rate = st.slider("Learning Rate (α)", 0.01, 1.0, 0.7, 0.01,
format="%.2f", help="Controls how much new information overrides old information")
gamma = st.slider("Discount Factor (γ)", 0.80, 0.99, 0.95, 0.01,
format="%.2f", help="Determines the importance of future rewards")
max_steps = st.slider("Max Steps per Episode", 50, 500, 99)
with col_b:
min_epsilon = st.slider("Min Exploration Rate (ε)", 0.01, 0.5, 0.05, 0.01,
format="%.2f", help="Minimum probability of random action")
max_epsilon = st.slider("Max Exploration Rate (ε)", 0.5, 1.0, 1.0, 0.01,
format="%.2f", help="Starting probability of random action")
decay_rate = st.slider("Epsilon Decay Rate", 0.0001, 0.01, 0.001, 0.0001,
format="%.4f", help="How quickly exploration decreases")
n_eval_episodes = st.slider("Evaluation Episodes", 10, 200, 100,
help="Number of episodes to evaluate performance")
# Additional parameters in a collapsed section
with st.expander("Advanced Settings"):
log_freq = st.slider("Q-table Update Frequency (every N episodes)", 1, 1000, 500)
eval_every = st.slider("Evaluation Frequency (% of training)", 5, 50, 10,
help="How often to evaluate agent performance")
video_length = st.slider("Evaluation Video Length (steps)", 10, 200, 50,
help="Maximum steps to show in visualization videos")
else:
# If training is completed, show the parameters that were used
st.info("Training completed with the following parameters:")
params = st.session_state.training_params
col_a, col_b = st.columns(2)
with col_a:
st.write(f"**Training Episodes**: {params['n_episodes']}")
st.write(f"**Learning Rate (α)**: {params['learning_rate']}")
st.write(f"**Discount Factor (γ)**: {params['gamma']}")
st.write(f"**Max Steps per Episode**: {params['max_steps']}")
with col_b:
st.write(f"**Min Exploration Rate (ε)**: {params['min_epsilon']}")
st.write(f"**Max Exploration Rate (ε)**: {params['max_epsilon']}")
st.write(f"**Epsilon Decay Rate**: {params['decay_rate']}")
st.write(f"**Evaluation Episodes**: {params['n_eval_episodes']}")
# Option to reset and train again
if st.button("Reset and Train Again", type="secondary"):
st.session_state.training_completed = False
st.session_state.trained_qtable = None
st.session_state.agent_videos = {}
st.session_state.training_params = {}
st.session_state.final_metrics = {}
st.experimental_rerun()
# Initialize Q-table
def initialize_q_table(state_space, action_space):
return np.zeros((state_space, action_space))
# Policies
def greedy_policy(Qtable, state):
return np.argmax(Qtable[state, :])
def epsilon_greedy_policy(Qtable, state, epsilon):
if np.random.uniform(0, 1) > epsilon:
return greedy_policy(Qtable, state)
else:
return env.action_space.sample()
# Function to create animation of agent behavior
def create_agent_video(env, Q, max_steps=100, seed=None):
frames = []
state, info = env.reset(seed=seed)
# Add the initial frame
frames.append(env.render())
for _ in range(max_steps):
# Choose action based on greedy policy
action = greedy_policy(Q, state)
# Step in environment
state, reward, terminated, truncated, _ = env.step(action)
# Render the frame after taking action
frames.append(env.render())
# Break if episode is done
if terminated or truncated:
break
return frames
# Evaluation function
def evaluate_agent(env, max_steps, n_eval_episodes, Q, seed=None):
rewards = []
steps = []
success_count = 0
for episode in range(n_eval_episodes):
state, info = env.reset(seed=seed[episode] if seed else None)
total_rewards_ep = 0
num_steps = 0
for step in range(max_steps):
action = greedy_policy(Q, state)
state, reward, terminated, truncated, _ = env.step(action)
total_rewards_ep += reward
num_steps += 1
if terminated or truncated:
if reward > 0: # Successfully completed the task
success_count += 1
break
rewards.append(total_rewards_ep)
steps.append(num_steps)
success_rate = success_count / n_eval_episodes * 100
return np.mean(rewards), np.std(rewards), np.mean(steps), success_rate
# Function to convert frames to HTML video
def frames_to_html_video(frames, fps=5):
if not frames:
return "<p>No frames available</p>"
try:
# Create a PIL image from each frame
pil_images = [Image.fromarray(frame) for frame in frames]
# Save as animated GIF to a BytesIO object
buffer = BytesIO()
pil_images[0].save(
buffer,
format='GIF',
save_all=True,
append_images=pil_images[1:],
duration=1000/fps,
loop=0
)
buffer.seek(0)
# Encode as base64
encoded = base64.b64encode(buffer.read()).decode("utf-8")
# Embed in HTML
html = f'<img src="data:image/gif;base64,{encoded}" alt="agent behavior" style="width:100%">'
return html
except Exception as e:
return f"<p>Error generating video: {str(e)}</p>"
# Training function
def train_agent(env, eval_env, params):
# Unpack parameters
n_episodes = params["n_episodes"]
learning_rate = params["learning_rate"]
gamma = params["gamma"]
max_steps = params["max_steps"]
min_epsilon = params["min_epsilon"]
max_epsilon = params["max_epsilon"]
decay_rate = params["decay_rate"]
log_freq = params["log_freq"]
eval_every = params["eval_every"]
n_eval_episodes = params["n_eval_episodes"]
video_length = params["video_length"]
# Store parameters in session state
st.session_state.training_params = params
# Initialize Q-table
Qtable = initialize_q_table(env.observation_space.n, env.action_space.n)
# Training metrics
reward_log = []
steps_log = []
qtable_snapshots = {}
videos = {}
epsilons = []
# Calculate checkpoints (ensure at least one checkpoint at the beginning)
num_checkpoints = 10 # Number of checkpoints to create
checkpoint_episodes = [int(n_episodes * i / num_checkpoints) for i in range(1, num_checkpoints + 1)]
checkpoint_episodes[0] = max(1, checkpoint_episodes[0]) # Ensure first checkpoint is at least at episode 1
# Progress tracking
progress_bar = st.progress(0)
status_text = st.empty()
# Dashboard components
tab1, tab2, tab3 = st.tabs(["📊 Training Progress", "🎬 Agent Evolution", "📋 Q-Table Visualization"])
with tab1:
col_metrics1, col_metrics2, col_metrics3, col_metrics4 = st.columns(4)
with col_metrics1:
current_episode_metric = st.empty()
with col_metrics2:
avg_reward_metric = st.empty()
with col_metrics3:
avg_steps_metric = st.empty()
with col_metrics4:
success_rate_metric = st.empty()
metrics_chart = st.empty()
epsilon_chart = st.empty()
with tab2:
video_placeholder = st.empty()
with tab3:
qtable_visualization = st.empty()
# Training loop
start_time = time.time()
for episode in range(n_episodes):
# Calculate epsilon for this episode
epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
epsilons.append(epsilon)
state, info = env.reset()
total_reward = 0
# Episode steps
for step in range(max_steps):
action = epsilon_greedy_policy(Qtable, state, epsilon)
new_state, reward, terminated, truncated, _ = env.step(action)
# Update Q-table
Qtable[state][action] += learning_rate * (
reward + gamma * np.max(Qtable[new_state, :]) - Qtable[state][action]
)
total_reward += reward
if terminated or truncated:
break
state = new_state
# Evaluation at checkpoints
if episode in checkpoint_episodes or episode == n_episodes - 1:
mean_reward, std_reward, mean_steps, success_rate = evaluate_agent(
env, max_steps, n_eval_episodes, Qtable
)
reward_log.append((episode, mean_reward, std_reward))
steps_log.append((episode, mean_steps))
# Create and store video of agent behavior
try:
video_frames = create_agent_video(eval_env, Qtable, max_steps=video_length)
videos[episode] = video_frames
except Exception as e:
st.warning(f"Could not create video for episode {episode}: {str(e)}")
videos[episode] = []
# Take Q-table snapshot
qtable_snapshots[episode] = Qtable.copy()
# Update metrics display
current_episode_metric.metric("Episodes", f"{episode}/{n_episodes}",
delta=f"{episode/n_episodes:.1%}")
avg_reward_metric.metric("Avg. Reward", f"{mean_reward:.2f}",
delta=f"{mean_reward - reward_log[-2][1]:.2f}" if len(reward_log) > 1 else None)
avg_steps_metric.metric("Avg. Steps", f"{mean_steps:.1f}")
success_rate_metric.metric("Success Rate", f"{success_rate:.1f}%")
# Update progress charts
if reward_log:
# Prepare data for plots
progress_df = pd.DataFrame(
reward_log, columns=["Episode", "Mean Reward", "Std Reward"]
)
steps_df = pd.DataFrame(steps_log, columns=["Episode", "Mean Steps"])
# Create subplots
fig = make_subplots(specs=[[{"secondary_y": True}]])
# Add reward line
fig.add_trace(
go.Scatter(
x=progress_df["Episode"],
y=progress_df["Mean Reward"],
mode="lines+markers",
name="Mean Reward",
line=dict(color="#1f77b4", width=3),
marker=dict(size=8)
)
)
# Add steps line on secondary axis
fig.add_trace(
go.Scatter(
x=steps_df["Episode"],
y=steps_df["Mean Steps"],
mode="lines+markers",
name="Mean Steps",
line=dict(color="#ff7f0e", width=3, dash="dot"),
marker=dict(size=8)
),
secondary_y=True
)
# Add confidence interval for reward
fig.add_trace(
go.Scatter(
x=progress_df["Episode"].tolist() + progress_df["Episode"].tolist()[::-1],
y=(progress_df["Mean Reward"] + progress_df["Std Reward"]).tolist() +
(progress_df["Mean Reward"] - progress_df["Std Reward"]).tolist()[::-1],
fill="toself",
fillcolor="rgba(31, 119, 180, 0.2)",
line=dict(color="rgba(255,255,255,0)"),
hoverinfo="skip",
showlegend=False
)
)
# Update layout
fig.update_layout(
title="Agent Performance Over Training",
xaxis_title="Training Episode",
margin=dict(l=20, r=20, t=40, b=20),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
height=400
)
# Set y-axes titles
fig.update_yaxes(title_text="Reward", secondary_y=False)
fig.update_yaxes(title_text="Steps", secondary_y=True)
metrics_chart.plotly_chart(fig, use_container_width=True)
# Epsilon decay chart
epsilon_df = pd.DataFrame({
"Episode": list(range(len(epsilons))),
"Epsilon": epsilons
})
epsilon_fig = px.line(
epsilon_df,
x="Episode",
y="Epsilon",
title="Exploration Rate (Epsilon) Decay"
)
epsilon_fig.update_layout(
xaxis_title="Training Episode",
yaxis_title="Epsilon Value",
height=250,
margin=dict(l=20, r=20, t=40, b=20)
)
epsilon_chart.plotly_chart(epsilon_fig, use_container_width=True)
# Update Q-table visualization
qtable_fig = px.imshow(
Qtable,
labels=dict(x="Actions", y="States", color="Q-Value"),
x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'],
zmin=Qtable.min(),
zmax=Qtable.max(),
color_continuous_scale="Viridis"
)
qtable_fig.update_layout(
title=f"Q-table at Episode {episode}",
height=600,
margin=dict(l=20, r=20, t=40, b=20)
)
qtable_visualization.plotly_chart(qtable_fig, use_container_width=True)
# Q-table snapshot at regular intervals
if episode % log_freq == 0:
qtable_snapshots[episode] = Qtable.copy()
# Update progress
if episode % 100 == 0:
elapsed = time.time() - start_time
estimated = elapsed / (episode + 1) * (n_episodes - episode - 1) if episode > 0 else 0
status_text.text(f"Training in progress... Time elapsed: {elapsed:.1f}s | Estimated time remaining: {estimated:.1f}s")
progress_bar.progress((episode + 1) / n_episodes)
# Training complete
progress_bar.progress(1.0)
status_text.success(f"✅ Training completed in {time.time() - start_time:.1f} seconds!")
# Store results in session state for persistence
st.session_state.trained_qtable = Qtable
st.session_state.agent_videos = videos
st.session_state.training_completed = True
# Final evaluation
final_mean, final_std, final_steps, final_success = evaluate_agent(
env, max_steps, n_eval_episodes * 2, Qtable # Double evaluation episodes for final eval
)
# Store final metrics
st.session_state.final_metrics = {
"mean_reward": final_mean,
"std_reward": final_std,
"mean_steps": final_steps,
"success_rate": final_success,
"q_min": Qtable.min(),
"q_max": Qtable.max()
}
return Qtable, videos
# Environment setup
env = gym.make("Taxi-v3")
eval_env = gym.make("Taxi-v3", render_mode="rgb_array")
# Create training button if training hasn't completed
if not st.session_state.training_completed:
train_col1, train_col2 = st.columns([3, 1])
with train_col1:
st.write("") # For spacing
with train_col2:
start_training = st.button("🚀 Start Training", type="primary", use_container_width=True)
# Start training when button is clicked
if start_training:
params = {
"n_episodes": n_episodes,
"learning_rate": learning_rate,
"gamma": gamma,
"max_steps": max_steps,
"min_epsilon": min_epsilon,
"max_epsilon": max_epsilon,
"decay_rate": decay_rate,
"log_freq": log_freq,
"eval_every": eval_every,
"n_eval_episodes": n_eval_episodes,
"video_length": video_length if 'video_length' in locals() else 50
}
trained_qtable, agent_videos = train_agent(env, eval_env, params)
# If training is completed, show results
if st.session_state.training_completed:
# Create tabs for different visualizations
tab1, tab2, tab3 = st.tabs(["📊 Training Results", "🎬 Agent Evolution", "📋 Q-Table Visualization"])
with tab1:
# Summary metrics in nice boxes
st.markdown("### 📊 Final Performance")
metrics = st.session_state.final_metrics
metric_cols = st.columns(4)
with metric_cols[0]:
st.metric("Final Average Reward", f"{metrics['mean_reward']:.2f}", delta=f"±{metrics['std_reward']:.2f}")
with metric_cols[1]:
st.metric("Average Steps to Complete", f"{metrics['mean_steps']:.1f}")
with metric_cols[2]:
st.metric("Success Rate", f"{metrics['success_rate']:.1f}%")
with metric_cols[3]:
st.metric("Q-values Range", f"{metrics['q_min']:.2f} to {metrics['q_max']:.2f}")
# Download trained Q-table
st.subheader("Export Model")
# Convert Q-table to bytes for download
def get_table_download_link(array):
csvfile = BytesIO()
np.save(csvfile, array)
b64 = base64.b64encode(csvfile.getvalue()).decode()
href = f'<a href="data:application/octet-stream;base64,{b64}" download="qtable.npy">Download Q-table (.npy)</a>'
return href
st.markdown(get_table_download_link(st.session_state.trained_qtable), unsafe_allow_html=True)
with tab2:
# Create video selection slider
video_episodes = sorted(list(st.session_state.agent_videos.keys()))
if video_episodes:
selected_episode = st.select_slider(
"Select checkpoint to view agent behavior:",
options=video_episodes,
format_func=lambda x: f"Episode {x} ({x/st.session_state.training_params['n_episodes']:.0%})"
)
# Display the selected video
st.markdown(f"### Agent Behavior at Episode {selected_episode} ({selected_episode/st.session_state.training_params['n_episodes']:.0%})")
video_html = frames_to_html_video(st.session_state.agent_videos[selected_episode])
st.markdown(video_html, unsafe_allow_html=True)
# Add explanation of agent behavior
st.markdown("""
#### What Am I Looking At?
This animated visualization shows how the trained agent behaves at different stages of training.
You can observe:
- The **yellow square** represents the taxi
- The **letters (R, G, Y, B)** represent four fixed locations
- The **blue letter** represents the passenger pickup location
- The **purple letter** represents the passenger dropoff destination
- When the passenger is in the taxi, the taxi turns green
The agent makes decisions based on its learned Q-values. Early in training, movements may appear random as the agent explores.
Later in training, the agent should take more direct routes to complete the task efficiently.
""")
with tab3:
# Q-table visualization
st.markdown("### Q-Table Visualization")
st.info("This heatmap shows the learned Q-values that guide the agent's decision making.")
# Generate a heatmap of the Q-table
qtable_fig = px.imshow(
st.session_state.trained_qtable,
labels=dict(x="Actions", y="States", color="Q-Value"),
x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'],
zmin=st.session_state.final_metrics["q_min"],
zmax=st.session_state.final_metrics["q_max"],
color_continuous_scale="Viridis"
)
qtable_fig.update_layout(
title="Final Q-table",
height=600,
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(qtable_fig, use_container_width=True)
# Add Q-table explanation
st.markdown("""
#### Understanding the Q-Table
The Q-table is the heart of the Q-learning algorithm:
- Each **row** represents a different state (there are 500 possible states in Taxi-v3)
- Each **column** represents an action (South, North, East, West, Pickup, Dropoff)
- The **values** (colors) represent the expected future reward for taking that action in that state
- **Brighter colors** indicate higher expected rewards
The agent selects actions by choosing the highest value (brightest color) for its current state.
""")
# Add educational resources at the bottom
with st.expander("📚 Learn More About Q-Learning"):
st.markdown("""
### Key Concepts in Q-Learning
* **Q-Value**: Represents the expected future reward for taking action A in state S
* **Exploration vs Exploitation**: Balancing between trying new actions and using known good actions
* **Learning Rate (α)**: Controls how much new information overrides old information
* **Discount Factor (γ)**: Determines the importance of future rewards
* **Epsilon-greedy Policy**: A strategy that balances exploration and exploitation
### Taxi-v3 Environment Details
The Taxi environment consists of a 5x5 grid world where a taxi needs to:
1. Navigate to the passenger's location
2. Pick up the passenger
3. Navigate to the destination
4. Drop off the passenger
**Actions**:
- Move South (0)
- Move North (1)
- Move East (2)
- Move West (3)
- Pickup passenger (4)
- Dropoff passenger (5)
**Rewards**:
- -1 per time step
- +20 for successful dropoff
- -10 for illegal pickup/dropoff actions
""")
# Footer
st.markdown("""
<div style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #eee;">
<p style="color: #64748b;">Interactive Q-Learning Dashboard for Reinforcement Learning Education</p>
</div>
""", unsafe_allow_html=True)