# Greedy-policy plots: trace the deterministic argmax-Q path from start to goal.
N_ROWS, N_COLS = 4, 12
START = (3, 0)
GOAL = (3, 11)
CLIFF_COLS = list(range(1, 11))
ARROWS = ["\u2191", "\u2192", "\u2193", "\u2190"]
ACTION_DELTAS = [(-1, 0), (0, 1), (1, 0), (0, -1)]
def greedy_path(q_table, max_steps=200):
r, c = START
path = [(r, c)]
for _ in range(max_steps):
if (r, c) == GOAL:
break
if r == 3 and c in CLIFF_COLS:
break # fell off the cliff during greedy rollout
s = r * N_COLS + c
a = int(np.argmax(q_table[s]))
dr, dc = ACTION_DELTAS[a]
r = max(0, min(N_ROWS - 1, r + dr))
c = max(0, min(N_COLS - 1, c + dc))
path.append((r, c))
return path
def plot_grid(ax, q_table, title):
for r in range(N_ROWS):
for c in range(N_COLS):
ax.add_patch(plt.Rectangle((c, r), 1, 1, fill=False, edgecolor="#cccccc", lw=1))
for c in CLIFF_COLS:
ax.add_patch(plt.Rectangle((c, 3), 1, 1, color="#ffd6d6"))
ax.add_patch(plt.Rectangle((START[1], START[0]), 1, 1, fill=False, edgecolor="green", lw=3))
ax.add_patch(plt.Rectangle((GOAL[1], GOAL[0]), 1, 1, fill=False, edgecolor="goldenrod", lw=3))
for r in range(N_ROWS):
for c in range(N_COLS):
if (r, c) == START or (r, c) == GOAL or (r == 3 and c in CLIFF_COLS):
continue
s = r * N_COLS + c
if s in q_table:
a = int(np.argmax(q_table[s]))
ax.text(c + 0.5, r + 0.5, ARROWS[a], ha="center", va="center", fontsize=14)
path = greedy_path(q_table)
xs = [c + 0.5 for r, c in path]
ys = [r + 0.5 for r, c in path]
ax.plot(xs, ys, color="#1f77b4", lw=2.5, alpha=0.85)
ax.text(START[1] + 0.5, START[0] + 0.5, "S", ha="center", va="center",
color="green", fontsize=14, fontweight="bold")
ax.text(GOAL[1] + 0.5, GOAL[0] + 0.5, "G", ha="center", va="center",
color="goldenrod", fontsize=14, fontweight="bold")
ax.set_xlim(0, N_COLS); ax.set_ylim(N_ROWS, 0)
ax.set_aspect("equal")
ax.set_xticks([]); ax.set_yticks([])
ax.set_title(title)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
plot_grid(ax1, sarsa_q_last, "SARSA, greedy policy")
plot_grid(ax2, qlearn_q_last, "Q-learning, greedy policy")
plt.savefig(f"{images_dir}/sarsa_vs_q_paths.png", dpi=150, bbox_inches="tight")
plt.show()