强化学习速查表
本文档提供强化学习常用概念、公式和 API 的快速参考。适合在学习或实践中快速查阅。
核心概念速查
MDP 五元组
| 元素 | 符号 | 说明 | 示例 |
|---|---|---|---|
| 状态空间 | 所有可能状态的集合 | 棋盘所有局面 | |
| 动作空间 | 所有可能动作的集合 | 上下左右移动 | |
| 转移概率 | $P(s' | s,a)$ | 状态转移概率 |
| 奖励函数 | 即时奖励 | 得分、惩罚 | |
| 折扣因子 | 未来奖励的折扣 | 通常 0.99 |
回报与价值函数
折扣回报:
状态价值函数:
动作价值函数:
优势函数:
贝尔曼方程
贝尔曼期望方程(状态价值):
贝尔曼期望方程(动作价值):
贝尔曼最优方程(状态价值):
贝尔曼最优方程(动作价值):
算法公式速查
表格型方法
Q-Learning
- 类型:离线策略(Off-Policy)
- 特点:学习最优策略的价值
- 适用:离散状态和动作空间
SARSA
- 类型:在线策略(On-Policy)
- 特点:学习当前策略的价值
- 适用:需要安全性的场景
预期 SARSA
- 类型:离线策略
- 特点:方差比 SARSA 小
深度强化学习
DQN 损失函数
其中 是目标网络参数。
Double DQN
Dueling DQN
其中:
策略梯度方法
策略梯度定理
REINFORCE
带基线的策略梯度
常用基线:
Actor-Critic 损失
PPO-Clip
其中 是重要性权重。
SAC(Soft Actor-Critic)
SAC 的核心是最大熵强化学习,目标函数为:
Q 网络目标值:
策略损失:
- 类型:离线策略(Off-Policy)
- 特点:熵正则化、自动探索、样本效率高
- 适用:连续动作空间
GAE(广义优势估计)
其中 是 TD 误差。
连续动作空间
高斯策略
SAC(Soft Actor-Critic)
熵正则化目标:
Gymnasium API 速查
基本使用
import gymnasium as gym
env = gym.make('CartPole-v1')
obs, info = env.reset(seed=42)
obs, reward, terminated, truncated, info = env.step(action)
env.close()
step() 返回值
| 返回值 | 类型 | 说明 |
|---|---|---|
| observation | np.ndarray | 新的观测 |
| reward | float | 即时奖励 |
| terminated | bool | 是否自然终止 |
| truncated | bool | 是否被截断 |
| info | dict | 辅助信息 |
空间类型
from gymnasium import spaces
spaces.Box(low=0, high=1, shape=(4,), dtype=np.float32)
spaces.Discrete(2)
spaces.MultiDiscrete([2, 3, 2])
spaces.MultiBinary(4)
spaces.Dict({'obs': spaces.Box(...), 'action': spaces.Discrete(2)})
spaces.Tuple((spaces.Box(...), spaces.Discrete(3)))
常用包装器
from gymnasium.wrappers import (
TimeLimit, RecordVideo, NormalizeObservation,
NormalizeReward, RecordEpisodeStatistics,
FrameStackObservation, TransformObservation
)
env = TimeLimit(env, max_episode_steps=500)
env = RecordVideo(env, video_folder='./videos')
env = NormalizeObservation(env)
env = NormalizeReward(env, gamma=0.99)
env = RecordEpisodeStatistics(env)
env = FrameStackObservation(env, stack_size=4)
向量化环境
from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv, make_vec_env
envs = SyncVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(4)])
envs = AsyncVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(4)])
envs = make_vec_env('CartPole-v1', n_envs=4, parallel=True)
Stable Baselines3 API 速查
创建模型
from stable_baselines3 import PPO, DQN, SAC, TD3, A2C
model = PPO('MlpPolicy', env, learning_rate=3e-4, verbose=1)
model = DQN('MlpPolicy', env, learning_rate=1e-4, buffer_size=100000)
model = SAC('MlpPolicy', env, learning_rate=3e-4)
model = TD3('MlpPolicy', env, learning_rate=3e-4)
model = A2C('MlpPolicy', env, learning_rate=7e-4)
model = PPO('CnnPolicy', env)
训练和预测
model.learn(total_timesteps=100000)
action, _states = model.predict(observation, deterministic=True)
action, _states = model.predict(observation, deterministic=False)
保存和加载
model.save('ppo_cartpole')
model = PPO.load('ppo_cartpole', env=env)
model.save_replay_buffer('replay_buffer')
model.load_replay_buffer('replay_buffer')
回调函数
from stable_baselines3.common.callbacks import (
EvalCallback, CheckpointCallback, CallbackList, BaseCallback
)
eval_callback = EvalCallback(
eval_env,
best_model_save_path='./best/',
log_path='./logs/',
eval_freq=10000,
n_eval_episodes=5,
deterministic=True
)
checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path='./checkpoints/',
name_prefix='model'
)
model.learn(total_timesteps=100000, callback=CallbackList([
eval_callback, checkpoint_callback
]))
自定义策略网络
import torch.nn as nn
policy_kwargs = dict(
net_arch=[64, 64],
activation_fn=nn.ReLU,
optimizer_class=torch.optim.Adam,
optimizer_kwargs=dict(lr=3e-4)
)
model = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs)
超参数推荐
PPO
| 参数 | 推荐值 | 说明 |
|---|---|---|
| learning_rate | 3e-4 | 学习率 |
| n_steps | 2048 | 每次更新收集的步数 |
| batch_size | 64 | 小批量大小 |
| n_epochs | 10 | 每次更新的训练轮数 |
| gamma | 0.99 | 折扣因子 |
| gae_lambda | 0.95 | GAE 参数 |
| clip_range | 0.2 | PPO 裁剪参数 |
| ent_coef | 0.01 | 熵系数 |
| vf_coef | 0.5 | 价值函数系数 |
DQN
| 参数 | 推荐值 | 说明 |
|---|---|---|
| learning_rate | 1e-4 | 学习率 |
| buffer_size | 100000 | 经验回放缓冲区大小 |
| batch_size | 32 或 64 | 小批量大小 |
| gamma | 0.99 | 折扣因子 |
| target_update_interval | 1000 ~ 10000 | 目标网络更新频率 |
| exploration_fraction | 0.1 | 探索比例 |
| exploration_final_eps | 0.01 | 最终探索率 |
| train_freq | 4 | 训练频率 |
| gradient_steps | 1 | 每次训练的梯度步数 |
SAC
| 参数 | 推荐值 | 说明 |
|---|---|---|
| learning_rate | 3e-4 | 学习率 |
| buffer_size | 1000000 | 经验回放缓冲区大小 |
| batch_size | 256 | 小批量大小 |
| gamma | 0.99 | 折扣因子 |
| tau | 0.005 | 软更新系数 |
| ent_coef | 'auto' | 熵系数(自动调整) |
| target_update_interval | 1 | 目标网络更新频率 |
| gradient_steps | 1 | 每次训练的梯度步数 |
TD3
| 参数 | 推荐值 | 说明 |
|---|---|---|
| learning_rate | 3e-4 | 学习率 |
| buffer_size | 1000000 | 经验回放缓冲区大小 |
| batch_size | 256 | 小批量大小 |
| gamma | 0.99 | 折扣因子 |
| tau | 0.005 | 软更新系数 |
| policy_delay | 2 | 策略更新延迟 |
| target_policy_noise | 0.2 | 目标策略平滑噪声 |
| target_noise_clip | 0.5 | 目标噪声裁剪 |
PyTorch 常用代码
策略网络
import torch
import torch.nn as nn
from torch.distributions import Categorical
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, x):
return self.network(x)
def get_action(self, x):
probs = self.forward(x)
dist = Categorical(probs)
action = dist.sample()
return action, dist.log_prob(action), dist.entropy()
价值网络
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.network(x)
经验回放
from collections import deque
import random
import numpy as np
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states),
np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(next_states),
np.array(dones, dtype=np.float32)
)
def __len__(self):
return len(self.buffer)
GAE 计算
def compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95):
advantages = []
gae = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_val = next_value
else:
next_val = values[t + 1]
delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages.insert(0, gae)
returns = [a + v for a, v in zip(advantages, values)]
return advantages, returns
折扣回报计算
def compute_returns(rewards, gamma=0.99):
returns = []
R = 0
for r in reversed(rewards):
R = r + gamma * R
returns.insert(0, R)
return returns
常见环境
经典控制
| 环境 | 动作空间 | 观测空间 | 说明 |
|---|---|---|---|
| CartPole-v1 | Discrete(2) | Box(4,) | 倒立摆 |
| MountainCar-v0 | Discrete(3) | Box(2,) | 山地车 |
| Acrobot-v1 | Discrete(3) | Box(6,) | 欠驱动摆 |
| Pendulum-v1 | Box(1,) | Box(3,) | 钟摆(连续动作) |
Box2D
| 环境 | 动作空间 | 观测空间 | 说明 |
|---|---|---|---|
| LunarLander-v2 | Discrete(4) | Box(8,) | 月球着陆 |
| LunarLanderContinuous-v2 | Box(2,) | Box(8,) | 月球着陆(连续) |
| BipedalWalker-v3 | Box(4,) | Box(24,) | 双足行走 |
MuJoCo
| 环境 | 动作空间 | 观测空间 | 说明 |
|---|---|---|---|
| HalfCheetah-v4 | Box(6,) | Box(17,) | 半猎豹 |
| Humanoid-v4 | Box(17,) | Box(376,) | 人形机器人 |
| Ant-v4 | Box(8,) | Box(111,) | 蚂蚁 |
| Reacher-v4 | Box(2,) | Box(11,) | 机械臂 |
| Hopper-v4 | Box(3,) | Box(11,) | 单腿跳跃 |
算法选择指南
| 场景 | 推荐算法 | 原因 |
|---|---|---|
| 离散动作,简单任务 | DQN | 简单高效 |
| 离散动作,复杂任务 | PPO | 稳定可靠 |
| 连续动作,简单任务 | SAC | 样本效率高 |
| 连续动作,复杂任务 | PPO / SAC / TD3 | 根据具体任务选择 |
| 样本效率要求高 | SAC / TD3 | 离线策略,可重用数据 |
| 训练稳定性要求高 | PPO | 更新步长可控 |
| 多进程训练 | PPO / A2C | 支持向量化环境 |
| 稀疏奖励 | HER + DQN/SAC | 目标重标记 |
调试技巧
检查环境
from gymnasium.utils.env_checker import check_env
check_env(env, warn=True)
监控训练
from stable_baselines3.common.callbacks import BaseCallback
class TensorboardCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self):
if self.n_calls % 100 == 0:
self.logger.record('custom/episode_reward',
self.locals.get('episode_reward', 0))
return True
梯度检查
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_mean={param.grad.mean():.6f}, "
f"grad_std={param.grad.std():.6f}")
常见问题排查
| 问题 | 可能原因 | 解决方案 |
|---|---|---|
| 奖励不增长 | 学习率太大/太小 | 调整学习率 |
| 奖励震荡剧烈 | 更新步长太大 | 减小学习率或增加 batch_size |
| 奖励突然下降 | 策略崩溃 | 使用 PPO 或减小更新幅度 |
| 探索不足 | ε 衰减太快 | 减慢 ε 衰减 |
| 过拟合 | 训练太久 | 早停或正则化 |
参考资源
官方文档
经典教材
经典论文
- DQN: Playing Atari with Deep Reinforcement Learning (2013)
- Double DQN: Deep Reinforcement Learning with Double Q-learning (2015)
- Dueling DQN: Dueling Network Architectures for Deep RL (2016)
- PPO: Proximal Policy Optimization Algorithms (2017)
- SAC: Soft Actor-Critic (2018)
- TD3: Twin Delayed DDPG (2018)