Generalized Advantage Estimation
Bootstrapping itself has trade-offs, the fewer step look-a-head often lead to higher bias, while Monte-Carlo method might lead to high-variance (as it can be far off trajectory from expectation). So here we have a weighted n-step estimator for the advantage function, Note that this is the whole starting point instead rely one single step bootstrapping, we look for ways that can include weighted consideration from all future bootstrapping steps We can get the following (Page4 GAE paper)
Which is n-step bootstrap for our advantage estimation, however, instead of using the n-step bootstrap, we introduce a parameter for exponential average.
Note: In the infinite horizon setting, the normalization factor is essential, without it it would be just cumulative of variance and bias, very problematic.
Which we get the following
When we have Monte-Carlo estimation which has high variance
And when we have single-step bootstrap which has high bias
Below is a batch implementation of GAE
def GAE(rewards, values, dones, lam, gam):
B, T = rewards.shape
dtype = rewards.dtype
device = rewards.device
gae = torch.zeros((B,), dtype=dtype, device=device)
advantage = torch.zeros((B, T), dtype=dtype, device=device)
for t in range(T, -1, -1):
not_done = 1 - dones[:, t]
delta = rewards[:, t] + gam * not_done * values[t + 1] - values[t]
gae = delta + lam * gam * not_done * gae
advantage[:, t] = gae
returns = gae + values[:, :-1]
return advantage, returns
In practice implementation for research, however, we often assume our problem is finite which we need to do some modification in the normalization term, we need to calculate the appropriate lam_coef_sum (see the above derivation of GAE), the following is the code of CleanRL implementation, note that instead of using the GAE in the paper, they consider the case of finite horizon GAE so the nice exponential is gone and need the following normalization.
if args.finite_horizon_gae:
"""
See GAE paper equation(16) line 1, we will compute the GAE based on this line only
1 *( -V(s_t) + r_t + gamma * V(s_{t+1}) )
lambda *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * V(s_{t+2}) )
lambda^2 *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + ... )
lambda^3 *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + gamma^3 * r_{t+3}
We then normalize it by the sum of the lambda^i (instead of 1-lambda)
"""
if t == args.num_steps - 1: # initialize
lam_coef_sum = 0.
reward_term_sum = 0. # the sum of the second term
value_term_sum = 0. # the sum of the third term
lam_coef_sum = lam_coef_sum * next_not_done
reward_term_sum = reward_term_sum * next_not_done
value_term_sum = value_term_sum * next_not_done
lam_coef_sum = 1 + args.gae_lambda * lam_coef_sum
reward_term_sum = args.gae_lambda * args.gamma * reward_term_sum + lam_coef_sum * rewards[t]
value_term_sum = args.gae_lambda * args.gamma * value_term_sum + args.gamma * real_next_values
advantages[t] = (reward_term_sum + value_term_sum) / lam_coef_sum - values[t]