跳到主要内容

PPO 近端策略优化

PPO(Proximal Policy Optimization,近端策略优化)是 OpenAI 在 2017 年提出的强化学习算法,它解决了策略梯度方法中策略更新步长难以控制的问题。PPO 是目前最流行、最稳定的策略梯度算法之一,被广泛应用于各种强化学习任务。

PPO 的核心思想

策略梯度的问题

传统策略梯度方法存在一个关键问题:策略更新步长难以控制。

  • 步长太大:策略可能剧烈变化,导致性能骤降,甚至无法恢复
  • 步长太小:学习速度太慢,需要大量样本

PPO 的解决方案

PPO 通过限制策略更新的幅度来解决这个问题:

  1. 重要性采样:使用旧策略收集的数据来估计新策略的梯度
  2. 裁剪目标函数:限制新策略与旧策略的差异

重要性采样

基本原理

重要性采样允许我们使用旧策略 πθold\pi_{\theta_{old}} 收集的数据来估计新策略 πθ\pi_\theta 的期望:

Eπθ[f(x)]=Eπθold[πθ(x)πθold(x)f(x)]\mathbb{E}_{\pi_\theta}[f(x)] = \mathbb{E}_{\pi_{\theta_{old}}}\left[\frac{\pi_\theta(x)}{\pi_{\theta_{old}}(x)} f(x)\right]

重要性权重

定义重要性权重:

rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}

策略梯度可以重写为:

LCPI(θ)=Et[rt(θ)A^t]L^{CPI}(\theta) = \mathbb{E}_t\left[r_t(\theta) \hat{A}_t\right]

其中 A^t\hat{A}_t 是优势函数的估计,CPI 表示保守策略迭代。

PPO-Clip 目标函数

PPO 的核心是裁剪目标函数:

LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]L^{CLIP}(\theta) = \mathbb{E}_t\left[\min\left(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t\right)\right]

理解裁剪机制

裁剪函数 clip(rt(θ),1ϵ,1+ϵ)\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) 将重要性权重限制在 [1ϵ,1+ϵ][1-\epsilon, 1+\epsilon] 范围内。

当优势 A^t>0\hat{A}_t > 0 时:

  • 动作比平均好,应该增加选择概率
  • 但如果 rt(θ)>1+ϵr_t(\theta) > 1+\epsilon,裁剪会阻止进一步增加

当优势 A^t<0\hat{A}_t < 0 时:

  • 动作比平均差,应该降低选择概率
  • 但如果 rt(θ)<1ϵr_t(\theta) < 1-\epsilon,裁剪会阻止进一步降低

为什么有效?

裁剪机制确保:

  • 新策略不会偏离旧策略太远
  • 即使优势估计不准确,也不会造成灾难性更新
  • 训练过程更加稳定

PPO 完整算法

算法流程

初始化策略参数 θ 和价值参数 φ
对于每次迭代:
使用当前策略收集一组轨迹
计算优势估计 Â_t
进行多轮更新:
计算重要性权重 r_t(θ)
计算 PPO-Clip 目标 L^{CLIP}
计算价值函数损失 L^{VF}
计算熵奖励 S
更新参数:θ, φ = argmax(L^{CLIP} - c_1 L^{VF} + c_2 S)

Python 实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.distributions import Categorical

class PPOActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super().__init__()

self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)

self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)

def forward(self, x):
action_probs = self.actor(x)
state_value = self.critic(x)
return action_probs, state_value

def get_action(self, state):
probs, value = self.forward(state)
dist = Categorical(probs)
action = dist.sample()
return action, dist.log_prob(action), dist.entropy(), value

class PPO:
def __init__(self, state_dim, action_dim, hidden_dim=64, lr=3e-4,
gamma=0.99, lam=0.95, clip_epsilon=0.2, value_coef=0.5,
entropy_coef=0.01, max_grad_norm=0.5, update_epochs=10,
mini_batch_size=64):
self.model = PPOActorCritic(state_dim, action_dim, hidden_dim)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

self.gamma = gamma
self.lam = lam
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.update_epochs = update_epochs
self.mini_batch_size = mini_batch_size

def compute_gae(self, rewards, values, dones, next_value):
advantages = []
gae = 0

for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_val = next_value
else:
next_val = values[t + 1]

delta = rewards[t] + self.gamma * next_val * (1 - dones[t]) - values[t]
gae = delta + self.gamma * self.lam * (1 - dones[t]) * gae
advantages.insert(0, gae)

returns = [a + v for a, v in zip(advantages, values)]
return advantages, returns

def update(self, states, actions, log_probs, returns, advantages):
states = torch.FloatTensor(np.array(states))
actions = torch.LongTensor(actions)
old_log_probs = torch.FloatTensor(log_probs)
returns = torch.FloatTensor(returns)
advantages = torch.FloatTensor(advantages)

advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

for _ in range(self.update_epochs):
for i in range(0, len(states), self.mini_batch_size):
batch_states = states[i:i+self.mini_batch_size]
batch_actions = actions[i:i+self.mini_batch_size]
batch_old_log_probs = old_log_probs[i:i+self.mini_batch_size]
batch_returns = returns[i:i+self.mini_batch_size]
batch_advantages = advantages[i:i+self.mini_batch_size]

probs, values = self.model(batch_states)
dist = Categorical(probs)
new_log_probs = dist.log_prob(batch_actions)
entropy = dist.entropy().mean()

ratio = torch.exp(new_log_probs - batch_old_log_probs)

surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon,
1 + self.clip_epsilon) * batch_advantages
policy_loss = -torch.min(surr1, surr2).mean()

value_loss = F.mse_loss(values.squeeze(), batch_returns)

loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy

self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()

return loss.item()

def collect_trajectories(self, env, num_steps=2048):
states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []

state, _ = env.reset()

for _ in range(num_steps):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
action, log_prob, _, value = self.model.get_action(state_tensor)

next_state, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated

states.append(state)
actions.append(action.item())
rewards.append(reward)
dones.append(done)
log_probs.append(log_prob.item())
values.append(value.item())

state = next_state
if done:
state, _ = env.reset()

state_tensor = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
_, _, _, next_value = self.model.get_action(state_tensor)

advantages, returns = self.compute_gae(rewards, values, dones, next_value.item())

return states, actions, log_probs, returns, advantages

def train(self, env, total_timesteps=100000, num_steps=2048):
rewards_history = []
episode_rewards = []
current_episode_reward = 0

timestep = 0
while timestep < total_timesteps:
states, actions, log_probs, returns, advantages = self.collect_trajectories(env, num_steps)

loss = self.update(states, actions, log_probs, returns, advantages)

for r, d in zip(states, []):
pass

timestep += num_steps

if (timestep // num_steps) % 10 == 0:
test_reward = self.evaluate(env)
rewards_history.append(test_reward)
print(f"Timestep {timestep}, Test Reward: {test_reward:.2f}")

return rewards_history

def evaluate(self, env, num_episodes=10):
total_rewards = []

for _ in range(num_episodes):
state, _ = env.reset()
episode_reward = 0
done = False

while not done:
state_tensor = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
probs, _ = self.model(state_tensor)
action = probs.argmax().item()

state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
episode_reward += reward

total_rewards.append(episode_reward)

return np.mean(total_rewards)

PPO 的关键技巧

广义优势估计(GAE)

GAE 平衡了偏差和方差:

A^tGAE=l=0(γλ)lδt+l\hat{A}_t^{GAE} = \sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_{t+l}

其中 δt=rt+γV(st+1)V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)

优势归一化

advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

梯度裁剪

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

价值函数裁剪

PPO 也可以对价值函数进行裁剪:

LVCLIP=max((Vθ(st)Rt)2,(VclipRt)2)L^{VCLIP} = \max\left((V_\theta(s_t) - R_t)^2, (V_{clip} - R_t)^2\right)

其中 Vclip=Vθold(st)+clip(Vθ(st)Vθold(st),ϵ,ϵ)V_{clip} = V_{\theta_{old}}(s_t) + \text{clip}(V_\theta(s_t) - V_{\theta_{old}}(s_t), -\epsilon, \epsilon)

超参数建议

超参数推荐值说明
学习率3e-4较小的学习率更稳定
裁剪参数 ε0.1 ~ 0.3控制策略更新幅度
GAE λ0.95优势估计的偏差-方差权衡
折扣因子 γ0.99大多数任务适用
更新轮数3 ~ 10每批数据更新的次数
小批量大小64 ~ 256根据内存调整
价值函数系数0.5价值损失的权重
熵系数0.01鼓励探索

使用 Stable Baselines3 的 PPO

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

env = make_vec_env('CartPole-v1', n_envs=4)

model = PPO(
"MlpPolicy",
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
verbose=1
)

model.learn(total_timesteps=100000)

obs = env.reset()
for _ in range(1000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()

PPO 的变体

PPO-Penalty

使用 KL 散度惩罚代替裁剪:

LKLPEN(θ)=Et[rt(θ)A^tβKL[πθold,πθ]]L^{KLPEN}(\theta) = \mathbb{E}_t\left[r_t(\theta) \hat{A}_t - \beta \text{KL}[\pi_{\theta_{old}}, \pi_\theta]\right]

PPO-Continuous

对于连续动作空间:

class PPOContinuous(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super().__init__()

self.actor_mean = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, action_dim)
)

self.actor_log_std = nn.Parameter(torch.zeros(action_dim))

self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)

def forward(self, x):
action_mean = self.actor_mean(x)
action_std = torch.exp(self.actor_log_std)
state_value = self.critic(x)
return action_mean, action_std, state_value

def get_action(self, state):
mean, std, value = self.forward(state)
dist = torch.distributions.Normal(mean, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(-1)
entropy = dist.entropy().sum(-1)
return action, log_prob, entropy, value

小结

PPO 是目前最实用的策略梯度算法:

  • 核心思想:限制策略更新幅度,保证训练稳定
  • 裁剪目标min(rtA^t,clip(rt,1ϵ,1+ϵ)A^t)\min(r_t \hat{A}_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \hat{A}_t)
  • 重要性采样:重复使用数据提高样本效率
  • GAE:平衡偏差和方差的优势估计
  • 易于调参:对超参数相对不敏感

PPO 在各种任务上都表现出色,是强化学习实践的首选算法之一。下一章将介绍 Gymnasium 环境接口。