Policy improvement theorem in reinforcement learning, how to prove it?

119 Views Asked by At

I'm watching the deep mind course of david Silver on reinforcement learning, lecture 3.

At around 46:00, he gives a proof for the policy improvement theorem, stated below. However, I'm struggling with the most essential step of the proof.

I won't define everything in the theorem, the elements are defined in the video serie. Here's the theorem.

Given a policy $\pi$,

evaluate the policy $\pi : v_\pi(s) = \mathbb{E}[R_{t+1}+\gamma R_{t+2} + ... | S_t = s] $

Define a new policy that choose the direct most rewarding action for every state :

$\pi'(s) = argmax_{a \in A} \quad q_{\pi}(s,a)$

with $q_{\pi}$ the state-action value function, computed using $v_\pi$.

By repeating those two steps iteratively, the policy will converge toward the optimal policy.

The proof :

The proof starts by showing that the state-action value when following $\pi'(s)$ for our next action is superior to the state value when following policy $\pi(s)$ :

by definition of the $\pi'(s)$ policy :

$q_{\pi}(s,\pi'(s)) = \max_{a \in A} \quad q_{\pi}(s,a) \geq q_{\pi}(s,\pi(s)) = v_{\pi}(s)$

this makes sense to me, what doesn't is the following line of the proof :

$v_{\pi}(s) \leq q_{\pi}(s,\pi'(s)) = \mathbb{E_{\pi'}}[R_{t+1} + \gamma v_{\pi}(S_{t+1})|S_t = s] $

This looks like the Bellman equation for the state value function :

$v_{\pi}(s) = \mathbb{E_{\pi}}[R_{t+1} + \gamma v_{\pi}(S_{t+1})|S_t = s]$

But we should apply the Bellman equation for the state-action value function :

$q_{\pi}(s,a) = \mathbb{E_{\pi}}[R_{t+1} + \gamma q_{\pi}(S_{t+1},A_{t+1})|S_t = s, A_t = a] $

By applying this equation to $q_{\pi}(s,\pi'(s))$ we get :

$q_{\pi}(s,\pi'(s)) = \mathbb{E_{\pi}}[R_{t+1} + \gamma q_{\pi}(S_{t+1},A_{t+1})|S_t = s, A_t = \pi'(s)] $

So here is my question :

How do we show that

$\mathbb{E_{\pi}}[R_{t+1} + \gamma q_{\pi}(S_{t+1},A_{t+1})|S_t = s, A_t = \pi'(s)] = \mathbb{E_{\pi'}}[R_{t+1} + \gamma v_{\pi}(S_{t+1})|S_t = s]$ ?

Those are two very different expressions and I can't connect the dots ! Your help would be very appreciated !