强化学习速查表
本文档提供强化学习常用概念、公式和 API 的快速参考。
核心概念
MDP 五元组
| 元素 | 符号 | 说明 |
|---|---|---|
| 状态空间 | 所有可能状态的集合 | |
| 动作空间 | 所有可能动作的集合 | |
| 转移概率 | $P(s' | s,a)$ |
| 奖励函数 | 即时奖励 | |
| 折扣因子 | 未来奖励的折扣 |
价值函数
状态价值函数:
动作价值函数:
优势函数:
贝尔曼方程
贝尔曼期望方程:
贝尔曼最优方程:
算法公式
Q-Learning
SARSA
策略梯度
PPO-Clip
GAE
其中
Gymnasium API
基本使用
import gymnasium as gym
env = gym.make('CartPole-v1')
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
env.close()
空间类型
from gymnasium import spaces
spaces.Box(low=0, high=1, shape=(4,)) # 连续空间
spaces.Discrete(2) # 离散空间
spaces.MultiDiscrete([2, 3, 2]) # 多维离散
spaces.MultiBinary(4) # 多维二进制
spaces.Dict({'obs': spaces.Box(...), ...}) # 字典空间
spaces.Tuple((spaces.Box(...), ...)) # 元组空间
包装器
from gymnasium.wrappers import (
TimeLimit, RecordVideo, NormalizeObservation,
NormalizeReward, FlattenObservation
)
env = TimeLimit(env, max_episode_steps=500)
env = RecordVideo(env, video_folder='./videos')
env = NormalizeObservation(env)
env = NormalizeReward(env)
Stable Baselines3 API
创建模型
from stable_baselines3 import PPO, DQN, SAC, TD3, A2C
model = PPO('MlpPolicy', env)
model = DQN('MlpPolicy', env)
model = SAC('MlpPolicy', env)
model = TD3('MlpPolicy', env)
model = A2C('MlpPolicy', env)
训练和预测
model.learn(total_timesteps=100000)
action, _ = model.predict(observation, deterministic=True)
保存和加载
model.save('model_path')
model = PPO.load('model_path')
回调函数
from stable_baselines3.common.callbacks import (
EvalCallback, CheckpointCallback, CallbackList
)
eval_callback = EvalCallback(
eval_env,
best_model_save_path='./best/',
eval_freq=10000
)
checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path='./checkpoints/'
)
model.learn(total_timesteps=100000, callback=CallbackList([
eval_callback, checkpoint_callback
]))
超参数推荐
PPO
| 参数 | 推荐值 |
|---|---|
| learning_rate | 3e-4 |
| n_steps | 2048 |
| batch_size | 64 |
| n_epochs | 10 |
| gamma | 0.99 |
| gae_lambda | 0.95 |
| clip_range | 0.2 |
DQN
| 参数 | 推荐值 |
|---|---|
| learning_rate | 1e-4 |
| buffer_size | 100000 |
| batch_size | 32 |
| gamma | 0.99 |
| target_update_interval | 1000 |
| exploration_fraction | 0.1 |
SAC
| 参数 | 推荐值 |
|---|---|
| learning_rate | 3e-4 |
| buffer_size | 1000000 |
| batch_size | 256 |
| gamma | 0.99 |
| tau | 0.005 |
| ent_coef | auto |
PyTorch 常用代码
策略网络
import torch
import torch.nn as nn
from torch.distributions import Categorical
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, x):
return self.network(x)
def get_action(self, x):
probs = self.forward(x)
dist = Categorical(probs)
action = dist.sample()
return action, dist.log_prob(action), dist.entropy()
价值网络
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.network(x)
经验回放
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(states),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(next_states),
torch.FloatTensor(dones)
)
GAE 计算
def compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95):
advantages = []
gae = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_val = next_value
else:
next_val = values[t + 1]
delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages.insert(0, gae)
returns = [a + v for a, v in zip(advantages, values)]
return advantages, returns
常见环境
经典控制
| 环境 | 动作空间 | 说明 |
|---|---|---|
| CartPole-v1 | Discrete(2) | 倒立摆 |
| MountainCar-v0 | Discrete(3) | 山地车 |
| Acrobot-v1 | Discrete(3) | 欠驱动摆 |
| Pendulum-v1 | Box(1) | 钟摆 |
Box2D
| 环境 | 动作空间 | 说明 |
|---|---|---|
| LunarLander-v2 | Discrete(4) | 月球着陆 |
| BipedalWalker-v3 | Box(4) | 双足行走 |
MuJoCo
| 环境 | 动作空间 | 说明 |
|---|---|---|
| HalfCheetah-v4 | Box(6) | 半猎豹 |
| Humanoid-v4 | Box(17) | 人形机器人 |
| Ant-v4 | Box(8) | 蚂蚁 |
| Reacher-v4 | Box(2) | 机械臂 |
算法选择指南
| 场景 | 推荐算法 |
|---|---|
| 离散动作,简单任务 | DQN |
| 离散动作,复杂任务 | PPO |
| 连续动作,简单任务 | SAC |
| 连续动作,复杂任务 | PPO / SAC / TD3 |
| 样本效率要求高 | SAC |
| 训练稳定性要求高 | PPO |
| 多进程训练 | PPO / A2C |
调试技巧
检查环境
from gymnasium.utils.env_checker import check_env
check_env(env, warn=True)
监控训练
from stable_baselines3.common.callbacks import BaseCallback
class DebugCallback(BaseCallback):
def _on_step(self):
print(f"Step: {self.num_timesteps}")
print(f"Episode reward: {self.locals.get('episode_reward', 'N/A')}")
return True
梯度检查
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_mean={param.grad.mean():.6f}")