跳到主要内容

Stable Baselines3

Stable Baselines3(简称 SB3)是一个提供可靠强化学习算法实现的 Python 库。它基于 PyTorch 构建,提供了清晰、一致且经过充分测试的算法实现,是进行强化学习研究和应用的首选工具。

SB3 简介

为什么选择 SB3?

  • 可靠性:所有算法都经过严格测试和验证
  • 易用性:简洁统一的 API,几行代码即可开始训练
  • 可扩展性:支持自定义策略网络和环境
  • 文档完善:详细的文档和丰富的示例
  • 社区活跃:持续维护和更新

安装

pip install stable-baselines3

# 安装额外功能
pip install sb3-contrib # 实验性算法
pip install rl_zoo3 # 训练框架

快速开始

基本使用

import gymnasium as gym
from stable_baselines3 import PPO

env = gym.make('CartPole-v1')

model = PPO('MlpPolicy', env, verbose=1)

model.learn(total_timesteps=10000)

obs, info = env.reset()
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()

env.close()

保存和加载模型

from stable_baselines3 import PPO

model = PPO('MlpPolicy', 'CartPole-v1')
model.learn(total_timesteps=10000)

model.save('ppo_cartpole')

del model

model = PPO.load('ppo_cartpole')

支持的算法

在线策略算法

算法说明适用场景
A2C同步 Actor-Critic简单任务、快速原型
PPO近端策略优化通用、稳定
TRPO信任区域策略优化需要稳定更新的场景

离线策略算法

算法说明适用场景
DQN深度 Q 网络离散动作空间
DDPG深度确定性策略梯度连续动作空间
TD3双延迟 DDPG连续动作空间、更稳定
SAC软 Actor-Critic连续动作空间、高样本效率

算法选择指南

from stable_baselines3 import PPO, DQN, SAC, TD3, DDPG, A2C

# 离散动作空间
model = DQN('MlpPolicy', 'CartPole-v1')

# 连续动作空间
model = SAC('MlpPolicy', 'Pendulum-v1')

# 通用选择
model = PPO('MlpPolicy', 'CartPole-v1')

策略网络

预定义策略

from stable_baselines3 import PPO

# MLP 策略(全连接网络)
model = PPO('MlpPolicy', env)

# CNN 策略(卷积网络,用于图像输入)
model = PPO('CnnPolicy', env)

# MultiInput 策略(多输入,如字典观测)
model = PPO('MultiInputPolicy', env)

自定义网络架构

import torch.nn as nn
from stable_baselines3 import PPO

policy_kwargs = dict(
net_arch=[64, 64],
activation_fn=nn.ReLU,
normalize_images=False
)

model = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs)

分离 Actor-Critic 架构

policy_kwargs = dict(
net_arch=dict(
pi=[64, 64],
vf=[64, 64]
)
)

model = PPO('MlpPolicy', env, policy_kwargs=policy_kwargs)

超参数配置

PPO 配置

from stable_baselines3 import PPO

model = PPO(
'MlpPolicy',
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.01,
vf_coef=0.5,
max_grad_norm=0.5,
verbose=1
)

DQN 配置

from stable_baselines3 import DQN

model = DQN(
'MlpPolicy',
env,
learning_rate=1e-4,
buffer_size=100000,
learning_starts=1000,
batch_size=32,
tau=1.0,
gamma=0.99,
train_freq=4,
gradient_steps=1,
target_update_interval=1000,
exploration_fraction=0.1,
exploration_final_eps=0.05,
verbose=1
)

SAC 配置

from stable_baselines3 import SAC

model = SAC(
'MlpPolicy',
env,
learning_rate=3e-4,
buffer_size=1000000,
learning_starts=100,
batch_size=256,
tau=0.005,
gamma=0.99,
train_freq=1,
gradient_steps=1,
ent_coef='auto',
verbose=1
)

回调函数

基本回调

from stable_baselines3.common.callbacks import BaseCallback

class CustomCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)

def _on_step(self):
if self.n_calls % 1000 == 0:
print(f"Steps: {self.n_calls}")
return True

model = PPO('MlpPolicy', env)
model.learn(total_timesteps=10000, callback=CustomCallback())

评估回调

from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env

env = make_vec_env('CartPole-v1', n_envs=4)
eval_env = gym.make('CartPole-v1')

eval_callback = EvalCallback(
eval_env,
best_model_save_path='./logs/best_model/',
log_path='./logs/results/',
eval_freq=10000,
n_eval_episodes=5,
deterministic=True,
render=False
)

model = PPO('MlpPolicy', env)
model.learn(total_timesteps=100000, callback=eval_callback)

检查点回调

from stable_baselines3.common.callbacks import CheckpointCallback

checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path='./logs/checkpoints/',
name_prefix='rl_model'
)

model.learn(total_timesteps=100000, callback=checkpoint_callback)

组合回调

from stable_baselines3.common.callbacks import CallbackList

callback = CallbackList([
eval_callback,
checkpoint_callback,
CustomCallback()
])

model.learn(total_timesteps=100000, callback=callback)

向量化环境

创建向量化环境

from stable_baselines3.common.env_util import make_vec_env

env = make_vec_env('CartPole-v1', n_envs=4)

model = PPO('MlpPolicy', env)
model.learn(total_timesteps=100000)

使用自定义环境

from stable_baselines3.common.vec_env import SubprocVecEnv

def make_env(env_id, rank):
def _init():
env = gym.make(env_id)
return env
return _init

num_cpu = 4
env = SubprocVecEnv([make_env('CartPole-v1', i) for i in range(num_cpu)])

model = PPO('MlpPolicy', env)
model.learn(total_timesteps=100000)

监控和可视化

TensorBoard 集成

from stable_baselines3 import PPO

model = PPO(
'MlpPolicy',
env,
tensorboard_log='./logs/tensorboard/'
)

model.learn(total_timesteps=100000)

启动 TensorBoard:

tensorboard --logdir ./logs/tensorboard/

记录自定义指标

from stable_baselines3.common.callbacks import BaseCallback

class TensorboardCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)

def _on_step(self):
self.logger.record('custom/reward', self.training_env.get_attr('reward')[0])
return True

自定义环境

使用自定义环境

import gymnasium as gym
from stable_baselines3 import PPO

class CustomEnv(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4,))
self.action_space = gym.spaces.Discrete(2)

def reset(self, seed=None, options=None):
return self.observation_space.sample(), {}

def step(self, action):
obs = self.observation_space.sample()
reward = 1.0
terminated = False
truncated = False
info = {}
return obs, reward, terminated, truncated, info

env = CustomEnv()
model = PPO('MlpPolicy', env)
model.learn(total_timesteps=10000)

高级功能

预训练模型

from stable_baselines3 import PPO

model = PPO.load('pretrained_model.zip')

model.set_env(new_env)
model.learn(total_timesteps=10000)

继续训练

model = PPO.load('ppo_cartpole', env=env)
model.learn(total_timesteps=10000)

获取和设置参数

params = model.get_parameters()
model.set_parameters(params)

导出模型

from stable_baselines3 import PPO

model = PPO('MlpPolicy', 'CartPole-v1')
model.learn(total_timesteps=10000)

model.save('ppo_cartpole')

完整示例

import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, CallbackList
from stable_baselines3.common.monitor import Monitor

env = make_vec_env('CartPole-v1', n_envs=4)

eval_env = Monitor(gym.make('CartPole-v1'))

eval_callback = EvalCallback(
eval_env,
best_model_save_path='./logs/best/',
log_path='./logs/results/',
eval_freq=10000,
deterministic=True,
render=False
)

model = PPO(
'MlpPolicy',
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
verbose=1,
tensorboard_log='./logs/tensorboard/'
)

model.learn(total_timesteps=100000, callback=eval_callback)

model.save('ppo_cartpole_final')

obs, info = eval_env.reset()
for _ in range(1000):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = eval_env.step(action)
if terminated or truncated:
obs, info = eval_env.reset()

eval_env.close()

常见问题

训练不稳定

  • 降低学习率
  • 增加 batch size
  • 使用梯度裁剪
  • 检查奖励设计

样本效率低

  • 使用离线策略算法(SAC、TD3)
  • 增加并行环境数量
  • 调整探索参数

内存不足

  • 减小 buffer_size
  • 减小 batch_size
  • 减小网络大小

小结

Stable Baselines3 是强化学习实践的重要工具:

  • 统一 API:所有算法使用相同的接口
  • 丰富算法:支持主流强化学习算法
  • 易于使用:几行代码即可开始训练
  • 可扩展:支持自定义策略和环境
  • 监控完善:TensorBoard 集成和回调系统

下一章将提供强化学习的速查表。