Gymnasium 环境接口
Gymnasium 是 OpenAI Gym 的维护分支,提供了标准化的强化学习环境接口。它包含了丰富的预定义环境,并支持自定义环境的开发。掌握 Gymnasium 是进行强化学习实践的基础。Gymnasium 由 Farama Foundation 维护,是目前强化学习社区最广泛使用的环境库。
Gymnasium 简介
什么是 Gymnasium?
Gymnasium 是一个用于开发和比较强化学习算法的工具包。它提供了:
- 统一的环境接口:所有环境都遵循相同的 API,便于算法迁移
- 丰富的预定义环境:从简单的经典控制到复杂的机器人仿真
- 环境包装器系统:灵活地修改环境行为
- 自定义环境开发支持:轻松创建自己的环境
从 Gym 到 Gymnasium
Gymnasium 是 OpenAI Gym 的维护分支。主要变化包括:
- API 变化:
reset()返回(observation, info),step()返回五元组 - 更好的类型提示:完整的类型注解支持
- 更严格的规范:环境必须符合接口规范
- 持续维护:积极修复 bug 和添加新功能
安装
pip install gymnasium
pip install gymnasium[atari]
pip install gymnasium[box2d]
pip install gymnasium[mujoco]
pip install gymnasium[all]
基本使用
创建环境
import gymnasium as gym
env = gym.make('CartPole-v1')
env = gym.make('CartPole-v1', render_mode='human')
env = gym.make('CartPole-v1', render_mode='rgb_array')
环境交互循环
强化学习的基本交互模式如下:
import gymnasium as gym
env = gym.make('CartPole-v1', render_mode='human')
observation, info = env.reset(seed=42)
for _ in range(1000):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
observation, info = env.reset()
env.close()
核心方法详解
reset() 方法
重置环境到初始状态。
observation, info = env.reset(seed=42, options={})
参数:
seed(可选):随机种子,用于可重复性options(可选):环境特定的配置选项
返回值:
observation:初始观测info:辅助信息字典
obs, info = env.reset(seed=42)
print(f"初始观测: {obs}")
print(f"信息: {info}")
obs, info = env.reset(options={'random_start': True})
step() 方法
执行一个动作并返回结果。
observation, reward, terminated, truncated, info = env.step(action)
参数:
action:要执行的动作,必须属于动作空间
返回值:
observation:新的观测reward:即时奖励terminated:是否达到终止状态(如游戏结束、成功)truncated:是否被截断(如超过最大步数)info:辅助信息字典
terminated vs truncated 的区别:
| 情况 | terminated | truncated | 说明 |
|---|---|---|---|
| 游戏胜利 | True | False | 达到目标状态 |
| 游戏失败 | True | False | 如掉入陷阱 |
| 超时 | False | True | 超过最大步数 |
| 正常进行 | False | False | 继续游戏 |
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if terminated:
print("游戏结束(自然终止)")
elif truncated:
print("游戏结束(超时截断)")
render() 方法
渲染环境画面。
env = gym.make('CartPole-v1', render_mode='human')
env = gym.make('CartPole-v1', render_mode='rgb_array')
frame = env.render()
env = gym.make('CartPole-v1', render_mode='rgb_array_list')
frames = env.render()
close() 方法
关闭环境,释放资源。
env.close()
完整的交互示例
import gymnasium as gym
import numpy as np
def run_episode(env_name='CartPole-v1', max_steps=500, render=False):
render_mode = 'human' if render else None
env = gym.make(env_name, render_mode=render_mode)
obs, info = env.reset(seed=42)
total_reward = 0
steps = 0
for step in range(max_steps):
action = env.action_space.sample()
next_obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
steps += 1
obs = next_obs
if terminated or truncated:
break
env.close()
return {
'total_reward': total_reward,
'steps': steps,
'terminated': terminated,
'truncated': truncated
}
result = run_episode(render=True)
print(f"总奖励: {result['total_reward']}")
print(f"步数: {result['steps']}")
print(f"终止原因: {'自然终止' if result['terminated'] else '截断'}")
环境空间
观测空间(Observation Space)
观测空间定义了环境返回的观测值的格式和范围。
import gymnasium as gym
env = gym.make('CartPole-v1')
print(f"观测空间: {env.observation_space}")
print(f"形状: {env.observation_space.shape}")
print(f"数据类型: {env.observation_space.dtype}")
print(f"下界: {env.observation_space.low}")
print(f"上界: {env.observation_space.high}")
obs, _ = env.reset()
print(f"观测值是否在空间内: {env.observation_space.contains(obs)}")
动作空间(Action Space)
动作空间定义了智能体可以执行的动作的格式。
env = gym.make('CartPole-v1')
print(f"动作空间: {env.action_space}")
print(f"动作数量: {env.action_space.n}")
env = gym.make('Pendulum-v1')
print(f"动作空间: {env.action_space}")
print(f"形状: {env.action_space.shape}")
print(f"范围: [{env.action_space.low}, {env.action_space.high}]")
空间类型详解
Box 空间
用于表示连续空间,如向量、矩阵、图像。
from gymnasium import spaces
import numpy as np
box = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
print(f"采样: {box.sample()}")
image_space = spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8)
print(f"图像空间形状: {image_space.shape}")
unbounded = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)
print(f"无界空间: {unbounded}")
Discrete 空间
用于表示离散动作选择。
from gymnasium import spaces
discrete = spaces.Discrete(4)
print(f"动作数量: {discrete.n}")
print(f"采样: {discrete.sample()}")
print(f"包含 2: {discrete.contains(2)}")
print(f"包含 5: {discrete.contains(5)}")
MultiDiscrete 空间
用于表示多个离散动作。
from gymnasium import spaces
multi_discrete = spaces.MultiDiscrete([3, 4, 2])
print(f"各维度动作数: {multi_discrete.nvec}")
print(f"采样: {multi_discrete.sample()}")
MultiBinary 空间
用于表示多个二进制选择。
from gymnasium import spaces
multi_binary = spaces.MultiBinary(5)
print(f"形状: {multi_binary.shape}")
print(f"采样: {multi_binary.sample()}")
Dict 空间
用于组合多种类型的观测。
from gymnasium import spaces
dict_space = spaces.Dict({
'position': spaces.Box(low=-10, high=10, shape=(2,), dtype=np.float32),
'velocity': spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32),
'target': spaces.Discrete(4),
})
print(f"采样: {dict_space.sample()}")
print(f"位置空间: {dict_space['position']}")
Tuple 空间
用于组合多个空间。
from gymnasium import spaces
tuple_space = spaces.Tuple((
spaces.Discrete(2),
spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32),
))
print(f"采样: {tuple_space.sample()}")
空间类型总结
| 空间类型 | 说明 | 典型用途 |
|---|---|---|
Box | 连续空间,n维数组 | 图像、位置坐标、速度 |
Discrete | 离散空间,整数 0 到 n-1 | 动作选择 |
MultiDiscrete | 多维离散空间 | 多个离散动作 |
MultiBinary | 多维二进制空间 | 多个开关、特征向量 |
Dict | 字典空间 | 多种观测组合 |
Tuple | 元组空间 | 多个空间组合 |
经典环境详解
经典控制任务
CartPole(倒立摆)
目标:通过左右移动小车,保持杆子直立。
import gymnasium as gym
env = gym.make('CartPole-v1')
print(f"观测空间: {env.observation_space}")
print(f"动作空间: {env.action_space}")
obs, _ = env.reset()
cart_pos, cart_vel, pole_angle, pole_vel = obs
print(f"小车位置: {cart_pos:.3f}")
print(f"小车速度: {cart_vel:.3f}")
print(f"杆子角度: {pole_angle:.3f}")
print(f"杆子角速度: {pole_vel:.3f}")
观测含义:
| 索引 | 观测 | 范围 | 说明 |
|---|---|---|---|
| 0 | 小车位置 | -4.8 ~ 4.8 | 小车在轨道上的位置 |
| 1 | 小车速度 | -Inf ~ Inf | 小车的运动速度 |
| 2 | 杆子角度 | -0.418 ~ 0.418 | 杆子与垂直方向的夹角 |
| 3 | 杆子角速度 | -Inf ~ Inf | 杆子转动的角速度 |
终止条件:
- 杆子角度超过 ±12°
- 小车位置超过 ±2.4
- 达到 500 步(v1 版本)
LunarLander(月球着陆器)
目标:控制飞船安全着陆到指定区域。
env = gym.make('LunarLander-v2')
print(f"观测空间: {env.observation_space}")
print(f"动作空间: {env.action_space}")
动作含义:
| 动作 | 说明 |
|---|---|
| 0 | 不操作 |
| 1 | 左推进器 |
| 2 | 主推进器 |
| 3 | 右推进器 |
Atari 游戏
Atari 游戏是深度强化学习的经典测试平台。
import gymnasium as gym
env = gym.make('ALE/Breakout-v5')
print(f"观测空间: {env.observation_space}")
print(f"动作空间: {env.action_space}")
from gymnasium.wrappers import AtariPreprocessing
env = gym.make('ALE/Breakout-v5')
env = AtariPreprocessing(
env,
noop_max=30,
frame_skip=4,
screen_size=84,
grayscale_obs=True,
grayscale_newaxis=False,
scale_obs=False
)
常用 Atari 预处理:
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation
def make_atari_env(env_id, render_mode=None):
env = gym.make(env_id, render_mode=render_mode)
env = AtariPreprocessing(
env,
frame_skip=4,
screen_size=84,
grayscale_obs=True,
)
env = FrameStackObservation(env, stack_size=4)
return env
env = make_atari_env('ALE/Breakout-v5')
print(f"处理后观测空间: {env.observation_space}")
MuJoCo 环境
MuJoCo 提供了高精度的物理仿真环境。
import gymnasium as gym
env = gym.make('HalfCheetah-v4')
print(f"观测空间: {env.observation_space}")
print(f"动作空间: {env.action_space}")
环境包装器
包装器是修改环境行为的强大工具,可以在不修改环境代码的情况下改变环境的行为。
包装器的类型
gym.Wrapper
├── gym.ObservationWrapper
├── gym.ActionWrapper
└── gym.RewardWrapper
常用内置包装器
TimeLimit
限制每个回合的最大步数。
from gymnasium.wrappers import TimeLimit
env = gym.make('CartPole-v1')
env = TimeLimit(env, max_episode_steps=200)
RecordVideo
录制视频。
from gymnasium.wrappers import RecordVideo
env = gym.make('CartPole-v1', render_mode='rgb_array')
env = RecordVideo(
env,
video_folder='./videos',
episode_trigger=lambda x: x % 10 == 0,
name_prefix='cartpole'
)
NormalizeObservation
归一化观测。
from gymnasium.wrappers import NormalizeObservation
env = gym.make('CartPole-v1')
env = NormalizeObservation(env)
NormalizeReward
归一化奖励。
from gymnasium.wrappers import NormalizeReward
env = gym.make('CartPole-v1')
env = NormalizeReward(env, gamma=0.99)
FrameStackObservation
堆叠多帧观测。
from gymnasium.wrappers import FrameStackObservation
env = gym.make('CartPole-v1')
env = FrameStackObservation(env, stack_size=4)
print(f"堆叠后观测空间: {env.observation_space}")
TransformObservation
自定义观测变换。
from gymnasium.wrappers import TransformObservation
import numpy as np
env = gym.make('CartPole-v1')
env = TransformObservation(
env,
f=lambda obs: obs.astype(np.float32) / 10.0,
observation_space=env.observation_space
)
自定义包装器
基础包装器
import gymnasium as gym
class CustomWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.total_steps = 0
def reset(self, **kwargs):
self.total_steps = 0
return self.env.reset(**kwargs)
def step(self, action):
self.total_steps += 1
obs, reward, terminated, truncated, info = self.env.step(action)
info['total_steps'] = self.total_steps
return obs, reward, terminated, truncated, info
env = gym.make('CartPole-v1')
env = CustomWrapper(env)
观测包装器
class ScaleObservation(gym.ObservationWrapper):
def __init__(self, env, scale=1.0):
super().__init__(env)
self.scale = scale
self.observation_space = gym.spaces.Box(
low=env.observation_space.low * scale,
high=env.observation_space.high * scale,
shape=env.observation_space.shape,
dtype=np.float32
)
def observation(self, obs):
return obs * self.scale
动作包装器
class StickyAction(gym.ActionWrapper):
def __init__(self, env, sticky_prob=0.25):
super().__init__(env)
self.sticky_prob = sticky_prob
self.last_action = 0
def action(self, action):
if self.np_random.random() < self.sticky_prob:
return self.last_action
self.last_action = action
return action
奖励包装器
class ClipReward(gym.RewardWrapper):
def __init__(self, env, min_reward=-1.0, max_reward=1.0):
super().__init__(env)
self.min_reward = min_reward
self.max_reward = max_reward
def reward(self, reward):
return float(np.clip(reward, self.min_reward, self.max_reward))
class RewardScaler(gym.RewardWrapper):
def __init__(self, env, scale=0.01):
super().__init__(env)
self.scale = scale
def reward(self, reward):
return reward * self.scale
组合多个包装器
import gymnasium as gym
from gymnasium.wrappers import (
TimeLimit, NormalizeObservation, NormalizeReward,
RecordVideo, TransformReward
)
def make_wrapped_env(env_name='CartPole-v1', record_video=False):
env = gym.make(env_name, render_mode='rgb_array' if record_video else None)
env = TimeLimit(env, max_episode_steps=500)
env = NormalizeObservation(env)
env = NormalizeReward(env, gamma=0.99)
if record_video:
env = RecordVideo(env, video_folder='./videos')
return env
自定义环境
创建自定义环境
创建自定义环境需要继承 gym.Env 并实现必要的方法。
import gymnasium as gym
import numpy as np
class GridWorldEnv(gym.Env):
metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': 4}
def __init__(self, render_mode=None, size=5):
self.size = size
self.render_mode = render_mode
self.window_size = 512
self.window = None
self.clock = None
self.observation_space = gym.spaces.Dict({
'agent': gym.spaces.Box(0, size - 1, shape=(2,), dtype=int),
'target': gym.spaces.Box(0, size - 1, shape=(2,), dtype=int),
})
self.action_space = gym.spaces.Discrete(4)
self._action_to_direction = {
0: np.array([1, 0]),
1: np.array([0, 1]),
2: np.array([-1, 0]),
3: np.array([0, -1]),
}
def _get_obs(self):
return {'agent': self._agent_location, 'target': self._target_location}
def _get_info(self):
return {
'distance': np.linalg.norm(
self._agent_location - self._target_location, ord=1
)
}
def reset(self, seed=None, options=None):
super().reset(seed=seed)
self._agent_location = self.np_random.integers(0, self.size, size=2)
self._target_location = self._agent_location
while np.array_equal(self._target_location, self._agent_location):
self._target_location = self.np_random.integers(0, self.size, size=2)
observation = self._get_obs()
info = self._get_info()
if self.render_mode == 'human':
self._render_frame()
return observation, info
def step(self, action):
direction = self._action_to_direction[action]
self._agent_location = np.clip(
self._agent_location + direction, 0, self.size - 1
)
terminated = np.array_equal(self._agent_location, self._target_location)
reward = 1 if terminated else 0
observation = self._get_obs()
info = self._get_info()
if self.render_mode == 'human':
self._render_frame()
return observation, reward, terminated, False, info
def render(self):
if self.render_mode == 'rgb_array':
return self._render_frame()
def _render_frame(self):
if self.window is None and self.render_mode == 'human':
import pygame
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode(
(self.window_size, self.window_size)
)
if self.clock is None and self.render_mode == 'human':
import pygame
self.clock = pygame.time.Clock()
canvas = pygame.Surface((self.window_size, self.window_size))
canvas.fill((255, 255, 255))
pix_square_size = self.window_size / self.size
pygame.draw.rect(
canvas,
(255, 0, 0),
pygame.Rect(
pix_square_size * self._target_location,
(pix_square_size, pix_square_size),
),
)
pygame.draw.circle(
canvas,
(0, 0, 255),
(self._agent_location + 0.5) * pix_square_size,
pix_square_size / 3,
)
for x in range(self.size + 1):
pygame.draw.line(
canvas,
0,
(0, pix_square_size * x),
(self.window_size, pix_square_size * x),
width=3,
)
pygame.draw.line(
canvas,
0,
(pix_square_size * x, 0),
(pix_square_size * x, self.window_size),
width=3,
)
if self.render_mode == 'human':
self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata['render_fps'])
return np.transpose(
np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
)
def close(self):
if self.window is not None:
import pygame
pygame.display.quit()
pygame.quit()
注册自定义环境
from gymnasium.envs.registration import register
register(
id='GridWorld-v0',
entry_point='my_env:GridWorldEnv',
max_episode_steps=300,
)
env = gym.make('GridWorld-v0')
向量化环境
向量化环境可以并行运行多个环境实例,提高训练效率。
SyncVectorEnv
同步向量化环境,顺序执行所有环境。
import gymnasium as gym
from gymnasium.vector import SyncVectorEnv
def make_env(env_id, seed=0):
def thunk():
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
return env
return thunk
envs = SyncVectorEnv([
make_env('CartPole-v1', seed=i) for i in range(4)
])
observations, infos = envs.reset()
print(f"观测形状: {observations.shape}")
actions = envs.action_space.sample()
observations, rewards, terminateds, truncateds, infos = envs.step(actions)
print(f"奖励: {rewards}")
AsyncVectorEnv
异步向量化环境,使用多进程并行执行。
from gymnasium.vector import AsyncVectorEnv
envs = AsyncVectorEnv([
make_env('CartPole-v1', seed=i) for i in range(8)
])
observations, infos = envs.reset()
actions = envs.action_space.sample()
observations, rewards, terminateds, truncateds, infos = envs.step(actions)
make_vec_env
简化向量化环境的创建。
from gymnasium.vector import make_vec_env
envs = make_vec_env('CartPole-v1', n_envs=4, parallel=True)
向量化环境的特殊处理
envs = AsyncVectorEnv([make_env('CartPole-v1', i) for i in range(4)])
observations, infos = envs.reset()
actions = np.array([0, 1, 0, 1])
observations, rewards, terminateds, truncateds, infos = envs.step(actions)
if any(terminateds) or any(truncateds):
final_observations = infos['final_observation']
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
if terminated or truncated:
print(f"环境 {i} 结束,最终观测: {final_observations[i]}")
环境检查
Gymnasium 提供了环境检查工具,确保自定义环境符合规范。
from gymnasium.utils.env_checker import check_env
env = GridWorldEnv()
check_env(env, warn=True)
print("环境检查通过!")
检查内容包括:
reset()返回正确格式step()返回正确格式- 观测和动作空间定义正确
- 奖励是数值类型
- 渲染方法正常工作
实用技巧
环境信息查看
import gymnasium as gym
env = gym.make('CartPole-v1')
print(f"环境 ID: {env.spec.id}")
print(f"最大步数: {env.spec.max_episode_steps}")
print(f"奖励阈值: {env.spec.reward_threshold}")
print(f"是否确定性: {env.spec.nondeterministic}")
可重复实验
import gymnasium as gym
import numpy as np
import random
import torch
def set_seed(seed=42):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
env = gym.make('CartPole-v1')
obs, _ = env.reset(seed=42)
环境监控
from gymnasium.wrappers import RecordEpisodeStatistics
env = gym.make('CartPole-v1')
env = RecordEpisodeStatistics(env)
for episode in range(10):
obs, _ = env.reset()
done = False
while not done:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
print(f"回合 {episode + 1}: 奖励 = {info['episode']['r']:.2f}, "
f"步数 = {info['episode']['l']}, "
f"时间 = {info['episode']['t']:.2f}s")
小结
Gymnasium 是强化学习实践的基础工具:
- 统一接口:
reset()、step()、render()、close() - 空间定义:
Box、Discrete、MultiDiscrete、Dict等 - 环境包装器:灵活修改环境行为的强大工具
- 自定义环境:继承
gym.Env创建自己的环境 - 向量化环境:并行运行多个环境实例
下一章将介绍 Stable Baselines3,它提供了可靠的强化学习算法实现,可以与 Gymnasium 环境无缝配合。