Stable Baselines3
Stable Baselines3(简称 SB3)是一个提供可靠强化学习算法实现的 Python 库。它基于 PyTorch 构建,提供了清晰、一致且经过充分测试的算法实现,是进行强化学习研究和应用的首选工具。
SB3 简介
为什么选择 SB3?
- 可靠性:所有算法都经过严格测试和验证
- 易用性:简洁统一的 API,几行代码即可开始训练
- 可扩展性:支持自定义策略网络和环境
- 文档完善:详细的文档和丰富的示例
- 社区活跃:持续维护和更新
安装
pip install stable-baselines3
# 安装额外功能
pip install sb3-contrib # 实验性算法
pip install rl_zoo3 # 训练框架
快速开始
基本使用
import gymnasium as gym
from stable_baselines3 import PPO
env = gym.make('CartPole-v1')
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)
obs, info = env.reset()
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
env.close()
保存和加载模型
from stable_baselines3 import PPO
model = PPO('MlpPolicy', 'CartPole-v1')
model.learn(total_timesteps=10000)
model.save('ppo_cartpole')
del model
model = PPO.load('ppo_cartpole')
支持的算法
在线策略算法
| 算法 | 说明 | 适用场景 |
|---|---|---|
| A2C | 同步 Actor-Critic | 简单任务、快速原型 |
| PPO | 近端策略优化 | 通用、稳定 |
| TRPO | 信任区域策略优化 | 需要稳定更新的场景 |
离线策略算法
| 算法 | 说明 | 适用场景 |
|---|---|---|
| DQN | 深度 Q 网络 | 离散动作空间 |
| DDPG | 深度确定性策略梯度 | 连续动作空间 |
| TD3 | 双延迟 DDPG | 连续动作空间、更稳定 |
| SAC | 软 Actor-Critic | 连续动作空间、高样本效率 |
算法选择指南
from stable_baselines3 import PPO, DQN, SAC, TD3, DDPG, A2C
# 离散动作空间
model = DQN('MlpPolicy', 'CartPole-v1')
# 连续动作空间
model = SAC('MlpPolicy', 'Pendulum-v1')
# 通用选择
model = PPO('MlpPolicy', 'CartPole-v1')
策略网络
预定义策略
from stable_baselines3 import PPO
# MLP 策略(全连接网络)
model = PPO('MlpPolicy', env)
# CNN 策略(卷积网络,用于图像输入)
model = PPO('CnnPolicy', env)
# MultiInput 策略(多输入,如字典观测)
model = PPO('MultiInputPolicy', env)
自定义网络架构
import torch.nn as nn
from stable_baselines3 import PPO
policy_kwargs = dict(
net_arch=[64, 64],
activation_fn=nn.ReLU,
normalize_images=False
)
model = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs)
分离 Actor-Critic 架构
policy_kwargs = dict(
net_arch=dict(
pi=[64, 64],
vf=[64, 64]
)
)
model = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs)
超参数配置
PPO 配置
from stable_baselines3 import PPO
model = PPO(
'MlpPolicy',
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.01,
vf_coef=0.5,
max_grad_norm=0.5,
verbose=1
)
DQN 配置
from stable_baselines3 import DQN
model = DQN(
'MlpPolicy',
env,
learning_rate=1e-4,
buffer_size=100000,
learning_starts=1000,
batch_size=32,
tau=1.0,
gamma=0.99,
train_freq=4,
gradient_steps=1,
target_update_interval=1000,
exploration_fraction=0.1,
exploration_final_eps=0.05,
verbose=1
)
SAC 配置
from stable_baselines3 import SAC
model = SAC(
'MlpPolicy',
env,
learning_rate=3e-4,
buffer_size=1000000,
learning_starts=100,
batch_size=256,
tau=0.005,
gamma=0.99,
train_freq=1,
gradient_steps=1,
ent_coef='auto',
verbose=1
)
回调函数
基本回调
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self):
if self.n_calls % 1000 == 0:
print(f"Steps: {self.n_calls}")
return True
model = PPO('MlpPolicy', env)
model.learn(total_timesteps=10000, callback=CustomCallback())
评估回调
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env('CartPole-v1', n_envs=4)
eval_env = gym.make('CartPole-v1')
eval_callback = EvalCallback(
eval_env,
best_model_save_path='./logs/best_model/',
log_path='./logs/results/',
eval_freq=10000,
n_eval_episodes=5,
deterministic=True,
render=False
)
model = PPO('MlpPolicy', env)
model.learn(total_timesteps=100000, callback=eval_callback)
检查点回调
from stable_baselines3.common.callbacks import CheckpointCallback
checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path='./logs/checkpoints/',
name_prefix='rl_model'
)
model.learn(total_timesteps=100000, callback=checkpoint_callback)
组合回调
from stable_baselines3.common.callbacks import CallbackList
callback = CallbackList([
eval_callback,
checkpoint_callback,
CustomCallback()
])
model.learn(total_timesteps=100000, callback=callback)
向量化环境
创建向量化环境
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env('CartPole-v1', n_envs=4)
model = PPO('MlpPolicy', env)
model.learn(total_timesteps=100000)
使用自定义环境
from stable_baselines3.common.vec_env import SubprocVecEnv
def make_env(env_id, rank):
def _init():
env = gym.make(env_id)
return env
return _init
num_cpu = 4
env = SubprocVecEnv([make_env('CartPole-v1', i) for i in range(num_cpu)])
model = PPO('MlpPolicy', env)
model.learn(total_timesteps=100000)
监控和可视化
TensorBoard 集成
from stable_baselines3 import PPO
model = PPO(
'MlpPolicy',
env,
tensorboard_log='./logs/tensorboard/'
)
model.learn(total_timesteps=100000)
启动 TensorBoard:
tensorboard --logdir ./logs/tensorboard/
记录自定义指标
from stable_baselines3.common.callbacks import BaseCallback
class TensorboardCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self):
self.logger.record('custom/reward', self.training_env.get_attr('reward')[0])
return True
自定义环境
使用自定义环境
import gymnasium as gym
from stable_baselines3 import PPO
class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4,))
self.action_space = gym.spaces.Discrete(2)
def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}
def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
terminated = False
truncated = False
info = {}
return obs, reward, terminated, truncated, info
env = CustomEnv()
model = PPO('MlpPolicy', env)
model.learn(total_timesteps=10000)
高级功能
预训练模型
from stable_baselines3 import PPO
model = PPO.load('pretrained_model.zip')
model.set_env(new_env)
model.learn(total_timesteps=10000)
继续训练
model = PPO.load('ppo_cartpole', env=env)
model.learn(total_timesteps=10000)
获取和设置参数
params = model.get_parameters()
model.set_parameters(params)
导出模型
from stable_baselines3 import PPO
model = PPO('MlpPolicy', 'CartPole-v1')
model.learn(total_timesteps=10000)
model.save('ppo_cartpole')
完整示例
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, CallbackList
from stable_baselines3.common.monitor import Monitor
env = make_vec_env('CartPole-v1', n_envs=4)
eval_env = Monitor(gym.make('CartPole-v1'))
eval_callback = EvalCallback(
eval_env,
best_model_save_path='./logs/best/',
log_path='./logs/results/',
eval_freq=10000,
deterministic=True,
render=False
)
model = PPO(
'MlpPolicy',
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
verbose=1,
tensorboard_log='./logs/tensorboard/'
)
model.learn(total_timesteps=100000, callback=eval_callback)
model.save('ppo_cartpole_final')
obs, info = eval_env.reset()
for _ in range(1000):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = eval_env.step(action)
if terminated or truncated:
obs, info = eval_env.reset()
eval_env.close()
常见问题
训练不稳定
- 降低学习率
- 增加 batch size
- 使用梯度裁剪
- 检查奖励设计
样本效率低
- 使用离线策略算法(SAC、TD3)
- 增加并行环境数量
- 调整探索参数
内存不足
- 减小 buffer_size
- 减小 batch_size
- 减小网络大小
小结
Stable Baselines3 是强化学习实践的重要工具:
- 统一 API:所有算法使用相同的接口
- 丰富算法:支持主流强化学习算法
- 易于使用:几行代码即可开始训练
- 可扩展:支持自定义策略和环境
- 监控完善:TensorBoard 集成和回调系统
下一章将提供强化学习的速查表。