import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import colormaps, patheffects
import matplotlib.cm as cm
def _draw_panel(ax, Q, states, terminals, vmin, vmax, cmap, panel_title):
ax.set_xlim(-0.5, len(states) - 0.5)
ax.set_ylim(-1.1, 1.3)
ax.axis("off")
ax.set_title(panel_title, fontsize=12)
stroke = [patheffects.withStroke(linewidth=2.5, foreground="white")]
denom = max(vmax - vmin, 1e-9)
for i, s in enumerate(states):
if s in terminals:
ax.add_patch(mpatches.Rectangle(
(i - 0.45, -0.5), 0.9, 1.0,
facecolor="#d9d9d9", edgecolor="black", linewidth=0.8))
ax.text(i, 0, "T", ha="center", va="center", fontsize=13, color="#555")
else:
q_left, q_right = float(Q[i, 0]), float(Q[i, 1])
ax.add_patch(mpatches.Rectangle(
(i - 0.45, -0.5), 0.45, 1.0,
facecolor=cmap((q_left - vmin) / denom),
edgecolor="black", linewidth=0.8))
ax.add_patch(mpatches.Rectangle(
(i, -0.5), 0.45, 1.0,
facecolor=cmap((q_right - vmin) / denom),
edgecolor="black", linewidth=0.8))
ax.text(i - 0.225, 0.32, "←", ha="center", va="center",
fontsize=9, color="#444")
ax.text(i + 0.225, 0.32, "→", ha="center", va="center",
fontsize=9, color="#444")
ax.text(i - 0.225, -0.18, f"{q_left:.2f}", ha="center", va="center",
fontsize=10, color="black", path_effects=stroke)
ax.text(i + 0.225, -0.18, f"{q_right:.2f}", ha="center", va="center",
fontsize=10, color="black", path_effects=stroke)
if q_left > 0 or q_right > 0:
if q_right >= q_left:
ax.annotate("", xy=(i + 0.3, 0.85), xytext=(i - 0.3, 0.85),
arrowprops=dict(arrowstyle="->", color="black", lw=2.5))
else:
ax.annotate("", xy=(i - 0.3, 0.85), xytext=(i + 0.3, 0.85),
arrowprops=dict(arrowstyle="->", color="black", lw=2.5))
ax.text(i, -0.85, f"{s}", ha="center", va="center", fontsize=11)
def plot_q_panel(Q, states, terminals, vmin, vmax, title):
fig, ax = plt.subplots(figsize=(11, 2.6))
cmap = colormaps["viridis"]
_draw_panel(ax, Q, states, terminals, vmin, vmax, cmap, title)
sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm.set_array([])
fig.colorbar(sm, ax=ax, orientation="horizontal",
fraction=0.06, pad=0.12, shrink=0.6, label="Q(s, a)")
plt.tight_layout()
plt.show()
def deterministic_robot_cleaning_traced():
state = [1, 2, 3, 4, 5, 6]
action = [-1, 1]
Q = np.zeros((len(state), len(action)))
Qold = Q.copy()
L = 15
gamma = 0.5
epsilon = 0.001
history = [Q.copy()]
for _ in range(1, L + 1):
for ii in range(len(state)):
for jj in range(len(action)):
Q[ii, jj] = (
reward(state[ii], action[jj])
+ gamma * Q[model(state[ii], action[jj]) - 1, jj]
)
history.append(Q.copy())
if np.abs(np.sum(Q - Qold)) < epsilon:
break
Qold = Q.copy()
return history
history = deterministic_robot_cleaning_traced()
states_list = [1, 2, 3, 4, 5, 6]
terminals = {1, 6}
suptitle = "Q-value iteration (deterministic cleaning robot)"
vmax = max(1e-6, float(np.ceil(max(Q.max() for Q in history) * 10) / 10))
vmin = 0.0
print(suptitle)
for k, Q in enumerate(history):
panel_title = "Initial (Q = 0)" if k == 0 else f"Iteration {k}"
plot_q_panel(Q, states=states_list, terminals=terminals,
vmin=vmin, vmax=vmax, title=panel_title)