# Define state positions and mappings
state_coords = {
0: (0, 0),
1: (1, 0),
2: (2, 0),
3: (3, 0),
4: (0, 1),
5: (1, 1),
6: (2, 1),
7: (3, 1),
8: (0, 2),
9: (1, 2),
10: (2, 2),
11: (3, 2),
}
terminal_states = {3: 1.0, 7: -1.0}
wall = {5}
states = list(range(12))
actions = [0, 1, 2, 3] # east, north, south, west
step_cost = -0.04
gamma = 1.0
theta = 1e-4
action_delta = {0: (1, 0), 1: (0, -1), 2: (0, 1), 3: (-1, 0)}
state_pos = {s: (x, y) for s, (x, y) in state_coords.items()}
pos_state = {v: k for k, v in state_pos.items()}
def get_transitions(s, a):
if s in terminal_states or s in wall:
return [(1.0, s)]
x, y = state_pos[s]
results = []
for prob, direction in zip([0.8, 0.1, 0.1], [a, (a + 1) % 4, (a + 3) % 4]):
dx, dy = action_delta[direction]
nx, ny = x + dx, y + dy
ns = pos_state.get((nx, ny), s)
if ns in wall:
ns = s
results.append((prob, ns))
return results