跳到主要内容

SARSA 算法

SARSA(State-Action-Reward-State-Action)是一种在线策略的时序差分学习算法。与 Q-Learning 不同,SARSA 学习的是当前策略的价值函数,而不是最优策略的价值函数。

SARSA 的核心思想

SARSA 的名称来源于它的更新过程涉及五元组 (St,At,Rt+1,St+1,At+1)(S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}),即当前状态、当前动作、奖励、下一状态和下一动作。

在线策略 vs 离线策略

理解 SARSA 需要先区分在线策略和离线策略:

  • 在线策略(On-Policy):学习和执行的策略相同。SARSA 学习当前正在执行的策略的价值。
  • 离线策略(Off-Policy):学习和执行的策略可以不同。Q-Learning 学习最优策略的价值,即使执行的是探索策略。

为什么选择 SARSA?

SARSA 在某些场景下比 Q-Learning 更合适:

  1. 安全性考虑:SARSA 会考虑探索行为的风险
  2. 策略评估:需要评估特定策略的价值
  3. 稳定性:在线策略学习通常更稳定

SARSA 更新公式

SARSA 的更新公式:

Q(s,a)Q(s,a)+α[r+γQ(s,a)Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha [r + \gamma Q(s',a') - Q(s,a)]

与 Q-Learning 的区别:

  • Q-Learning:使用 maxaQ(s,a)\max_{a'} Q(s',a')(最优动作的价值)
  • SARSA:使用 Q(s,a)Q(s',a')(实际选择的下一动作的价值)

理解差异

这个看似微小的差异带来了重要的影响:

  1. Q-Learning 假设未来总是选择最优动作,即使当前策略会随机探索
  2. SARSA 考虑当前策略的实际行为,包括探索动作

SARSA 算法流程

算法伪代码

初始化 Q(s,a) 为任意值
对于每个回合:
初始化状态 s
根据策略选择动作 a
对于每个时间步:
执行动作 a,观察奖励 r 和下一状态 s'
根据策略选择下一动作 a'
Q(s,a) ← Q(s,a) + α[r + γ Q(s',a') - Q(s,a)]
s ← s', a ← a'
直到 s 是终止状态

Python 实现

import numpy as np

class SARSA:
def __init__(self, num_states, num_actions, learning_rate=0.1,
discount_factor=0.99, epsilon=0.1):
self.num_states = num_states
self.num_actions = num_actions
self.lr = learning_rate
self.gamma = discount_factor
self.epsilon = epsilon
self.q_table = np.zeros((num_states, num_actions))

def select_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.num_actions)
return np.argmax(self.q_table[state])

def update(self, state, action, reward, next_state, next_action, done):
current_q = self.q_table[state, action]

if done:
target = reward
else:
target = reward + self.gamma * self.q_table[next_state, next_action]

td_error = target - current_q
self.q_table[state, action] += self.lr * td_error

def train(self, env, num_episodes=1000):
rewards_history = []

for episode in range(num_episodes):
state = env.reset()
action = self.select_action(state)
total_reward = 0
done = False

while not done:
next_state, reward, done, _ = env.step(action)
next_action = self.select_action(next_state)
self.update(state, action, reward, next_state, next_action, done)
state = next_state
action = next_action
total_reward += reward

rewards_history.append(total_reward)

return rewards_history

Q-Learning vs SARSA:悬崖行走对比

悬崖行走是展示两种算法差异的经典例子:

import numpy as np
import matplotlib.pyplot as plt

class CliffWalking:
def __init__(self, width=12, height=4):
self.width = width
self.height = height
self.start = (height - 1, 0)
self.goal = (height - 1, width - 1)
self.cliff = [(height - 1, i) for i in range(1, width - 1)]
self.state = self.start

def reset(self):
self.state = self.start
return self._state_to_idx(self.state)

def _state_to_idx(self, state):
return state[0] * self.width + state[1]

def step(self, action):
row, col = self.state

if action == 0: # 上
row = max(0, row - 1)
elif action == 1: # 下
row = min(self.height - 1, row + 1)
elif action == 2: # 左
col = max(0, col - 1)
elif action == 3: # 右
col = min(self.width - 1, col + 1)

self.state = (row, col)

if self.state in self.cliff:
return self.reset(), -100, True
elif self.state == self.goal:
return self._state_to_idx(self.state), -1, True
else:
return self._state_to_idx(self.state), -1, False

def train_qlearning(env, num_episodes=500, epsilon=0.1):
q_table = np.zeros((env.width * env.height, 4))
rewards = []

for episode in range(num_episodes):
state = env.reset()
total_reward = 0
done = False

while not done:
if np.random.random() < epsilon:
action = np.random.randint(4)
else:
action = np.argmax(q_table[state])

next_state, reward, done = env.step(action)

target = reward + 0.99 * np.max(q_table[next_state]) * (1 - done)
q_table[state, action] += 0.5 * (target - q_table[state, action])

state = next_state
total_reward += reward

rewards.append(total_reward)

return q_table, rewards

def train_sarsa(env, num_episodes=500, epsilon=0.1):
q_table = np.zeros((env.width * env.height, 4))
rewards = []

for episode in range(num_episodes):
state = env.reset()
if np.random.random() < epsilon:
action = np.random.randint(4)
else:
action = np.argmax(q_table[state])

total_reward = 0
done = False

while not done:
next_state, reward, done = env.step(action)

if np.random.random() < epsilon:
next_action = np.random.randint(4)
else:
next_action = np.argmax(q_table[next_state])

if done:
target = reward
else:
target = reward + 0.99 * q_table[next_state, next_action]

q_table[state, action] += 0.5 * (target - q_table[state, action])

state = next_state
action = next_action
total_reward += reward

rewards.append(total_reward)

return q_table, rewards

env = CliffWalking()
q_table_q, rewards_q = train_qlearning(env)
q_table_sarsa, rewards_sarsa = train_sarsa(env)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(rewards_q, label='Q-Learning', alpha=0.7)
plt.plot(rewards_sarsa, label='SARSA', alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Learning Curves')
plt.legend()

plt.subplot(1, 2, 2)
plt.bar(['Q-Learning', 'SARSA'],
[np.mean(rewards_q[-50:]), np.mean(rewards_sarsa[-50:])])
plt.ylabel('Average Reward (last 50 episodes)')
plt.title('Final Performance')

plt.tight_layout()
plt.show()

结果分析

在悬崖行走环境中:

  • Q-Learning 学到的是沿着悬崖边缘走的"最优"路径,但在执行时由于探索行为可能掉入悬崖
  • SARSA 学到的是远离悬崖的安全路径,虽然路径更长但更稳定

SARSA(λ):资格迹

SARSA(λ) 是 SARSA 的扩展,通过资格迹实现多步更新。

前向视角

前向视角使用 n 步回报:

Gt(n)=Rt+1+γRt+2+...+γn1Rt+n+γnQ(St+n,At+n)G_t^{(n)} = R_{t+1} + \gamma R_{t+2} + ... + \gamma^{n-1} R_{t+n} + \gamma^n Q(S_{t+n}, A_{t+n})

λ-回报是不同 n 步回报的加权平均:

Gtλ=(1λ)n=1λn1Gt(n)G_t^\lambda = (1-\lambda)\sum_{n=1}^{\infty} \lambda^{n-1} G_t^{(n)}

后向视角

后向视角使用资格迹高效实现:

class SARSALambda:
def __init__(self, num_states, num_actions, lr=0.1, gamma=0.99,
epsilon=0.1, lambda_=0.8):
self.num_states = num_states
self.num_actions = num_actions
self.lr = lr
self.gamma = gamma
self.epsilon = epsilon
self.lambda_ = lambda_
self.q_table = np.zeros((num_states, num_actions))

def train(self, env, num_episodes=1000):
rewards_history = []

for episode in range(num_episodes):
eligibility = np.zeros((self.num_states, self.num_actions))
state = env.reset()
action = self.select_action(state)
total_reward = 0
done = False

while not done:
next_state, reward, done, _ = env.step(action)
next_action = self.select_action(next_state)

if done:
td_error = reward - self.q_table[state, action]
else:
td_error = reward + self.gamma * self.q_table[next_state, next_action] - self.q_table[state, action]

eligibility[state, action] += 1
self.q_table += self.lr * td_error * eligibility
eligibility *= self.gamma * self.lambda_

state = next_state
action = next_action
total_reward += reward

rewards_history.append(total_reward)

return rewards_history

def select_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.num_actions)
return np.argmax(self.q_table[state])

λ 的作用

  • λ = 0:退化为单步 SARSA
  • λ = 1:类似于蒙特卡洛方法,使用完整回合的回报
  • 0 < λ < 1:平衡单步和多步更新

预期 SARSA

预期 SARSA 是 SARSA 的变体,使用下一状态所有动作的期望价值:

Q(s,a)Q(s,a)+α[r+γaπ(as)Q(s,a)Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha [r + \gamma \sum_{a'} \pi(a'|s') Q(s',a') - Q(s,a)]

def expected_sarsa_update(q_table, state, action, reward, next_state, done, 
epsilon, gamma, lr):
if done:
target = reward
else:
num_actions = q_table.shape[1]
best_action = np.argmax(q_table[next_state])
expected_value = 0
for a in range(num_actions):
if a == best_action:
prob = 1 - epsilon + epsilon / num_actions
else:
prob = epsilon / num_actions
expected_value += prob * q_table[next_state, a]
target = reward + gamma * expected_value

td_error = target - q_table[state, action]
q_table[state, action] += lr * td_error
return q_table

预期 SARSA 结合了 Q-Learning 和 SARSA 的优点:

  • 比 SARSA 方差更小(使用了期望)
  • 比 Q-Learning 更稳定(考虑了策略行为)

算法比较总结

算法策略类型更新目标特点
Q-Learning离线策略maxaQ(s,a)\max_{a'} Q(s',a')学习最优策略,可能冒险
SARSA在线策略Q(s,a)Q(s',a')学习当前策略,更安全
SARSA(λ)在线策略多步回报加速传播奖励
预期 SARSA离线策略Eπ[Q(s,a)]\mathbb{E}_\pi[Q(s',a')]方差更小

完整示例:Wind Grid World

import numpy as np

class WindGridWorld:
def __init__(self, width=10, height=7):
self.width = width
self.height = height
self.start = (3, 0)
self.goal = (3, 7)
self.wind = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]

def reset(self):
self.state = self.start
return self._state_to_idx(self.state)

def _state_to_idx(self, state):
return state[0] * self.width + state[1]

def step(self, action):
row, col = self.state

if action == 0: # 上
row = max(0, row - 1)
elif action == 1: # 下
row = min(self.height - 1, row + 1)
elif action == 2: # 左
col = max(0, col - 1)
elif action == 3: # 右
col = min(self.width - 1, col + 1)

wind_effect = self.wind[col]
row = max(0, row - wind_effect)

self.state = (row, col)
done = self.state == self.goal
return self._state_to_idx(self.state), -1, done

env = WindGridWorld()
agent = SARSA(
num_states=env.width * env.height,
num_actions=4,
learning_rate=0.5,
discount_factor=0.99,
epsilon=0.1
)

rewards = agent.train(env, num_episodes=1000)
print(f"平均步数(最后100回合): {-np.mean(rewards[-100:]):.1f}")

小结

SARSA 是一种重要的在线策略强化学习算法:

  • 核心思想:学习当前策略的价值函数
  • 更新公式Q(s,a)Q(s,a)+α[r+γQ(s,a)Q(s,a)]Q(s,a) \leftarrow Q(s,a) + \alpha [r + \gamma Q(s',a') - Q(s,a)]
  • 与 Q-Learning 的区别:考虑实际行为而非最优行为
  • 适用场景:需要安全性、策略评估或稳定学习的场景
  • 扩展:SARSA(λ)、预期 SARSA

下一章将介绍深度 Q 网络(DQN),它将 Q-Learning 与深度学习结合,能够处理高维状态空间的问题。