Skip to main content
Open In Colab It is instructive to apply SARSA to a small, well-known environment so the learning dynamics are visible end-to-end. We use the Gymnasium CliffWalking-v1 environment described below.

Environment: CliffWalking-v1

CliffWalking-v1 is the canonical Gymnasium gridworld for tabular control. The agent moves on a 4×124 \times 12 grid: it starts at the bottom-left (state index 36) and must reach the bottom-right goal (state index 47). Every step yields reward 1-1, except stepping into any of the cells along the bottom edge between start and goal, the cliff, which yields reward 100-100 and resets the agent back to the start. The action space is Discrete(4) with the convention 0 = up, 1 = right, 2 = down, 3 = left, transitions are deterministic, and the observation is a single integer in {0,,47}\{0, \ldots, 47\} encoding row * 12 + col. With only 48 states it is small enough to learn with a plain Q-table, while also being the textbook environment (Sutton & Barto, Example 6.6) used to contrast SARSA’s safe path along the top edge against Q-learning’s optimal-but-risky path adjacent to the cliff.
import os
from collections import defaultdict

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tqdm import tqdm


class SARSAgent:
    """Tabular SARSA for a discrete-state, discrete-action Gymnasium env.

    Update: Q(S_t, A_t) <- Q(S_t, A_t) + alpha * (R_{t+1} + gamma * Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t))
    On terminal transitions, Q(S_{t+1}, .) is treated as 0.
    """

    def __init__(self, n_actions, learning_rate=0.5, discount_factor=1.0, epsilon=0.1):
        self.n_actions = n_actions
        self.alpha = learning_rate
        self.gamma = discount_factor
        self.epsilon = epsilon
        self.q_table = defaultdict(lambda: np.zeros(n_actions))

    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            return int(np.random.randint(self.n_actions))
        return int(np.argmax(self.q_table[state]))

    def learn(self, s, a, r, s_next, a_next, terminated):
        target = r + (0.0 if terminated else self.gamma * self.q_table[s_next][a_next])
        self.q_table[s][a] += self.alpha * (target - self.q_table[s][a])
env = gym.make("CliffWalking-v1")
n_actions = env.action_space.n
agent = SARSAgent(n_actions=n_actions, learning_rate=0.5, discount_factor=1.0, epsilon=0.1)

n_episodes = 500
episode_returns = np.empty(n_episodes)

for ep in tqdm(range(n_episodes), desc="Training"):
    s, _ = env.reset()
    a = agent.get_action(s)
    total_r = 0.0
    while True:
        s_next, r, terminated, truncated, _ = env.step(a)
        a_next = agent.get_action(s_next)
        agent.learn(s, a, r, s_next, a_next, terminated)
        total_r += r
        s, a = s_next, a_next
        if terminated or truncated:
            break
    episode_returns[ep] = total_r

env.close()
print(f"Final 50-episode mean return: {episode_returns[-50:].mean():.1f}")

Training:   0%|          | 0/500 [00:00<?, ?it/s]

Training:  51%|█████     | 253/500 [00:00<00:00, 2526.38it/s]

Training: 100%|██████████| 500/500 [00:00<00:00, 3117.58it/s]
Final 50-episode mean return: -22.6
N_ROWS, N_COLS = 4, 12
GOAL = (3, 11)
START = (3, 0)
CLIFF_COLS = range(1, 11)

# --- learning curve
plt.figure(figsize=(9, 3.5))
window = 20
smoothed = np.convolve(episode_returns, np.ones(window) / window, mode="valid")
plt.plot(episode_returns, alpha=0.3, label="per-episode return")
plt.plot(np.arange(window - 1, len(episode_returns)), smoothed, label=f"moving average (window={window})")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.title("SARSA learning curve on CliffWalking-v1")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"{images_dir}/sarsa_cliff_learning_curve.png", dpi=150, bbox_inches="tight")
plt.show()

# --- value-function heatmap
V = np.full((N_ROWS, N_COLS), np.nan)
policy = np.full((N_ROWS, N_COLS), -1, dtype=int)
for s, q in agent.q_table.items():
    row, col = divmod(int(s), N_COLS)
    V[row, col] = q.max()
    policy[row, col] = int(np.argmax(q))

fig, ax = plt.subplots(figsize=(12, 3.5))
sns.heatmap(V, annot=True, fmt=".1f", cmap="coolwarm", cbar=True, ax=ax,
            xticklabels=range(N_COLS), yticklabels=range(N_ROWS))
ax.add_patch(plt.Rectangle((START[1], START[0]), 1, 1, fill=False, edgecolor="green", lw=3))
ax.add_patch(plt.Rectangle((GOAL[1], GOAL[0]), 1, 1, fill=False, edgecolor="gold", lw=3))
for c in CLIFF_COLS:
    ax.add_patch(plt.Rectangle((c, 3), 1, 1, fill=False, edgecolor="red", lw=3))
ax.set_title("Max Q-value per state, start (green), cliff (red), goal (gold)")
ax.set_xlabel("Column")
ax.set_ylabel("Row")
plt.savefig(f"{images_dir}/sarsa_cliff_q_heatmap.png", dpi=150, bbox_inches="tight")
plt.show()

# --- greedy policy arrows
arrow_chars = ["\u2191", "\u2192", "\u2193", "\u2190"]  # up, right, down, left
fig, ax = plt.subplots(figsize=(12, 3.5))
for r in range(N_ROWS):
    for c in range(N_COLS):
        if (r, c) == START:
            ax.text(c + 0.5, r + 0.5, "S", ha="center", va="center", color="green", fontsize=14, fontweight="bold")
        elif (r, c) == GOAL:
            ax.text(c + 0.5, r + 0.5, "G", ha="center", va="center", color="goldenrod", fontsize=14, fontweight="bold")
        elif r == 3 and c in CLIFF_COLS:
            ax.text(c + 0.5, r + 0.5, "\u2620", ha="center", va="center", color="red", fontsize=14)
        elif policy[r, c] >= 0:
            ax.text(c + 0.5, r + 0.5, arrow_chars[policy[r, c]], ha="center", va="center", fontsize=18)
ax.set_xlim(0, N_COLS); ax.set_ylim(N_ROWS, 0)
ax.set_xticks(range(N_COLS + 1)); ax.set_yticks(range(N_ROWS + 1))
ax.set_aspect("equal"); ax.grid(True, alpha=0.3)
ax.set_title("Greedy policy from learned Q-table (S=start, G=goal, \u2620=cliff)")
plt.savefig(f"{images_dir}/sarsa_cliff_policy.png", dpi=150, bbox_inches="tight")
plt.show()
Output from cell 4 Output from cell 4 Output from cell 4 Key references: (Ma & Yu, 2016; Li, 2017; Bellemare et al., 2016; Jaderberg et al., 2016; Lillicrap et al., 2015)

References

  • Bellemare, M., Srinivasan, S., Ostrovski, G., Schaul, T., Saxton, D., et al. (2016). Unifying count-based exploration and intrinsic motivation.
  • Jaderberg, M., Mnih, V., Czarnecki, W., Schaul, T., Leibo, J., et al. (2016). Reinforcement Learning with Unsupervised Auxiliary Tasks.
  • Li, Y. (2017). Deep Reinforcement Learning: An Overview.
  • Lillicrap, T., Hunt, J., Pritzel, A., Heess, N., Erez, T., et al. (2015). Continuous control with deep reinforcement learning.
  • Ma, S., Yu, J. (2016). Transition-based versus State-based Reward Functions for MDPs with Value-at-Risk.