import streamlit as st import numpy as np import gymnasium as gym import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import base64 from io import BytesIO from PIL import Image import time # Use session state to persist data across reruns if 'trained_qtable' not in st.session_state: st.session_state.trained_qtable = None if 'agent_videos' not in st.session_state: st.session_state.agent_videos = {} if 'training_completed' not in st.session_state: st.session_state.training_completed = False if 'training_params' not in st.session_state: st.session_state.training_params = {} if 'final_metrics' not in st.session_state: st.session_state.final_metrics = {} # Set page configuration for a cleaner look st.set_page_config( page_title="Taxi-v3 Q-Learning Dashboard", page_icon="🚕", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS to make the dashboard look cleaner st.markdown(""" """, unsafe_allow_html=True) # Header st.markdown("""

🚕 Taxi-v3 Q-Learning Dashboard

Interactive Reinforcement Learning Visualization

""", unsafe_allow_html=True) # Create a two-column layout for the main dashboard col1, col2 = st.columns([3, 2]) with col2: st.markdown("### 🎮 Environment Preview") # Fix: Create a proper environment preview by resetting first preview_env = gym.make("Taxi-v3", render_mode="rgb_array") preview_env.reset() # Reset the environment first env_preview = preview_env.render() st.image(env_preview, caption="Taxi-v3 Environment", use_column_width=True) st.markdown("""

📝 About this Environment

The Taxi-v3 task involves navigating a taxi to pick up a passenger and drop them off at a destination.

""", unsafe_allow_html=True) with col1: st.markdown("### ⚙️ Training Parameters") # Only show parameters if training hasn't completed yet if not st.session_state.training_completed: # Create a cleaner parameter input section col_a, col_b = st.columns(2) with col_a: n_episodes = st.number_input("Training Episodes", min_value=1000, max_value=100000, value=25000, step=1000) learning_rate = st.slider("Learning Rate (α)", 0.01, 1.0, 0.7, 0.01, format="%.2f", help="Controls how much new information overrides old information") gamma = st.slider("Discount Factor (γ)", 0.80, 0.99, 0.95, 0.01, format="%.2f", help="Determines the importance of future rewards") max_steps = st.slider("Max Steps per Episode", 50, 500, 99) with col_b: min_epsilon = st.slider("Min Exploration Rate (ε)", 0.01, 0.5, 0.05, 0.01, format="%.2f", help="Minimum probability of random action") max_epsilon = st.slider("Max Exploration Rate (ε)", 0.5, 1.0, 1.0, 0.01, format="%.2f", help="Starting probability of random action") decay_rate = st.slider("Epsilon Decay Rate", 0.0001, 0.01, 0.001, 0.0001, format="%.4f", help="How quickly exploration decreases") n_eval_episodes = st.slider("Evaluation Episodes", 10, 200, 100, help="Number of episodes to evaluate performance") # Additional parameters in a collapsed section with st.expander("Advanced Settings"): log_freq = st.slider("Q-table Update Frequency (every N episodes)", 1, 1000, 500) eval_every = st.slider("Evaluation Frequency (% of training)", 5, 50, 10, help="How often to evaluate agent performance") video_length = st.slider("Evaluation Video Length (steps)", 10, 200, 50, help="Maximum steps to show in visualization videos") else: # If training is completed, show the parameters that were used st.info("Training completed with the following parameters:") params = st.session_state.training_params col_a, col_b = st.columns(2) with col_a: st.write(f"**Training Episodes**: {params['n_episodes']}") st.write(f"**Learning Rate (α)**: {params['learning_rate']}") st.write(f"**Discount Factor (γ)**: {params['gamma']}") st.write(f"**Max Steps per Episode**: {params['max_steps']}") with col_b: st.write(f"**Min Exploration Rate (ε)**: {params['min_epsilon']}") st.write(f"**Max Exploration Rate (ε)**: {params['max_epsilon']}") st.write(f"**Epsilon Decay Rate**: {params['decay_rate']}") st.write(f"**Evaluation Episodes**: {params['n_eval_episodes']}") # Option to reset and train again if st.button("Reset and Train Again", type="secondary"): st.session_state.training_completed = False st.session_state.trained_qtable = None st.session_state.agent_videos = {} st.session_state.training_params = {} st.session_state.final_metrics = {} st.experimental_rerun() # Initialize Q-table def initialize_q_table(state_space, action_space): return np.zeros((state_space, action_space)) # Policies def greedy_policy(Qtable, state): return np.argmax(Qtable[state, :]) def epsilon_greedy_policy(Qtable, state, epsilon): if np.random.uniform(0, 1) > epsilon: return greedy_policy(Qtable, state) else: return env.action_space.sample() # Function to create animation of agent behavior def create_agent_video(env, Q, max_steps=100, seed=None): frames = [] state, info = env.reset(seed=seed) # Add the initial frame frames.append(env.render()) for _ in range(max_steps): # Choose action based on greedy policy action = greedy_policy(Q, state) # Step in environment state, reward, terminated, truncated, _ = env.step(action) # Render the frame after taking action frames.append(env.render()) # Break if episode is done if terminated or truncated: break return frames # Evaluation function def evaluate_agent(env, max_steps, n_eval_episodes, Q, seed=None): rewards = [] steps = [] success_count = 0 for episode in range(n_eval_episodes): state, info = env.reset(seed=seed[episode] if seed else None) total_rewards_ep = 0 num_steps = 0 for step in range(max_steps): action = greedy_policy(Q, state) state, reward, terminated, truncated, _ = env.step(action) total_rewards_ep += reward num_steps += 1 if terminated or truncated: if reward > 0: # Successfully completed the task success_count += 1 break rewards.append(total_rewards_ep) steps.append(num_steps) success_rate = success_count / n_eval_episodes * 100 return np.mean(rewards), np.std(rewards), np.mean(steps), success_rate # Function to convert frames to HTML video def frames_to_html_video(frames, fps=5): if not frames: return "

No frames available

" try: # Create a PIL image from each frame pil_images = [Image.fromarray(frame) for frame in frames] # Save as animated GIF to a BytesIO object buffer = BytesIO() pil_images[0].save( buffer, format='GIF', save_all=True, append_images=pil_images[1:], duration=1000/fps, loop=0 ) buffer.seek(0) # Encode as base64 encoded = base64.b64encode(buffer.read()).decode("utf-8") # Embed in HTML html = f'agent behavior' return html except Exception as e: return f"

Error generating video: {str(e)}

" # Training function def train_agent(env, eval_env, params): # Unpack parameters n_episodes = params["n_episodes"] learning_rate = params["learning_rate"] gamma = params["gamma"] max_steps = params["max_steps"] min_epsilon = params["min_epsilon"] max_epsilon = params["max_epsilon"] decay_rate = params["decay_rate"] log_freq = params["log_freq"] eval_every = params["eval_every"] n_eval_episodes = params["n_eval_episodes"] video_length = params["video_length"] # Store parameters in session state st.session_state.training_params = params # Initialize Q-table Qtable = initialize_q_table(env.observation_space.n, env.action_space.n) # Training metrics reward_log = [] steps_log = [] qtable_snapshots = {} videos = {} epsilons = [] # Calculate checkpoints (ensure at least one checkpoint at the beginning) num_checkpoints = 10 # Number of checkpoints to create checkpoint_episodes = [int(n_episodes * i / num_checkpoints) for i in range(1, num_checkpoints + 1)] checkpoint_episodes[0] = max(1, checkpoint_episodes[0]) # Ensure first checkpoint is at least at episode 1 # Progress tracking progress_bar = st.progress(0) status_text = st.empty() # Dashboard components tab1, tab2, tab3 = st.tabs(["📊 Training Progress", "🎬 Agent Evolution", "📋 Q-Table Visualization"]) with tab1: col_metrics1, col_metrics2, col_metrics3, col_metrics4 = st.columns(4) with col_metrics1: current_episode_metric = st.empty() with col_metrics2: avg_reward_metric = st.empty() with col_metrics3: avg_steps_metric = st.empty() with col_metrics4: success_rate_metric = st.empty() metrics_chart = st.empty() epsilon_chart = st.empty() with tab2: video_placeholder = st.empty() with tab3: qtable_visualization = st.empty() # Training loop start_time = time.time() for episode in range(n_episodes): # Calculate epsilon for this episode epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode) epsilons.append(epsilon) state, info = env.reset() total_reward = 0 # Episode steps for step in range(max_steps): action = epsilon_greedy_policy(Qtable, state, epsilon) new_state, reward, terminated, truncated, _ = env.step(action) # Update Q-table Qtable[state][action] += learning_rate * ( reward + gamma * np.max(Qtable[new_state, :]) - Qtable[state][action] ) total_reward += reward if terminated or truncated: break state = new_state # Evaluation at checkpoints if episode in checkpoint_episodes or episode == n_episodes - 1: mean_reward, std_reward, mean_steps, success_rate = evaluate_agent( env, max_steps, n_eval_episodes, Qtable ) reward_log.append((episode, mean_reward, std_reward)) steps_log.append((episode, mean_steps)) # Create and store video of agent behavior try: video_frames = create_agent_video(eval_env, Qtable, max_steps=video_length) videos[episode] = video_frames except Exception as e: st.warning(f"Could not create video for episode {episode}: {str(e)}") videos[episode] = [] # Take Q-table snapshot qtable_snapshots[episode] = Qtable.copy() # Update metrics display current_episode_metric.metric("Episodes", f"{episode}/{n_episodes}", delta=f"{episode/n_episodes:.1%}") avg_reward_metric.metric("Avg. Reward", f"{mean_reward:.2f}", delta=f"{mean_reward - reward_log[-2][1]:.2f}" if len(reward_log) > 1 else None) avg_steps_metric.metric("Avg. Steps", f"{mean_steps:.1f}") success_rate_metric.metric("Success Rate", f"{success_rate:.1f}%") # Update progress charts if reward_log: # Prepare data for plots progress_df = pd.DataFrame( reward_log, columns=["Episode", "Mean Reward", "Std Reward"] ) steps_df = pd.DataFrame(steps_log, columns=["Episode", "Mean Steps"]) # Create subplots fig = make_subplots(specs=[[{"secondary_y": True}]]) # Add reward line fig.add_trace( go.Scatter( x=progress_df["Episode"], y=progress_df["Mean Reward"], mode="lines+markers", name="Mean Reward", line=dict(color="#1f77b4", width=3), marker=dict(size=8) ) ) # Add steps line on secondary axis fig.add_trace( go.Scatter( x=steps_df["Episode"], y=steps_df["Mean Steps"], mode="lines+markers", name="Mean Steps", line=dict(color="#ff7f0e", width=3, dash="dot"), marker=dict(size=8) ), secondary_y=True ) # Add confidence interval for reward fig.add_trace( go.Scatter( x=progress_df["Episode"].tolist() + progress_df["Episode"].tolist()[::-1], y=(progress_df["Mean Reward"] + progress_df["Std Reward"]).tolist() + (progress_df["Mean Reward"] - progress_df["Std Reward"]).tolist()[::-1], fill="toself", fillcolor="rgba(31, 119, 180, 0.2)", line=dict(color="rgba(255,255,255,0)"), hoverinfo="skip", showlegend=False ) ) # Update layout fig.update_layout( title="Agent Performance Over Training", xaxis_title="Training Episode", margin=dict(l=20, r=20, t=40, b=20), legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ), height=400 ) # Set y-axes titles fig.update_yaxes(title_text="Reward", secondary_y=False) fig.update_yaxes(title_text="Steps", secondary_y=True) metrics_chart.plotly_chart(fig, use_container_width=True) # Epsilon decay chart epsilon_df = pd.DataFrame({ "Episode": list(range(len(epsilons))), "Epsilon": epsilons }) epsilon_fig = px.line( epsilon_df, x="Episode", y="Epsilon", title="Exploration Rate (Epsilon) Decay" ) epsilon_fig.update_layout( xaxis_title="Training Episode", yaxis_title="Epsilon Value", height=250, margin=dict(l=20, r=20, t=40, b=20) ) epsilon_chart.plotly_chart(epsilon_fig, use_container_width=True) # Update Q-table visualization qtable_fig = px.imshow( Qtable, labels=dict(x="Actions", y="States", color="Q-Value"), x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'], zmin=Qtable.min(), zmax=Qtable.max(), color_continuous_scale="Viridis" ) qtable_fig.update_layout( title=f"Q-table at Episode {episode}", height=600, margin=dict(l=20, r=20, t=40, b=20) ) qtable_visualization.plotly_chart(qtable_fig, use_container_width=True) # Q-table snapshot at regular intervals if episode % log_freq == 0: qtable_snapshots[episode] = Qtable.copy() # Update progress if episode % 100 == 0: elapsed = time.time() - start_time estimated = elapsed / (episode + 1) * (n_episodes - episode - 1) if episode > 0 else 0 status_text.text(f"Training in progress... Time elapsed: {elapsed:.1f}s | Estimated time remaining: {estimated:.1f}s") progress_bar.progress((episode + 1) / n_episodes) # Training complete progress_bar.progress(1.0) status_text.success(f"✅ Training completed in {time.time() - start_time:.1f} seconds!") # Store results in session state for persistence st.session_state.trained_qtable = Qtable st.session_state.agent_videos = videos st.session_state.training_completed = True # Final evaluation final_mean, final_std, final_steps, final_success = evaluate_agent( env, max_steps, n_eval_episodes * 2, Qtable # Double evaluation episodes for final eval ) # Store final metrics st.session_state.final_metrics = { "mean_reward": final_mean, "std_reward": final_std, "mean_steps": final_steps, "success_rate": final_success, "q_min": Qtable.min(), "q_max": Qtable.max() } return Qtable, videos # Environment setup env = gym.make("Taxi-v3") eval_env = gym.make("Taxi-v3", render_mode="rgb_array") # Create training button if training hasn't completed if not st.session_state.training_completed: train_col1, train_col2 = st.columns([3, 1]) with train_col1: st.write("") # For spacing with train_col2: start_training = st.button("🚀 Start Training", type="primary", use_container_width=True) # Start training when button is clicked if start_training: params = { "n_episodes": n_episodes, "learning_rate": learning_rate, "gamma": gamma, "max_steps": max_steps, "min_epsilon": min_epsilon, "max_epsilon": max_epsilon, "decay_rate": decay_rate, "log_freq": log_freq, "eval_every": eval_every, "n_eval_episodes": n_eval_episodes, "video_length": video_length if 'video_length' in locals() else 50 } trained_qtable, agent_videos = train_agent(env, eval_env, params) # If training is completed, show results if st.session_state.training_completed: # Create tabs for different visualizations tab1, tab2, tab3 = st.tabs(["📊 Training Results", "🎬 Agent Evolution", "📋 Q-Table Visualization"]) with tab1: # Summary metrics in nice boxes st.markdown("### 📊 Final Performance") metrics = st.session_state.final_metrics metric_cols = st.columns(4) with metric_cols[0]: st.metric("Final Average Reward", f"{metrics['mean_reward']:.2f}", delta=f"±{metrics['std_reward']:.2f}") with metric_cols[1]: st.metric("Average Steps to Complete", f"{metrics['mean_steps']:.1f}") with metric_cols[2]: st.metric("Success Rate", f"{metrics['success_rate']:.1f}%") with metric_cols[3]: st.metric("Q-values Range", f"{metrics['q_min']:.2f} to {metrics['q_max']:.2f}") # Download trained Q-table st.subheader("Export Model") # Convert Q-table to bytes for download def get_table_download_link(array): csvfile = BytesIO() np.save(csvfile, array) b64 = base64.b64encode(csvfile.getvalue()).decode() href = f'Download Q-table (.npy)' return href st.markdown(get_table_download_link(st.session_state.trained_qtable), unsafe_allow_html=True) with tab2: # Create video selection slider video_episodes = sorted(list(st.session_state.agent_videos.keys())) if video_episodes: selected_episode = st.select_slider( "Select checkpoint to view agent behavior:", options=video_episodes, format_func=lambda x: f"Episode {x} ({x/st.session_state.training_params['n_episodes']:.0%})" ) # Display the selected video st.markdown(f"### Agent Behavior at Episode {selected_episode} ({selected_episode/st.session_state.training_params['n_episodes']:.0%})") video_html = frames_to_html_video(st.session_state.agent_videos[selected_episode]) st.markdown(video_html, unsafe_allow_html=True) # Add explanation of agent behavior st.markdown(""" #### What Am I Looking At? This animated visualization shows how the trained agent behaves at different stages of training. You can observe: - The **yellow square** represents the taxi - The **letters (R, G, Y, B)** represent four fixed locations - The **blue letter** represents the passenger pickup location - The **purple letter** represents the passenger dropoff destination - When the passenger is in the taxi, the taxi turns green The agent makes decisions based on its learned Q-values. Early in training, movements may appear random as the agent explores. Later in training, the agent should take more direct routes to complete the task efficiently. """) with tab3: # Q-table visualization st.markdown("### Q-Table Visualization") st.info("This heatmap shows the learned Q-values that guide the agent's decision making.") # Generate a heatmap of the Q-table qtable_fig = px.imshow( st.session_state.trained_qtable, labels=dict(x="Actions", y="States", color="Q-Value"), x=['South', 'North', 'East', 'West', 'Pickup', 'Dropoff'], zmin=st.session_state.final_metrics["q_min"], zmax=st.session_state.final_metrics["q_max"], color_continuous_scale="Viridis" ) qtable_fig.update_layout( title="Final Q-table", height=600, margin=dict(l=20, r=20, t=40, b=20) ) st.plotly_chart(qtable_fig, use_container_width=True) # Add Q-table explanation st.markdown(""" #### Understanding the Q-Table The Q-table is the heart of the Q-learning algorithm: - Each **row** represents a different state (there are 500 possible states in Taxi-v3) - Each **column** represents an action (South, North, East, West, Pickup, Dropoff) - The **values** (colors) represent the expected future reward for taking that action in that state - **Brighter colors** indicate higher expected rewards The agent selects actions by choosing the highest value (brightest color) for its current state. """) # Add educational resources at the bottom with st.expander("📚 Learn More About Q-Learning"): st.markdown(""" ### Key Concepts in Q-Learning * **Q-Value**: Represents the expected future reward for taking action A in state S * **Exploration vs Exploitation**: Balancing between trying new actions and using known good actions * **Learning Rate (α)**: Controls how much new information overrides old information * **Discount Factor (γ)**: Determines the importance of future rewards * **Epsilon-greedy Policy**: A strategy that balances exploration and exploitation ### Taxi-v3 Environment Details The Taxi environment consists of a 5x5 grid world where a taxi needs to: 1. Navigate to the passenger's location 2. Pick up the passenger 3. Navigate to the destination 4. Drop off the passenger **Actions**: - Move South (0) - Move North (1) - Move East (2) - Move West (3) - Pickup passenger (4) - Dropoff passenger (5) **Rewards**: - -1 per time step - +20 for successful dropoff - -10 for illegal pickup/dropoff actions """) # Footer st.markdown("""

Interactive Q-Learning Dashboard for Reinforcement Learning Education

""", unsafe_allow_html=True)