深度 Q 网络(DQN)
深度 Q 网络(Deep Q-Network,简称 DQN)是将深度学习与 Q-Learning 结合的里程碑式算法。它解决了传统 Q-Learning 无法处理高维状态空间的问题,是深度强化学习的开山之作。
为什么需要 DQN?
表格型 Q-Learning 的局限
传统 Q-Learning 使用表格存储 Q 值,这在以下情况下会遇到严重问题:
- 状态空间爆炸:状态数量随特征维度指数增长
- 连续状态空间:无法用有限表格表示无限状态
- 缺乏泛化能力:相似状态无法共享知识
实际例子
以 Atari 游戏为例:
- 输入:84×84×4 的图像(约 28 万维)
- 如果每个像素取 256 个值,状态空间大小为
显然,表格方法完全不可行。
解决思路
使用神经网络来近似 Q 函数:
其中 是神经网络的参数。神经网络可以将高维状态映射到 Q 值,并具有强大的泛化能力。
DQN 的核心创新
DQN 的成功依赖于两个关键创新:经验回放和目标网络。
经验回放(Experience Replay)
问题:强化学习的数据是序列相关的,这会导致训练不稳定。
解决方案:将经验存储在回放缓冲区中,训练时随机采样。
import numpy as np
import random
from collections import deque
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states),
np.array(actions),
np.array(rewards),
np.array(next_states),
np.array(dones)
)
def __len__(self):
return len(self.buffer)
经验回放的优点:
- 打破数据相关性:随机采样使数据更接近独立同分布
- 提高数据利用率:每条经验可以被多次使用
- 支持离线策略学习:可以存储和重用历史经验
目标网络(Target Network)
问题:如果使用同一个网络计算当前 Q 值和目标 Q 值,会导致训练不稳定。
解决方案:使用两个网络:
- 主网络:用于选择动作和计算当前 Q 值
- 目标网络:用于计算目标 Q 值,参数定期从主网络复制
import torch
import torch.nn as nn
class DQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, x):
return self.network(x)
class DQNAgent:
def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99,
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
buffer_size=10000, batch_size=64, target_update=10):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.q_network = DQN(state_dim, action_dim).to(self.device)
self.target_network = DQN(state_dim, action_dim).to(self.device)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
self.buffer = ReplayBuffer(buffer_size)
self.action_dim = action_dim
self.gamma = gamma
self.batch_size = batch_size
self.target_update = target_update
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.steps = 0
def select_action(self, state):
if random.random() < self.epsilon:
return random.randint(0, self.action_dim - 1)
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.q_network(state)
return q_values.argmax().item()
def update(self):
if len(self.buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_q = self.target_network(next_states).max(1)[0]
target_q = rewards + self.gamma * next_q * (1 - dones)
loss = nn.MSELoss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.steps += 1
if self.steps % self.target_update == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
return loss.item()
def decay_epsilon(self):
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
DQN 算法流程
算法伪代码
初始化回放缓冲区 D
初始化主网络 Q(s,a;θ) 和目标网络 Q(s,a;θ⁻),θ⁻ = θ
对于每个回合:
初始化状态 s
对于每个时间步:
以 ε 概率随机选择动作,否则选择 a = argmax Q(s,a;θ)
执行动作 a,观察 r, s'
存储 (s, a, r, s', done) 到 D
从 D 中随机采样小批量
计算目标:y = r + γ max_a' Q(s',a';θ⁻)
对 (y - Q(s,a;θ))² 进行梯度下降
每 C 步更新 θ⁻ = θ
s ← s'
直到终止
完整实现
import gymnasium as gym
import torch
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
def train_dqn(env_name='CartPole-v1', num_episodes=500):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
rewards_history = []
recent_rewards = deque(maxlen=100)
for episode in range(num_episodes):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
action = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.buffer.push(state, action, reward, next_state, float(done))
loss = agent.update()
state = next_state
total_reward += reward
agent.decay_epsilon()
rewards_history.append(total_reward)
recent_rewards.append(total_reward)
if (episode + 1) % 10 == 0:
avg_reward = np.mean(recent_rewards)
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, Epsilon: {agent.epsilon:.3f}")
if len(recent_rewards) == 100 and np.mean(recent_rewards) >= 495:
print(f"Environment solved in {episode + 1} episodes!")
break
env.close()
return rewards_history
rewards = train_dqn()
plt.figure(figsize=(10, 5))
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('DQN Training on CartPole')
plt.show()
DQN 的改进
Double DQN
问题:DQN 存在过估计问题,因为使用同一个网络选择和评估动作。
解决方案:使用主网络选择动作,目标网络评估价值。
def double_dqn_update(self):
if len(self.buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
next_actions = self.q_network(next_states).argmax(1)
next_q = self.target_network(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
target_q = rewards + self.gamma * next_q * (1 - dones)
loss = nn.MSELoss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
Dueling DQN
思想:将 Q 值分解为状态价值和动作优势。
其中:
- 是状态价值,表示状态 s 的好坏
- 是动作优势,表示动作 a 相对于平均动作的优势
class DuelingDQN(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.feature = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, x):
features = self.feature(x)
values = self.value_stream(features)
advantages = self.advantage_stream(features)
q_values = values + (advantages - advantages.mean(dim=1, keepdim=True))
return q_values
Prioritized Experience Replay
思想:重要的经验应该被更频繁地采样。
优先级定义:使用 TD 误差的绝对值作为优先级。
采样概率:
class PrioritizedReplayBuffer:
def __init__(self, capacity=10000, alpha=0.6, beta=0.4, beta_increment=0.001):
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.beta_increment = beta_increment
self.buffer = []
self.priorities = np.zeros(capacity, dtype=np.float32)
self.pos = 0
self.size = 0
def push(self, state, action, reward, next_state, done):
max_priority = self.priorities.max() if self.size > 0 else 1.0
if self.size < self.capacity:
self.buffer.append((state, action, reward, next_state, done))
else:
self.buffer[self.pos] = (state, action, reward, next_state, done)
self.priorities[self.pos] = max_priority
self.pos = (self.pos + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size):
probs = self.priorities[:self.size] ** self.alpha
probs /= probs.sum()
indices = np.random.choice(self.size, batch_size, p=probs)
samples = [self.buffer[i] for i in indices]
weights = (self.size * probs[indices]) ** (-self.beta)
weights /= weights.max()
self.beta = min(1.0, self.beta + self.beta_increment)
states, actions, rewards, next_states, dones = zip(*samples)
return (
np.array(states),
np.array(actions),
np.array(rewards),
np.array(next_states),
np.array(dones),
indices,
weights
)
def update_priorities(self, indices, td_errors):
self.priorities[indices] = np.abs(td_errors) + 1e-6
DQN for Atari Games
DQN 最著名的应用是在 Atari 游戏上达到人类水平:
import torch
import torch.nn as nn
import gymnasium as gym
import numpy as np
class AtariDQN(nn.Module):
def __init__(self, num_actions):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
self.fc = nn.Sequential(
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, num_actions)
)
def forward(self, x):
x = x.float() / 255.0
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
class FrameStack:
def __init__(self, num_frames=4):
self.num_frames = num_frames
self.frames = deque(maxlen=num_frames)
def reset(self):
self.frames.clear()
def push(self, frame):
self.frames.append(frame)
def get(self):
return np.array(self.frames)
超参数调优建议
| 超参数 | 推荐值 | 说明 |
|---|---|---|
| 学习率 | 1e-4 ~ 1e-3 | 较小的学习率更稳定 |
| 折扣因子 | 0.99 | 大多数任务适用 |
| 回放缓冲区大小 | 10^5 ~ 10^6 | 足够存储多样经验 |
| 批量大小 | 32 ~ 128 | 根据内存调整 |
| 目标网络更新频率 | 1000 ~ 10000 步 | 太频繁会不稳定 |
| ε 衰减 | 0.99 ~ 0.999 | 逐渐减少探索 |
小结
DQN 是深度强化学习的里程碑:
- 核心思想:用神经网络近似 Q 函数
- 关键创新:经验回放和目标网络
- 改进版本:Double DQN、Dueling DQN、Prioritized Replay
- 应用:Atari 游戏、机器人控制等
DQN 为后续的深度强化学习算法奠定了基础。下一章将介绍策略梯度方法,它直接优化策略函数,能够处理连续动作空间的问题。