跳到主要内容

强化学习速查表

本文档提供强化学习常用概念、公式和 API 的快速参考。

核心概念

MDP 五元组

元素符号说明
状态空间SS所有可能状态的集合
动作空间AA所有可能动作的集合
转移概率$P(s's,a)$
奖励函数R(s,a,s)R(s,a,s')即时奖励
折扣因子γ\gamma未来奖励的折扣

价值函数

状态价值函数Vπ(s)=Eπ[t=0γtRtS0=s]V^\pi(s) = \mathbb{E}_\pi\left[\sum_{t=0}^{\infty} \gamma^t R_t | S_0 = s\right]

动作价值函数Qπ(s,a)=Eπ[t=0γtRtS0=s,A0=a]Q^\pi(s,a) = \mathbb{E}_\pi\left[\sum_{t=0}^{\infty} \gamma^t R_t | S_0 = s, A_0 = a\right]

优势函数A(s,a)=Q(s,a)V(s)A(s,a) = Q(s,a) - V(s)

贝尔曼方程

贝尔曼期望方程Vπ(s)=aπ(as)sP(ss,a)[R(s,a,s)+γVπ(s)]V^\pi(s) = \sum_a \pi(a|s) \sum_{s'} P(s'|s,a)[R(s,a,s') + \gamma V^\pi(s')]

贝尔曼最优方程V(s)=maxasP(ss,a)[R(s,a,s)+γV(s)]V^*(s) = \max_a \sum_{s'} P(s'|s,a)[R(s,a,s') + \gamma V^*(s')]

算法公式

Q-Learning

Q(s,a)Q(s,a)+α[r+γmaxaQ(s,a)Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha[r + \gamma \max_{a'} Q(s',a') - Q(s,a)]

SARSA

Q(s,a)Q(s,a)+α[r+γQ(s,a)Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha[r + \gamma Q(s',a') - Q(s,a)]

策略梯度

θJ(θ)=E[θlogπθ(as)Qπθ(s,a)]\nabla_\theta J(\theta) = \mathbb{E}[\nabla_\theta \log \pi_\theta(a|s) Q^{\pi_\theta}(s,a)]

PPO-Clip

LCLIP=E[min(rtA^t,clip(rt,1ϵ,1+ϵ)A^t)]L^{CLIP} = \mathbb{E}\left[\min(r_t \hat{A}_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \hat{A}_t)\right]

GAE

A^tGAE=l=0(γλ)lδt+l\hat{A}_t^{GAE} = \sum_{l=0}^{\infty}(\gamma \lambda)^l \delta_{t+l}

其中 δt=rt+γV(st+1)V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)

Gymnasium API

基本使用

import gymnasium as gym

env = gym.make('CartPole-v1')
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
env.close()

空间类型

from gymnasium import spaces

spaces.Box(low=0, high=1, shape=(4,)) # 连续空间
spaces.Discrete(2) # 离散空间
spaces.MultiDiscrete([2, 3, 2]) # 多维离散
spaces.MultiBinary(4) # 多维二进制
spaces.Dict({'obs': spaces.Box(...), ...}) # 字典空间
spaces.Tuple((spaces.Box(...), ...)) # 元组空间

包装器

from gymnasium.wrappers import (
TimeLimit, RecordVideo, NormalizeObservation,
NormalizeReward, FlattenObservation
)

env = TimeLimit(env, max_episode_steps=500)
env = RecordVideo(env, video_folder='./videos')
env = NormalizeObservation(env)
env = NormalizeReward(env)

Stable Baselines3 API

创建模型

from stable_baselines3 import PPO, DQN, SAC, TD3, A2C

model = PPO('MlpPolicy', env)
model = DQN('MlpPolicy', env)
model = SAC('MlpPolicy', env)
model = TD3('MlpPolicy', env)
model = A2C('MlpPolicy', env)

训练和预测

model.learn(total_timesteps=100000)

action, _ = model.predict(observation, deterministic=True)

保存和加载

model.save('model_path')
model = PPO.load('model_path')

回调函数

from stable_baselines3.common.callbacks import (
EvalCallback, CheckpointCallback, CallbackList
)

eval_callback = EvalCallback(
eval_env,
best_model_save_path='./best/',
eval_freq=10000
)

checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path='./checkpoints/'
)

model.learn(total_timesteps=100000, callback=CallbackList([
eval_callback, checkpoint_callback
]))

超参数推荐

PPO

参数推荐值
learning_rate3e-4
n_steps2048
batch_size64
n_epochs10
gamma0.99
gae_lambda0.95
clip_range0.2

DQN

参数推荐值
learning_rate1e-4
buffer_size100000
batch_size32
gamma0.99
target_update_interval1000
exploration_fraction0.1

SAC

参数推荐值
learning_rate3e-4
buffer_size1000000
batch_size256
gamma0.99
tau0.005
ent_coefauto

PyTorch 常用代码

策略网络

import torch
import torch.nn as nn
from torch.distributions import Categorical

class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)

def forward(self, x):
return self.network(x)

def get_action(self, x):
probs = self.forward(x)
dist = Categorical(probs)
action = dist.sample()
return action, dist.log_prob(action), dist.entropy()

价值网络

class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

def forward(self, x):
return self.network(x)

经验回放

from collections import deque
import random

class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)

def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))

def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(states),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(next_states),
torch.FloatTensor(dones)
)

GAE 计算

def compute_gae(rewards, values, next_value, dones, gamma=0.99, lam=0.95):
advantages = []
gae = 0

for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_val = next_value
else:
next_val = values[t + 1]

delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages.insert(0, gae)

returns = [a + v for a, v in zip(advantages, values)]
return advantages, returns

常见环境

经典控制

环境动作空间说明
CartPole-v1Discrete(2)倒立摆
MountainCar-v0Discrete(3)山地车
Acrobot-v1Discrete(3)欠驱动摆
Pendulum-v1Box(1)钟摆

Box2D

环境动作空间说明
LunarLander-v2Discrete(4)月球着陆
BipedalWalker-v3Box(4)双足行走

MuJoCo

环境动作空间说明
HalfCheetah-v4Box(6)半猎豹
Humanoid-v4Box(17)人形机器人
Ant-v4Box(8)蚂蚁
Reacher-v4Box(2)机械臂

算法选择指南

场景推荐算法
离散动作,简单任务DQN
离散动作,复杂任务PPO
连续动作,简单任务SAC
连续动作,复杂任务PPO / SAC / TD3
样本效率要求高SAC
训练稳定性要求高PPO
多进程训练PPO / A2C

调试技巧

检查环境

from gymnasium.utils.env_checker import check_env

check_env(env, warn=True)

监控训练

from stable_baselines3.common.callbacks import BaseCallback

class DebugCallback(BaseCallback):
def _on_step(self):
print(f"Step: {self.num_timesteps}")
print(f"Episode reward: {self.locals.get('episode_reward', 'N/A')}")
return True

梯度检查

for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name}: grad_mean={param.grad.mean():.6f}")

参考资源