SARSA 算法
SARSA(State-Action-Reward-State-Action)是一种在线策略的时序差分学习算法。与 Q-Learning 不同,SARSA 学习的是当前策略的价值函数,而不是最优策略的价值函数。
SARSA 的核心思想
SARSA 的名称来源于它的更新过程涉及五元组 ,即当前状态、当前动作、奖励、下一状态和下一动作。
在线策略 vs 离线策略
理解 SARSA 需要先区分在线策略和离线策略:
- 在线策略(On-Policy):学习和执行的策略相同。SARSA 学习当前正在执行的策略的价值。
- 离线策略(Off-Policy):学习和执行的策略可以不同。Q-Learning 学习最优策略的价值,即使执行的是探索策略。
为什么选择 SARSA?
SARSA 在某些场景下比 Q-Learning 更合适:
- 安全性考虑:SARSA 会考虑探索行为的风险
- 策略评估:需要评估特定策略的价值
- 稳定性:在线策略学习通常更稳定
SARSA 更新公式
SARSA 的更新公式:
与 Q-Learning 的区别:
- Q-Learning:使用 (最优动作的价值)
- SARSA:使用 (实际选择的下一动作的价值)
理解差异
这个看似微小的差异带来了重要的影响:
- Q-Learning 假设未来总是选择最优动作,即使当前策略会随机探索
- SARSA 考虑当前策略的实际行为,包括探索动作
SARSA 算法流程
算法伪代码
初始化 Q(s,a) 为任意值
对于每个回合:
初始化状态 s
根据策略选择动作 a
对于每个时间步:
执行动作 a,观察奖励 r 和下一状态 s'
根据策略选择下一动作 a'
Q(s,a) ← Q(s,a) + α[r + γ Q(s',a') - Q(s,a)]
s ← s', a ← a'
直到 s 是终止状态
Python 实现
import numpy as np
class SARSA:
def __init__(self, num_states, num_actions, learning_rate=0.1,
discount_factor=0.99, epsilon=0.1):
self.num_states = num_states
self.num_actions = num_actions
self.lr = learning_rate
self.gamma = discount_factor
self.epsilon = epsilon
self.q_table = np.zeros((num_states, num_actions))
def select_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.num_actions)
return np.argmax(self.q_table[state])
def update(self, state, action, reward, next_state, next_action, done):
current_q = self.q_table[state, action]
if done:
target = reward
else:
target = reward + self.gamma * self.q_table[next_state, next_action]
td_error = target - current_q
self.q_table[state, action] += self.lr * td_error
def train(self, env, num_episodes=1000):
rewards_history = []
for episode in range(num_episodes):
state = env.reset()
action = self.select_action(state)
total_reward = 0
done = False
while not done:
next_state, reward, done, _ = env.step(action)
next_action = self.select_action(next_state)
self.update(state, action, reward, next_state, next_action, done)
state = next_state
action = next_action
total_reward += reward
rewards_history.append(total_reward)
return rewards_history
Q-Learning vs SARSA:悬崖行走对比
悬崖行走是展示两种算法差异的经典例子:
import numpy as np
import matplotlib.pyplot as plt
class CliffWalking:
def __init__(self, width=12, height=4):
self.width = width
self.height = height
self.start = (height - 1, 0)
self.goal = (height - 1, width - 1)
self.cliff = [(height - 1, i) for i in range(1, width - 1)]
self.state = self.start
def reset(self):
self.state = self.start
return self._state_to_idx(self.state)
def _state_to_idx(self, state):
return state[0] * self.width + state[1]
def step(self, action):
row, col = self.state
if action == 0: # 上
row = max(0, row - 1)
elif action == 1: # 下
row = min(self.height - 1, row + 1)
elif action == 2: # 左
col = max(0, col - 1)
elif action == 3: # 右
col = min(self.width - 1, col + 1)
self.state = (row, col)
if self.state in self.cliff:
return self.reset(), -100, True
elif self.state == self.goal:
return self._state_to_idx(self.state), -1, True
else:
return self._state_to_idx(self.state), -1, False
def train_qlearning(env, num_episodes=500, epsilon=0.1):
q_table = np.zeros((env.width * env.height, 4))
rewards = []
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
done = False
while not done:
if np.random.random() < epsilon:
action = np.random.randint(4)
else:
action = np.argmax(q_table[state])
next_state, reward, done = env.step(action)
target = reward + 0.99 * np.max(q_table[next_state]) * (1 - done)
q_table[state, action] += 0.5 * (target - q_table[state, action])
state = next_state
total_reward += reward
rewards.append(total_reward)
return q_table, rewards
def train_sarsa(env, num_episodes=500, epsilon=0.1):
q_table = np.zeros((env.width * env.height, 4))
rewards = []
for episode in range(num_episodes):
state = env.reset()
if np.random.random() < epsilon:
action = np.random.randint(4)
else:
action = np.argmax(q_table[state])
total_reward = 0
done = False
while not done:
next_state, reward, done = env.step(action)
if np.random.random() < epsilon:
next_action = np.random.randint(4)
else:
next_action = np.argmax(q_table[next_state])
if done:
target = reward
else:
target = reward + 0.99 * q_table[next_state, next_action]
q_table[state, action] += 0.5 * (target - q_table[state, action])
state = next_state
action = next_action
total_reward += reward
rewards.append(total_reward)
return q_table, rewards
env = CliffWalking()
q_table_q, rewards_q = train_qlearning(env)
q_table_sarsa, rewards_sarsa = train_sarsa(env)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(rewards_q, label='Q-Learning', alpha=0.7)
plt.plot(rewards_sarsa, label='SARSA', alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Learning Curves')
plt.legend()
plt.subplot(1, 2, 2)
plt.bar(['Q-Learning', 'SARSA'],
[np.mean(rewards_q[-50:]), np.mean(rewards_sarsa[-50:])])
plt.ylabel('Average Reward (last 50 episodes)')
plt.title('Final Performance')
plt.tight_layout()
plt.show()
结果分析
在悬崖行走环境中:
- Q-Learning 学到的是沿着悬崖边缘走的"最优"路径,但在执行时由于探索行为可能掉入悬崖
- SARSA 学到的是远离悬崖的安全路径,虽然路径更长但更稳定
SARSA(λ):资格迹
SARSA(λ) 是 SARSA 的扩展,通过资格迹实现多步更新。
前向视角
前向视角使用 n 步回报:
λ-回报是不同 n 步回报的加权平均:
后向视角
后向视角使用资格迹高效实现:
class SARSALambda:
def __init__(self, num_states, num_actions, lr=0.1, gamma=0.99,
epsilon=0.1, lambda_=0.8):
self.num_states = num_states
self.num_actions = num_actions
self.lr = lr
self.gamma = gamma
self.epsilon = epsilon
self.lambda_ = lambda_
self.q_table = np.zeros((num_states, num_actions))
def train(self, env, num_episodes=1000):
rewards_history = []
for episode in range(num_episodes):
eligibility = np.zeros((self.num_states, self.num_actions))
state = env.reset()
action = self.select_action(state)
total_reward = 0
done = False
while not done:
next_state, reward, done, _ = env.step(action)
next_action = self.select_action(next_state)
if done:
td_error = reward - self.q_table[state, action]
else:
td_error = reward + self.gamma * self.q_table[next_state, next_action] - self.q_table[state, action]
eligibility[state, action] += 1
self.q_table += self.lr * td_error * eligibility
eligibility *= self.gamma * self.lambda_
state = next_state
action = next_action
total_reward += reward
rewards_history.append(total_reward)
return rewards_history
def select_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.num_actions)
return np.argmax(self.q_table[state])
λ 的作用
- λ = 0:退化为单步 SARSA
- λ = 1:类似于蒙特卡洛方法,使用完整回合的回报
- 0 < λ < 1:平衡单步和多步更新
预期 SARSA
预期 SARSA 是 SARSA 的变体,使用下一状态所有动作的期望价值:
def expected_sarsa_update(q_table, state, action, reward, next_state, done,
epsilon, gamma, lr):
if done:
target = reward
else:
num_actions = q_table.shape[1]
best_action = np.argmax(q_table[next_state])
expected_value = 0
for a in range(num_actions):
if a == best_action:
prob = 1 - epsilon + epsilon / num_actions
else:
prob = epsilon / num_actions
expected_value += prob * q_table[next_state, a]
target = reward + gamma * expected_value
td_error = target - q_table[state, action]
q_table[state, action] += lr * td_error
return q_table
预期 SARSA 结合了 Q-Learning 和 SARSA 的优点:
- 比 SARSA 方差更小(使用了期望)
- 比 Q-Learning 更稳定(考虑了策略行为)
算法比较总结
| 算法 | 策略类型 | 更新目标 | 特点 |
|---|---|---|---|
| Q-Learning | 离线策略 | 学习最优策略,可能冒险 | |
| SARSA | 在线策略 | 学习当前策略,更安全 | |
| SARSA(λ) | 在线策略 | 多步回报 | 加速传播奖励 |
| 预期 SARSA | 离线策略 | 方差更小 |
完整示例:Wind Grid World
import numpy as np
class WindGridWorld:
def __init__(self, width=10, height=7):
self.width = width
self.height = height
self.start = (3, 0)
self.goal = (3, 7)
self.wind = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]
def reset(self):
self.state = self.start
return self._state_to_idx(self.state)
def _state_to_idx(self, state):
return state[0] * self.width + state[1]
def step(self, action):
row, col = self.state
if action == 0: # 上
row = max(0, row - 1)
elif action == 1: # 下
row = min(self.height - 1, row + 1)
elif action == 2: # 左
col = max(0, col - 1)
elif action == 3: # 右
col = min(self.width - 1, col + 1)
wind_effect = self.wind[col]
row = max(0, row - wind_effect)
self.state = (row, col)
done = self.state == self.goal
return self._state_to_idx(self.state), -1, done
env = WindGridWorld()
agent = SARSA(
num_states=env.width * env.height,
num_actions=4,
learning_rate=0.5,
discount_factor=0.99,
epsilon=0.1
)
rewards = agent.train(env, num_episodes=1000)
print(f"平均步数(最后100回合): {-np.mean(rewards[-100:]):.1f}")
小结
SARSA 是一种重要的在线策略强化学习算法:
- 核心思想:学习当前策略的价值函数
- 更新公式:
- 与 Q-Learning 的区别:考虑实际行为而非最优行为
- 适用场景:需要安全性、策略评估或稳定学习的场景
- 扩展:SARSA(λ)、预期 SARSA
下一章将介绍深度 Q 网络(DQN),它将 Q-Learning 与深度学习结合,能够处理高维状态空间的问题。