Soft Actor-Critic (SAC)
Soft Actor-Critic(SAC)是 2018 年由 UC Berkeley 的 Haarnoja 等人提出的一种离线策略深度强化学习算法。SAC 通过引入最大熵强化学习框架,在保持高样本效率的同时实现了稳定的学习过程,目前已成为连续动作空间任务中最流行的算法之一。
为什么需要 SAC?
连续动作空间的挑战
在连续动作空间中,传统的 Q-Learning 方法面临根本性困难:无法对无限的动作空间进行穷举来计算 。DDPG 和 TD3 通过学习确定性策略解决了这个问题,但它们仍然存在一些局限:
- 探索不足:确定性策略需要额外添加噪声来探索,噪声的选择需要调参
- 过估计问题:虽然 TD3 通过双 Q 网络缓解了过估计,但问题并未完全解决
- 超参数敏感:对学习率、噪声参数等超参数较为敏感
SAC 的核心思想
SAC 的核心创新在于引入了熵正则化(Entropy Regularization)。传统强化学习只最大化期望回报:
而 SAC 同时最大化回报和策略熵:
其中 是策略的熵, 是温度参数。
熵正则化的好处
什么是熵?
熵是衡量随机变量不确定性的指标。对于一个离散随机变量 ,其熵定义为:
对于连续随机变量(如高斯分布),熵为:
熵正则化的优势:
- 更好的探索:高熵意味着策略更加随机,能够探索更多状态-动作对
- 避免局部最优:随机性有助于逃离局部最优解
- 鲁棒性:学到的策略对环境变化更加鲁棒
- 无需手动调整噪声:熵天然提供了探索机制
最大熵强化学习
熵正则化的 MDP
在最大熵强化学习框架下,价值函数的定义有所变化。状态价值函数包含未来所有时刻的熵奖励:
动作价值函数包含除第一个动作外所有时刻的熵奖励:
V 和 Q 的关系
在最大熵框架下,V 和 Q 的关系变为:
这可以重写为:
这个公式告诉我们:状态价值等于在该状态下所有动作的"软"最大值——不是取最大,而是对 Q 值减去熵惩罚后的期望。
软贝尔曼方程
最大熵框架下的贝尔曼方程:
利用熵的定义:
这个方程是 SAC 算法的理论基础。
SAC 算法详解
网络架构
SAC 同时学习以下组件:
- 策略网络 :输出动作分布的参数(均值和标准差)
- 两个 Q 网络 :估计动作价值
- 两个目标 Q 网络 :计算目标值
为什么要两个 Q 网络?这与 TD3 相同,是为了解决过估计问题,使用取最小值的方式:
Q 网络的更新
Q 网络通过最小化软贝尔曼误差来更新:
目标值 的计算:
其中 是从当前策略采样的下一动作。
与 TD3 的关键区别:
- 目标动作来源:TD3 使用目标策略网络,SAC 使用当前策略网络
- 熵项:SAC 的目标包含熵惩罚
- 策略随机性:SAC 的策略本身就是随机的,无需额外添加噪声
策略网络的更新
策略优化的目标是最大化 :
为了能够通过梯度下降优化,SAC 使用了重参数化技巧(Reparameterization Trick)。
重参数化技巧
直接对 求梯度是困难的,因为采样过程本身依赖于策略参数。重参数化技巧将采样过程改写为:
其中 和 是策略网络输出的均值和标准差, 是独立于参数的标准高斯噪声。
为了将动作限制在有界范围内,SAC 使用 tanh 函数进行压缩:
使用重参数化后,策略损失变为:
Squashed Gaussian Policy 的对数概率
使用 tanh 压缩后,动作分布不再是高斯分布。对于压缩后的动作 ,其概率密度为:
其中 是未压缩的高斯分布密度。对数概率为:
在代码中需要特别注意这个计算:
def gaussian_log_prob(noise, log_std):
return -0.5 * (noise ** 2 + 2 * log_std + np.log(2 * np.pi))
def squashed_gaussian_log_prob(mean, log_std, noise):
log_prob = gaussian_log_prob(noise, log_std)
action = torch.tanh(mean + log_std.exp() * noise)
log_prob -= torch.sum(torch.log(1 - action.pow(2) + 1e-6), dim=-1)
return log_prob
自动温度调整
手动选择温度参数 是困难的。SAC 提出了自动调整温度的方法,将温度作为一个可学习的参数,通过约束策略的熵来优化:
其中 是目标熵,通常设置为动作空间维度的负值:
温度的损失函数:
class SAC:
def __init__(self, ...):
self.log_alpha = torch.zeros(1, requires_grad=True)
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)
self.target_entropy = -action_dim
def update_alpha(self, states, log_probs):
alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
return self.log_alpha.exp()
SAC 完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=1000000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states),
np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(next_states),
np.array(dones, dtype=np.float32)
)
def __len__(self):
return len(self.buffer)
class GaussianPolicy(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256, log_std_min=-20, log_std_max=2):
super().__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.mean_layer = nn.Linear(hidden_dim, action_dim)
self.log_std_layer = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
features = self.network(state)
mean = self.mean_layer(features)
log_std = self.log_std_layer(features)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
return mean, log_std
def sample(self, state, deterministic=False):
mean, log_std = self.forward(state)
std = log_std.exp()
if deterministic:
action = torch.tanh(mean)
log_prob = None
else:
noise = torch.randn_like(mean)
x = mean + std * noise
action = torch.tanh(x)
log_prob = -0.5 * (noise ** 2 + 2 * log_std + np.log(2 * np.pi))
log_prob = log_prob.sum(dim=-1, keepdim=True)
log_prob -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
return action, log_prob, mean, log_std
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.network(x)
class SACAgent:
def __init__(self, state_dim, action_dim, hidden_dim=256, lr=3e-4,
gamma=0.99, tau=0.005, alpha=0.2, auto_entropy=True):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.gamma = gamma
self.tau = tau
self.auto_entropy = auto_entropy
# 策略网络
self.policy = GaussianPolicy(state_dim, action_dim, hidden_dim).to(self.device)
self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
# Q 网络
self.q1 = QNetwork(state_dim, action_dim, hidden_dim).to(self.device)
self.q2 = QNetwork(state_dim, action_dim, hidden_dim).to(self.device)
self.q1_optimizer = optim.Adam(self.q1.parameters(), lr=lr)
self.q2_optimizer = optim.Adam(self.q2.parameters(), lr=lr)
# 目标 Q 网络
self.target_q1 = QNetwork(state_dim, action_dim, hidden_dim).to(self.device)
self.target_q2 = QNetwork(state_dim, action_dim, hidden_dim).to(self.device)
self.target_q1.load_state_dict(self.q1.state_dict())
self.target_q2.load_state_dict(self.q2.state_dict())
# 温度参数
if auto_entropy:
self.target_entropy = -action_dim
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
else:
self.alpha = alpha
self.replay_buffer = ReplayBuffer()
@property
def alpha(self):
if self.auto_entropy:
return self.log_alpha.exp()
return self._alpha
def select_action(self, state, deterministic=False):
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
action, _, _, _ = self.policy.sample(state, deterministic)
return action.cpu().numpy()[0]
def update(self, batch_size=256):
if len(self.replay_buffer) < batch_size:
return {}
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.FloatTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
# 更新 Q 网络
with torch.no_grad():
next_actions, next_log_probs, _, _ = self.policy.sample(next_states)
q1_next = self.target_q1(next_states, next_actions)
q2_next = self.target_q2(next_states, next_actions)
q_next = torch.min(q1_next, q2_next) - self.alpha * next_log_probs
target_q = rewards + self.gamma * (1 - dones) * q_next
q1_pred = self.q1(states, actions)
q2_pred = self.q2(states, actions)
q1_loss = F.mse_loss(q1_pred, target_q)
q2_loss = F.mse_loss(q2_pred, target_q)
self.q1_optimizer.zero_grad()
q1_loss.backward()
self.q1_optimizer.step()
self.q2_optimizer.zero_grad()
q2_loss.backward()
self.q2_optimizer.step()
# 更新策略网络
new_actions, log_probs, _, _ = self.policy.sample(states)
q1_new = self.q1(states, new_actions)
q2_new = self.q2(states, new_actions)
q_new = torch.min(q1_new, q2_new)
policy_loss = (self.alpha * log_probs - q_new).mean()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
# 更新温度参数
if self.auto_entropy:
alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
# 软更新目标网络
for param, target_param in zip(self.q1.parameters(), self.target_q1.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.q2.parameters(), self.target_q2.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
return {
'q1_loss': q1_loss.item(),
'q2_loss': q2_loss.item(),
'policy_loss': policy_loss.item(),
'alpha': self.alpha.item()
}
def train(self, env, total_steps=1000000, start_steps=10000, update_every=1,
batch_size=256, eval_interval=5000):
state, _ = env.reset()
episode_reward = 0
episode_rewards = []
for step in range(total_steps):
if step < start_steps:
action = env.action_space.sample()
else:
action = self.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
self.replay_buffer.push(state, action, reward, next_state, float(done))
state = next_state
episode_reward += reward
if done:
state, _ = env.reset()
episode_rewards.append(episode_reward)
episode_reward = 0
if step >= start_steps and step % update_every == 0:
losses = self.update(batch_size)
if (step + 1) % eval_interval == 0:
avg_reward = np.mean(episode_rewards[-100:]) if episode_rewards else 0
print(f"Step {step + 1}, Avg Reward: {avg_reward:.2f}")
return episode_rewards
SAC 与其他算法的比较
SAC vs TD3
| 特性 | SAC | TD3 |
|---|---|---|
| 策略类型 | 随机策略 | 确定性策略 |
| 探索方式 | 熵正则化(自动) | 添加噪声(需要调参) |
| 样本效率 | 较高 | 较高 |
| 超参数敏感性 | 较低 | 较高 |
| 连续动作 | 支持 | 支持 |
| 离散动作 | 需要修改 | 不支持 |
SAC vs PPO
| 特性 | SAC | PPO |
|---|---|---|
| 策略类型 | 离线策略 | 在线策略 |
| 样本效率 | 高 | 低 |
| 数据重用 | 支持(经验回放) | 有限 |
| 稳定性 | 高 | 高 |
| 连续动作 | 支持 | 支持 |
| 离散动作 | 需要修改 | 支持 |
什么时候选择 SAC?
- 连续动作空间:SAC 是连续控制任务的首选
- 样本效率要求高:离线策略可以重用历史数据
- 希望自动探索:熵正则化自动平衡探索与利用
- 超参数敏感的任务:SAC 对超参数相对鲁棒
实践建议
网络架构
SAC 的默认架构相对简单:
policy_kwargs = dict(
net_arch=[256, 256],
activation_fn=nn.ReLU
)
对于复杂任务,可以考虑更深或更宽的网络。注意 SAC 默认使用 ReLU 而不是 tanh。
超参数选择
| 参数 | 推荐值 | 说明 |
|---|---|---|
| learning_rate | 3e-4 | 学习率 |
| buffer_size | 1e6 | 经验回放缓冲区大小 |
| batch_size | 256 | 批量大小 |
| gamma | 0.99 | 折扣因子 |
| tau | 0.005 | 软更新系数 |
| alpha | 'auto' | 温度参数,建议自动调整 |
| start_steps | 10000 | 随机探索步数 |
常见问题及解决方案
训练不稳定:
- 检查奖励尺度,过大的奖励可能导致不稳定
- 尝试调整学习率
- 确保动作空间归一化到 [-1, 1]
探索不足:
- 增大初始温度参数
- 检查自动温度调整是否正常工作
- 增加初始随机探索步数
收敛慢:
- 增大经验回放缓冲区
- 增加更新频率
- 检查网络架构是否合适
使用 Stable Baselines3 的 SAC
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
env = make_vec_env('Pendulum-v1', n_envs=1)
model = SAC(
'MlpPolicy',
env,
learning_rate=3e-4,
buffer_size=1000000,
learning_starts=10000,
batch_size=256,
tau=0.005,
gamma=0.99,
ent_coef='auto',
verbose=1,
tensorboard_log='./logs/'
)
model.learn(total_timesteps=100000)
model.save('sac_pendulum')
obs, _ = env.reset()
for _ in range(1000):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, _ = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
小结
SAC 是目前最优秀的连续动作空间强化学习算法之一:
- 核心创新:熵正则化框架,自动平衡探索与利用
- 关键技巧:双 Q 网络、重参数化、自动温度调整
- 主要优势:样本效率高、超参数鲁棒、无需手动调探索
- 适用场景:连续动作空间、样本效率要求高的任务
SAC 结合了离线策略的样本效率和随机策略的探索能力,是解决连续控制问题的强大工具。
参考文献
- Haarnoja, T., et al. (2018). Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor. ICML.
- Haarnoja, T., et al. (2018). Soft Actor-Critic Algorithms and Applications. arXiv preprint.
- Schulman, J., et al. (2017). Equivalence Between Policy Gradients and Soft Q-Learning. arXiv preprint.