A Reinforcement Learning Riddle

I proved 1=01=0 starting from the formula for the on-policy distribution in episodic tasks. Obviously there is a mistake, can you spot it? πŸ€”.

The proof

From page 199 of Sutton's book:

Let h(s)h(s) denote the probability that an episode begins in each state ss, and let η(s)\eta(s) denote the number of time steps spent, on average, in state ss in a single episode. Time is spent in a sate ss if episodes start in ss, or if transitions are made into ss from a preceding state sˉ\bar s in which time is spent:

Ξ·(s)=h(s)+βˆ‘sΛ‰Ξ·(sΛ‰)βˆ‘aΟ€(a∣sΛ‰)p(s∣sΛ‰,a),Β forΒ allΒ s∈S.\eta(s) = h(s) + \sum_{\bar s} \eta(\bar s) \sum_a \pi(a|\bar s)p(s|\bar s,a), \text{ for all }s\in\mathcal S.

This system of equations can be solved for the expected number of visits Ξ·(s)\eta(s). The on-policy distribution is then the fraction of time spent in each state normalized to sum to one:

ΞΌ(s)=Ξ·(s)βˆ‘sΛ‰Ξ·(sΛ‰),Β forΒ allΒ s∈S.\mu(s) = \frac{\eta(s)}{\sum_{\bar s}\eta(\bar s)}, \text{ for all }s\in\mathcal S.

Where the denominator in the previous formula is simply TT, the number of time steps in the episode.

To make our formulas a little shorter let us define pΟ€(s∣sΛ‰)p^\pi(s|\bar s) as the probability of going from state sΛ‰\bar s to state ss under policy Ο€\pi:

pΟ€(s∣sΛ‰)β‰βˆ‘aΟ€(a∣sΛ‰)p(s∣sΛ‰,a)p^\pi(s|\bar s) \doteq \sum_a \pi(a|\bar s)p(s|\bar s,a)

Then:

Ξ·(s)=h(s)+βˆ‘sΛ‰Ξ·(sΛ‰)βˆ‘aΟ€(a∣sΛ‰)p(s∣sΛ‰,a)=h(s)+βˆ‘sΛ‰Ξ·(sΛ‰)pΟ€(s∣sΛ‰)β€…β€ŠβŸΊβ€…β€Šβˆ‘sh(s)=βˆ‘s[Ξ·(s)βˆ’βˆ‘sΛ‰Ξ·(sΛ‰)pΟ€(s∣sΛ‰)]β€…β€ŠβŸΊβ€…β€Š1=βˆ‘sΞ·(s)βˆ’βˆ‘sβˆ‘sΛ‰Ξ·(sΛ‰)pΟ€(s∣sΛ‰)=βˆ‘sΞ·(s)βˆ’βˆ‘sΛ‰Ξ·(sΛ‰)βˆ‘spΟ€(s∣sΛ‰)=βˆ‘sΞ·(s)βˆ’βˆ‘sΛ‰Ξ·(sΛ‰)=0\begin{aligned} \eta(s) &= h(s) + \sum_{\bar s} \eta(\bar s) \sum_a \pi(a|\bar s)p(s|\bar s,a) \\ &= h(s) + \sum_{\bar s} \eta(\bar s) p^\pi(s|\bar s) & \iff \\ \sum_s h(s) &= \sum_s \left[ \eta(s) - \sum_{\bar s} \eta(\bar s) p^\pi(s|\bar s) \right] & \iff \\ 1 &= \sum_s \eta(s) - \sum_s \sum_{\bar s} \eta(\bar s) p^\pi(s|\bar s) \\ &= \sum_s \eta(s) - \sum_{\bar s} \eta(\bar s) \xcancel{\sum_s p^\pi(s|\bar s)} \\ &= \sum_s \eta(s) - \sum_{\bar s} \eta(\bar s) = 0 \end{aligned}

Wat. Can you spot the problem with the previous equations? Try to figure it out or keep reading to know the answer.

The problem

To see what is going on, let us take a very simple problem, with two states, AA and BB, where the latter is terminal. Starting from AA, under pΟ€p^\pi, we stay in AA with probability Ξ±\alpha, and move to BB with probability 1βˆ’Ξ±1-\alpha.

Note that the number of timesteps TT of an episode is not fixed, because the problem is not deterministic, but it is finite, because we are always going to reach BB at some point.

Let us assume that all episodes start in AA and compute Ξ·\eta for each state:

Ξ·(A)=h(A)+Ξ·(A)pΟ€(A∣A)+Ξ·(B)pΟ€(A∣B)=1+Ξ·(A)Ξ±+Ξ·(B)0β€…β€ŠβŸΊβ€…β€ŠΞ·(A)=11βˆ’Ξ±Ξ·(B)=h(B)+Ξ·(A)pΟ€(B∣A)+Ξ·(B)pΟ€(B∣B)=0+1βˆ’Ξ±1βˆ’Ξ±+Ξ·(B)pΟ€(B∣B)=1+Ξ·(B)pΟ€(B∣B)β€…β€ŠβŸΊβ€…β€ŠΞ·(B)=11βˆ’pΟ€(B∣B)\begin{aligned} \eta(A) &= h(A) + \eta(A)p^\pi(A|A) + \eta(B)p^\pi(A|B) \\ &= 1 + \eta(A)\alpha + \eta(B)0 \iff \\ \eta(A) &= \frac{1}{1-\alpha} \\ \eta(B) &= h(B) + \eta(A)p^\pi(B|A) + \eta(B)p^\pi(B|B) \\ &= 0 + \frac{1-\alpha}{1-\alpha} + \eta(B)p^\pi(B|B) \\ &= 1 + \eta(B)p^\pi(B|B) \iff \\ \eta(B) &= \frac{1}{1-p^\pi(B|B)} \end{aligned}

This is where the problem begins. What is the value of pΟ€(B∣B)p^\pi(B|B)?

First attempt (T<∞)(T<\infty)

We know that all episodes terminate as soon as we reach BB. Therefore, at least intuitively, we would like the value of the average count Ξ·(B)\eta(B) to be 11, since we are always going to be in BB exactly once.

The only way to make this happen, starting from our previous formula,

Ξ·(B)=1+Ξ·(B)pΟ€(B∣B)\eta(B) = 1 + \eta(B)p^\pi(B|B)

is by defining the probability of going from BB to BB to be 00:

Ξ·(B)=1+Ξ·(B)0=1\eta(B) = 1 + \eta(B)0 = 1

This seems to work nicely, both Ξ·(A)\eta(A) and Ξ·(B)\eta(B) are what we expect. Then, why is the formula at the beginning of the article not working?

...

Because our probabilities do not add up! The sum of the probabilities of going from BB to every other state should be 11, but is actually 0:

βˆ‘spΟ€(s∣B)=pΟ€(A∣B)+pΟ€(B∣B)=0+0=0\sum_s p^\pi(s|B) = p^\pi(A|B) + p^\pi(B|B) = 0 + 0 = 0

This, quite obviously, makes our original formula break when we assume

βˆ‘spΟ€(s∣sΛ‰)=1.\sum_s p^\pi(s|\bar s) = 1.

Umhh ... can we find an other way?

Second attempt (T=∞)(T=\infty)

Let us try with an other option. If we consider BB to be an absorbing state, then pΟ€(B∣B)p^\pi(B|B) should be 11. However, if we substitute that value we obtain something weird:

Ξ·(B)=11βˆ’pΟ€(B∣B)=10\eta(B) = \frac{1}{1-p^\pi(B|B)} = \frac{1}{0}

To see what is going on remember that Ξ·(s)\eta(s) represents "the number of time steps spent, on average, in state ss". By definition, when reaching an absorbing state, like BB, we stay there forever. Thus, as soon as we reach BB, we start looping, increasing indefinitely both the total number of timesteps TT and our count Ξ·(B).\eta(B).

This can be seen by taking the limit for pΟ€(B∣B)=1p^\pi(B|B)=1 in our previous formula:

lim⁑pΟ€(B∣B)β†’1+11βˆ’pΟ€(B∣B)=+∞\lim_{p^\pi(B|B)\to1^+} \frac{1}{1-p^\pi(B|B)} = +\infty

Well, then, our formula is not wrong ... at least in the limit. However, our average count Ξ·(B)\eta(B) goes to infinity, making our on-policy distribution ΞΌ(s)\mu(s) not so useful anymore. This is because BB would take all the weight, leading to ΞΌ(B)=1\mu(B)=1, and ΞΌ(s)=0\mu(s)=0 everywhere else.

This outcome too, is not satisfying. So, what should we conclude from all this?

Conclusion

  • In our first case, where the number of timesteps TT is finite, we cannot treat terminal states like every other state, otherwise their probabilities would not add up.
  • Similarly, in the second case, with infinite timesteps TT, we cannot treat absorbing states like every other state (unless we introduce discounting), otherwise the expected number of visits Ξ·(s)\eta(s) becomes infinite.

One solution that works for both cases, is to partition the states in two sets S\mathcal S, containing all states, except terminal states, and S+\mathcal S^+ containing every state, including terminal ones. Then, rewrite our original formula as:

Ξ·(s)=h(s)+βˆ‘sΛ‰βˆˆSΞ·(sΛ‰)pΟ€(s∣sΛ‰)\eta(s) = h(s) + \sum_{\bar s\in \mathcal S} \eta(\bar s) p^\pi(s|\bar s)

Finally, let us revisit our original proof:

1=βˆ‘s∈S+Ξ·(s)βˆ’βˆ‘s∈S+βˆ‘sΛ‰βˆˆSΞ·(sΛ‰)pΟ€(s∣sΛ‰)=βˆ‘s∈S+Ξ·(s)βˆ’βˆ‘sΛ‰βˆˆSΞ·(sΛ‰)βˆ‘s∈S+pΟ€(s∣sΛ‰)=βˆ‘s∈S+Ξ·(s)βˆ’βˆ‘sΛ‰βˆˆSΞ·(sΛ‰)=βˆ‘s∈S+βˆ–SΞ·(s)\begin{aligned} 1 &= \sum_{s\in\mathcal{S}^+} \eta(s) - \sum_{s\in\mathcal{S}^+} \sum_{\bar{s}\in\mathcal S} \eta(\bar s) p^\pi(s|\bar s) \\ &= \sum_{s\in\mathcal{S}^+} \eta(s) - \sum_{\bar{s}\in\mathcal S} \eta(\bar s) \xcancel{\sum_{s\in\mathcal{S}^+} p^\pi(s|\bar s)} \\ &= \sum_{s\in\mathcal{S}^+} \eta(s) - \sum_{\bar{s}\in\mathcal S} \eta(\bar s) \\ &= \sum_{s\in\mathcal{S^+\setminus S}} \eta(s) \end{aligned}

This is saying that, on average, in an episode, we are going to spend one timestep in terminal states.

It works! πŸŽ‰

That's it, in the end we did not manage to prove 1=01=0, but hopefully we learned something interesting along the way!