N_ROWS, N_COLS = 4, 12
GOAL = (3, 11)
START = (3, 0)
CLIFF_COLS = range(1, 11)
# --- learning curve
plt.figure(figsize=(9, 3.5))
window = 20
smoothed = np.convolve(episode_returns, np.ones(window) / window, mode="valid")
plt.plot(episode_returns, alpha=0.3, label="per-episode return")
plt.plot(np.arange(window - 1, len(episode_returns)), smoothed, label=f"moving average (window={window})")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.title("SARSA learning curve on CliffWalking-v1")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"{images_dir}/sarsa_cliff_learning_curve.png", dpi=150, bbox_inches="tight")
plt.show()
# --- value-function heatmap
V = np.full((N_ROWS, N_COLS), np.nan)
policy = np.full((N_ROWS, N_COLS), -1, dtype=int)
for s, q in agent.q_table.items():
row, col = divmod(int(s), N_COLS)
V[row, col] = q.max()
policy[row, col] = int(np.argmax(q))
fig, ax = plt.subplots(figsize=(12, 3.5))
sns.heatmap(V, annot=True, fmt=".1f", cmap="coolwarm", cbar=True, ax=ax,
xticklabels=range(N_COLS), yticklabels=range(N_ROWS))
ax.add_patch(plt.Rectangle((START[1], START[0]), 1, 1, fill=False, edgecolor="green", lw=3))
ax.add_patch(plt.Rectangle((GOAL[1], GOAL[0]), 1, 1, fill=False, edgecolor="gold", lw=3))
for c in CLIFF_COLS:
ax.add_patch(plt.Rectangle((c, 3), 1, 1, fill=False, edgecolor="red", lw=3))
ax.set_title("Max Q-value per state, start (green), cliff (red), goal (gold)")
ax.set_xlabel("Column")
ax.set_ylabel("Row")
plt.savefig(f"{images_dir}/sarsa_cliff_q_heatmap.png", dpi=150, bbox_inches="tight")
plt.show()
# --- greedy policy arrows
arrow_chars = ["\u2191", "\u2192", "\u2193", "\u2190"] # up, right, down, left
fig, ax = plt.subplots(figsize=(12, 3.5))
for r in range(N_ROWS):
for c in range(N_COLS):
if (r, c) == START:
ax.text(c + 0.5, r + 0.5, "S", ha="center", va="center", color="green", fontsize=14, fontweight="bold")
elif (r, c) == GOAL:
ax.text(c + 0.5, r + 0.5, "G", ha="center", va="center", color="goldenrod", fontsize=14, fontweight="bold")
elif r == 3 and c in CLIFF_COLS:
ax.text(c + 0.5, r + 0.5, "\u2620", ha="center", va="center", color="red", fontsize=14)
elif policy[r, c] >= 0:
ax.text(c + 0.5, r + 0.5, arrow_chars[policy[r, c]], ha="center", va="center", fontsize=18)
ax.set_xlim(0, N_COLS); ax.set_ylim(N_ROWS, 0)
ax.set_xticks(range(N_COLS + 1)); ax.set_yticks(range(N_ROWS + 1))
ax.set_aspect("equal"); ax.grid(True, alpha=0.3)
ax.set_title("Greedy policy from learned Q-table (S=start, G=goal, \u2620=cliff)")
plt.savefig(f"{images_dir}/sarsa_cliff_policy.png", dpi=150, bbox_inches="tight")
plt.show()