TD(0) evaluation of terminal states

105 Views Asked by At

I'm going through Sutton and Barto's Introduction to Machine Learning and am currently reading into Temporal-Difference Learning TD(0) methods. In the textbook they use Random walk as a toy example, which works like this:

You have a chain of 5 states (1, 2, 3, 4, 5), where 1 and 5 are the terminal states. You have an agent that can move either left or right along the chain (i.e. 2 -> 3, 4 -> 3, 4 -> 5) until it reaches a terminal state (1 or 5). If the agent reaches terminal state 1, the episode ends and the agent receives a reward of 0. If the agent reaches terminal state 5, the episode ends and the agent receives a reward of 1.

We want to create an evaluation method that determines the value of each state for a given agent using the update rule $V(S_t) = V(S_t) + \alpha * (R_{t+1} + \lambda V(S_{t+1}) - V(S_t))$.

The evaluation algorithm that the textbook gives on Page 145 is:

  • Input Policy $\pi$ to be evaluated
  • Initialize V(S) arbitrarily $(V(S) = 0 \forall s \in S)$
  • Repeat for each episode:
    • Repeat (for each step in episode)
      • A = action given by $\pi$ for S
      • Take action A, observe reward R and next state $S'$
      • $V(S) = V(S) + \alpha[R + \lambda V(S') - V(S)]$
      • $S = S'$
    • Until S is terminal

For a random agent that chooses to go left or right with equal probability, the correct evaluation of that agent is: $ V(1) = 0, V(2) = \frac{1}{4}, V(3) = \frac{2}{4}, V(4) = \frac{3}{4}, V(5) = 1$.

One thing that is strange about this algorithm is that the terminal states will always remain at value 0 (the algorithm described above will always terminate with $V(5) = 0$.

The value of the terminal state, if set to anything $> 0$, will cause other states to bias high.

For example, if we have $\alpha = 0.1$, $\lambda = 1.0$, $V(5) = 1.0$, and $V(4) = 1.0$, and observe an episode where we transition from $V(4)$ to $V(5)$, the update will be:

$V(4) = V(4) + \alpha[R_5 + \lambda V(5) - V(4)] = 1.0 + 0.1 (2.0 - 1.0) = 1.1$

So when V(5) is set to its correct value 1.0, the TD(0) update rule is biasing V(4) towards being 2.0, which doesn't make sense. So for the values of all non-terminal states to be correct, we need to set the terminal state values to 0, which is also incorrect.

This issue is not really mentioned in the book, and I'm wondering if I'm misunderstanding something about how this works.

Thanks for the help!