Python 测试
测试是软件开发中至关重要的环节,确保代码质量和功能正确性。本章将介绍 Python 测试的完整知识体系。
为什么需要测试?
测试的重要性
- 保证代码质量:及早发现和修复 bug
- 重构信心:有测试覆盖,重构时不怕破坏功能
- 文档作用:测试用例展示了代码的预期行为
- 设计驱动:编写测试帮助设计更好的 API
测试类型
┌─────────────────────────────────────────────────────────────┐
│ 测试金字塔 │
├─────────────────────────────────────────────────────────────┤
│ / \ │
│ / E2E测试 \ 数量少,执行慢 │
│ / \ │
│ / 集成测试 \ 数量中等 │
│ / \ │
│ / 单元测试 \ 数量多,执行快 │
│ / \ │
└─────────────────────────────────────────────────────────────┘
解释:
- 单元测试:测试单个函数或方法,执行速度快
- 集成测试:测试多个模块之间的交互
- 端到端测试(E2E):测试整个应用流程,最接近真实用户场景
unittest 模块
Python 内置的测试框架,基于 xUnit 架构。
基本用法
import unittest
# 被测试的函数
def add(a, b):
return a + b
def divide(a, b):
if b == 0:
raise ValueError("不能除以零")
return a / b
# 测试类
class TestMathFunctions(unittest.TestCase):
def test_add_positive_numbers(self):
"""测试正数相加"""
result = add(2, 3)
self.assertEqual(result, 5)
def test_add_negative_numbers(self):
"""测试负数相加"""
result = add(-2, -3)
self.assertEqual(result, -5)
def test_add_zero(self):
"""测试与零相加"""
self.assertEqual(add(5, 0), 5)
self.assertEqual(add(0, 5), 5)
def test_divide_normal(self):
"""测试正常除法"""
self.assertEqual(divide(10, 2), 5)
self.assertAlmostEqual(divide(7, 3), 2.333, places=2)
def test_divide_by_zero(self):
"""测试除以零抛出异常"""
with self.assertRaises(ValueError) as context:
divide(10, 0)
self.assertEqual(str(context.exception), "不能除以零")
# 运行测试
if __name__ == '__main__':
unittest.main()
解释:
- 测试类必须继承
unittest.TestCase - 测试方法必须以
test_开头 - 使用
assertEqual、assertRaises等断言方法
常用断言方法
import unittest
class TestAssertions(unittest.TestCase):
def test_equality(self):
"""相等性断言"""
self.assertEqual(1 + 1, 2)
self.assertNotEqual(1 + 1, 3)
def test_boolean(self):
"""布尔值断言"""
self.assertTrue(True)
self.assertFalse(False)
def test_comparison(self):
"""比较断言"""
self.assertGreater(5, 3)
self.assertLess(3, 5)
self.assertGreaterEqual(5, 5)
self.assertLessEqual(3, 5)
def test_membership(self):
"""成员断言"""
self.assertIn(1, [1, 2, 3])
self.assertNotIn(4, [1, 2, 3])
def test_type(self):
"""类型断言"""
self.assertIsInstance("hello", str)
self.assertNotIsInstance(123, str)
def test_none(self):
"""None 断言"""
self.assertIsNone(None)
self.assertIsNotNone("value")
def test_approximate(self):
"""近似值断言"""
self.assertAlmostEqual(0.1 + 0.2, 0.3, places=7)
def test_exception(self):
"""异常断言"""
with self.assertRaises(ValueError):
int("abc")
with self.assertRaisesRegex(ValueError, "invalid literal"):
int("abc")
测试固件(Fixture)
import unittest
import tempfile
import os
class TestFileOperations(unittest.TestCase):
"""文件操作测试示例"""
@classmethod
def setUpClass(cls):
"""整个测试类开始前执行一次"""
print("开始测试文件操作")
cls.test_dir = tempfile.mkdtemp()
@classmethod
def tearDownClass(cls):
"""整个测试类结束后执行一次"""
print("测试结束")
import shutil
shutil.rmtree(cls.test_dir)
def setUp(self):
"""每个测试方法执行前运行"""
self.test_file = os.path.join(self.test_dir, "test.txt")
with open(self.test_file, 'w') as f:
f.write("test content")
def tearDown(self):
"""每个测试方法执行后运行"""
if os.path.exists(self.test_file):
os.remove(self.test_file)
def test_read_file(self):
"""测试读取文件"""
with open(self.test_file, 'r') as f:
content = f.read()
self.assertEqual(content, "test content")
def test_file_exists(self):
"""测试文件存在"""
self.assertTrue(os.path.exists(self.test_file))
解释:
setUpClass/tearDownClass:类级别固件,整个测试类只执行一次setUp/tearDown:方法级别固件,每个测试方法前后执行
跳过测试
import unittest
import sys
class TestSkipping(unittest.TestCase):
@unittest.skip("跳过此测试")
def test_skip(self):
self.fail("不应该执行")
@unittest.skipIf(sys.version_info < (3, 10), "需要 Python 3.10+")
def test_skip_if(self):
# Python 3.10+ 特性
self.assertTrue(True)
@unittest.skipUnless(sys.platform == "linux", "仅 Linux 平台")
def test_skip_unless(self):
self.assertTrue(True)
@unittest.expectedFailure
def test_expected_failure(self):
"""预期失败的测试"""
self.assertEqual(1, 2) # 预期失败,标记为通过
if __name__ == '__main__':
unittest.main()
pytest 框架
pytest 是 Python 最流行的第三方测试框架,比 unittest 更简洁强大。
安装
pip install pytest pytest-cov
基本用法
# test_calc.py
def add(a, b):
"""被测试的函数"""
return a + b
def subtract(a, b):
return a - b
def multiply(a, b):
return a * b
def divide(a, b):
if b == 0:
raise ValueError("不能除以零")
return a / b
# ===== 测试函数 =====
def test_add():
"""测试加法"""
assert add(2, 3) == 5
assert add(-1, 1) == 0
assert add(0, 0) == 0
def test_subtract():
"""测试减法"""
assert subtract(5, 3) == 2
assert subtract(3, 5) == -2
def test_multiply():
"""测试乘法"""
assert multiply(2, 3) == 6
assert multiply(-2, 3) == -6
assert multiply(0, 100) == 0
def test_divide():
"""测试除法"""
assert divide(10, 2) == 5
# 测试异常
import pytest
with pytest.raises(ValueError) as excinfo:
divide(10, 0)
assert "不能除以零" in str(excinfo.value)
解释:
- pytest 使用简单的
assert语句,不需要特殊方法 - 测试文件名以
test_开头或_test.py结尾 - 测试函数名以
test_开头 - pytest 会自动发现并运行测试
运行测试
# 运行当前目录所有测试
pytest
# 运行指定文件
pytest test_calc.py
# 运行指定测试函数
pytest test_calc.py::test_add
# 运行指定测试类中的方法
pytest test_calc.py::TestMath::test_add
# 显示详细输出
pytest -v
# 显示 print 输出
pytest -s
# 只运行失败的测试
pytest --lf
# 先运行失败的,再运行其他的
pytest --ff
# 并行运行测试(需要 pytest-xdist)
pytest -n 4
# 生成覆盖率报告
pytest --cov=myapp --cov-report=html
pytest 断言
import pytest
def test_assertions():
"""pytest 断言示例"""
# 基本断言
assert 1 + 1 == 2
assert "hello" in "hello world"
assert 5 > 3
# 断言近似值
assert 0.1 + 0.2 == pytest.approx(0.3)
# 断言异常
with pytest.raises(ValueError):
int("abc")
# 断言异常消息
with pytest.raises(ValueError, match="invalid literal"):
int("abc")
# 断言警告
with pytest.warns(DeprecationWarning):
import warnings
warnings.warn("deprecated", DeprecationWarning)
def test_approx_comparison():
"""近似比较"""
# 浮点数近似
assert 0.1 + 0.2 == pytest.approx(0.3)
# 指定精度
assert 0.12345 == pytest.approx(0.123, abs=0.001)
# 列表近似
assert [0.1 + 0.2, 0.2 + 0.3] == pytest.approx([0.3, 0.5])
def test_exception_details():
"""异常详情测试"""
with pytest.raises(ValueError) as excinfo:
raise ValueError("错误消息")
assert excinfo.type is ValueError
assert "错误消息" in str(excinfo.value)
Fixture(测试固件)
pytest 的 fixture 比 unittest 更强大灵活。
基本 Fixture
import pytest
@pytest.fixture
def sample_data():
"""提供测试数据"""
return {"name": "张三", "age": 25}
def test_with_fixture(sample_data):
"""使用 fixture"""
assert sample_data["name"] == "张三"
assert sample_data["age"] == 25
Fixture 作用域
import pytest
# function 级别(默认):每个测试函数执行一次
@pytest.fixture(scope="function")
def func_fixture():
print("\n设置 function fixture")
yield "function"
print("\n清理 function fixture")
# class 级别:每个测试类执行一次
@pytest.fixture(scope="class")
def class_fixture():
print("\n设置 class fixture")
yield "class"
print("\n清理 class fixture")
# module 级别:每个模块执行一次
@pytest.fixture(scope="module")
def module_fixture():
print("\n设置 module fixture")
yield "module"
print("\n清理 module fixture")
# session 级别:整个测试会话执行一次
@pytest.fixture(scope="session")
def session_fixture():
print("\n设置 session fixture")
yield "session"
print("\n清理 session fixture")
def test_fixture_example(func_fixture, module_fixture):
"""测试 fixture 作用域"""
assert func_fixture == "function"
assert module_fixture == "module"
Fixture 清理操作
import pytest
import tempfile
import os
@pytest.fixture
def temp_file():
"""创建临时文件,测试后自动清理"""
# setup
fd, path = tempfile.mkstemp()
os.close(fd)
with open(path, 'w') as f:
f.write("test content")
# 提供给测试使用
yield path
# teardown(清理)
if os.path.exists(path):
os.remove(path)
def test_temp_file(temp_file):
"""测试使用临时文件"""
assert os.path.exists(temp_file)
with open(temp_file, 'r') as f:
content = f.read()
assert content == "test content"
内置 Fixture
import pytest
def test_tmp_path(tmp_path):
"""使用临时目录"""
# tmp_path 是 pathlib.Path 对象
test_file = tmp_path / "test.txt"
test_file.write_text("hello")
assert test_file.read_text() == "hello"
def test_tmp_path_factory(tmp_path_factory):
"""创建多个临时目录"""
dir1 = tmp_path_factory.mktemp("dir1")
dir2 = tmp_path_factory.mktemp("dir2")
assert dir1 != dir2
def test_capfd(capfd):
"""捕获标准输出"""
print("Hello, World!")
captured = capfd.readouterr()
assert captured.out == "Hello, World!\n"
def test_caplog(caplog):
"""捕获日志"""
import logging
logging.warning("warning message")
assert "warning message" in caplog.text
def test_monkeypatch(monkeypatch):
"""修改环境变量和属性"""
import os
# 设置环境变量
monkeypatch.setenv("TEST_VAR", "test_value")
assert os.environ["TEST_VAR"] == "test_value"
# 删除环境变量
monkeypatch.delenv("TEST_VAR", raising=False)
assert "TEST_VAR" not in os.environ
参数化测试
import pytest
# 单参数测试
@pytest.mark.parametrize("input,expected", [
(1, 2),
(2, 4),
(3, 6),
(10, 20),
])
def test_double(input, expected):
"""测试加倍函数"""
assert input * 2 == expected
# 多参数测试
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(5, 5, 10),
(-1, 1, 0),
(0, 0, 0),
])
def test_add(a, b, expected):
"""测试加法"""
assert a + b == expected
# 参数化标记
@pytest.mark.parametrize("value,expected", [
(1, 1),
pytest.param(2, 4, id="two_squared"),
pytest.param(3, 9, marks=pytest.mark.xfail(reason="故意失败")),
pytest.param(4, 16, marks=pytest.mark.skip("跳过")),
])
def test_square(value, expected):
assert value ** 2 == expected
# 组合参数化
@pytest.mark.parametrize("x", [1, 2])
@pytest.mark.parametrize("y", [10, 20])
def test_combination(x, y):
"""测试参数组合:会生成 4 个测试"""
assert x + y in [11, 12, 21, 22]
标记(Marks)
import pytest
# 跳过测试
@pytest.mark.skip(reason="尚未实现")
def test_not_implemented():
pass
# 条件跳过
@pytest.mark.skipif(sys.version_info < (3, 10), reason="需要 Python 3.10+")
def test_python310_feature():
pass
# 预期失败
@pytest.mark.xfail(reason="已知 bug")
def test_known_bug():
assert 1 == 2
# 自定义标记
@pytest.mark.slow
def test_slow_operation():
import time
time.sleep(1)
assert True
@pytest.mark.integration
def test_database():
"""集成测试"""
pass
# 运行指定标记的测试
# pytest -m slow 运行标记为 slow 的测试
# pytest -m "not slow" 运行未标记为 slow 的测试
# pytest -m "integration or slow"
测试类
import pytest
class TestCalculator:
"""计算器测试类"""
@pytest.fixture(autouse=True)
def setup(self):
"""每个测试方法前自动执行"""
self.calculator = Calculator()
def test_add(self):
assert self.calculator.add(1, 2) == 3
def test_subtract(self):
assert self.calculator.subtract(5, 3) == 2
@pytest.mark.parametrize("a,b,expected", [
(2, 3, 6),
(4, 5, 20),
])
def test_multiply(self, a, b, expected):
assert self.calculator.multiply(a, b) == expected
class Calculator:
def add(self, a, b): return a + b
def subtract(self, a, b): return a - b
def multiply(self, a, b): return a * b
Mock 和 Patch
使用 unittest.mock 模拟外部依赖,隔离测试。
基本 Mock
from unittest.mock import Mock, MagicMock
def test_mock_basic():
"""基本 Mock 用法"""
# 创建 Mock 对象
mock = Mock()
# 设置返回值
mock.method.return_value = 42
assert mock.method() == 42
# 验证调用
mock.method.assert_called_once()
# 验证调用参数
mock.method_with_args("hello")
mock.method_with_args.assert_called_with("hello")
# 检查调用次数
assert mock.method.call_count == 1
使用 patch
from unittest.mock import patch, MagicMock
import requests
# 被测试的函数
def get_user_info(user_id):
"""获取用户信息"""
response = requests.get(f"https://api.example.com/users/{user_id}")
return response.json()
# 测试
@patch('requests.get')
def test_get_user_info(mock_get):
"""测试获取用户信息"""
# 设置 mock 返回值
mock_response = MagicMock()
mock_response.json.return_value = {"id": 1, "name": "张三"}
mock_get.return_value = mock_response
# 调用函数
result = get_user_info(1)
# 验证结果
assert result == {"id": 1, "name": "张三"}
# 验证 requests.get 被正确调用
mock_get.assert_called_once_with("https://api.example.com/users/1")
# 使用上下文管理器
def test_with_context():
"""使用上下文管理器的 patch"""
with patch('requests.get') as mock_get:
mock_get.return_value.json.return_value = {"id": 2}
result = get_user_info(2)
assert result["id"] == 2
Mock 类方法
from unittest.mock import patch, MagicMock
class Database:
def query(self, sql):
"""真实的数据库查询"""
pass
class UserService:
def __init__(self, db):
self.db = db
def get_user(self, user_id):
return self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
def test_user_service():
"""测试用户服务"""
# 创建 mock 数据库
mock_db = MagicMock()
mock_db.query.return_value = {"id": 1, "name": "张三"}
# 注入 mock
service = UserService(mock_db)
result = service.get_user(1)
# 验证
assert result["name"] == "张三"
mock_db.query.assert_called_once_with("SELECT * FROM users WHERE id = 1")
@patch('__main__.Database')
def test_with_patch(MockDatabase):
"""使用 patch 替换类"""
mock_instance = MockDatabase.return_value
mock_instance.query.return_value = {"id": 1}
db = Database()
result = db.query("SELECT 1")
assert result["id"] == 1
pytest-mock 插件
# pip install pytest-mock
def test_with_mocker(mocker):
"""使用 pytest-mock"""
# mocker.patch 更简洁
mock_get = mocker.patch('requests.get')
mock_get.return_value.json.return_value = {"id": 1}
result = get_user_info(1)
assert result["id"] == 1
# mocker.spy 监视方法调用但不替换
# spy = mocker.spy(SomeClass, 'method')
测试覆盖率
使用 pytest-cov 测量代码覆盖率。
# 安装
pip install pytest-cov
# 运行测试并生成覆盖率报告
pytest --cov=myapp
# 生成 HTML 报告
pytest --cov=myapp --cov-report=html
# 指定覆盖率要求
pytest --cov=myapp --cov-fail-under=80
配置文件
# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --cov=myapp --cov-report=html --cov-fail-under=80
markers =
slow: 慢速测试
integration: 集成测试
# pyproject.toml
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
addopts = "-v --cov=myapp --cov-report=term-missing"
markers = [
"slow: marks tests as slow",
"integration: marks tests as integration tests",
]
[tool.coverage.run]
source = ["myapp"]
omit = ["tests/*", "*/__pycache__/*"]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise NotImplementedError",
"if __name__ == .__main__.:",
]
测试最佳实践
1. 测试命名规范
# 好的命名:描述性、清晰
def test_add_should_return_sum_of_two_numbers():
pass
def test_divide_should_raise_error_when_dividing_by_zero():
pass
def test_user_should_be_able_to_login_with_valid_credentials():
pass
# 不好的命名
def test_1():
pass
def test_something():
pass
2. AAA 模式
def test_with_aaa_pattern():
"""Arrange-Act-Assert 模式"""
# Arrange(准备):设置测试数据和条件
calculator = Calculator()
a, b = 2, 3
# Act(执行):调用被测试的代码
result = calculator.add(a, b)
# Assert(断言):验证结果
assert result == 5
3. 测试隔离
import pytest
import tempfile
import os
@pytest.fixture
def isolated_file_system():
"""隔离的文件系统环境"""
# 创建临时目录
temp_dir = tempfile.mkdtemp()
original_dir = os.getcwd()
os.chdir(temp_dir)
yield temp_dir
# 清理
os.chdir(original_dir)
import shutil
shutil.rmtree(temp_dir)
def test_file_operations(isolated_file_system):
"""测试文件操作,使用隔离环境"""
# 在临时目录中操作,不影响真实文件系统
with open("test.txt", "w") as f:
f.write("hello")
assert os.path.exists("test.txt")
4. 测试边界条件
import pytest
def process_age(age):
"""处理年龄"""
if not isinstance(age, int):
raise TypeError("年龄必须是整数")
if age < 0:
raise ValueError("年龄不能为负")
if age > 150:
raise ValueError("年龄不合理")
return age
class TestProcessAge:
"""测试年龄处理函数"""
def test_normal_age(self):
assert process_age(25) == 25
@pytest.mark.parametrize("age", [0, 1, 100, 149])
def test_valid_boundary(self, age):
"""测试有效边界"""
assert process_age(age) == age
@pytest.mark.parametrize("age", [-1, -100, 151, 200])
def test_invalid_boundary(self, age):
"""测试无效边界"""
with pytest.raises(ValueError):
process_age(age)
def test_non_integer(self):
"""测试非整数输入"""
with pytest.raises(TypeError):
process_age("25")
with pytest.raises(TypeError):
process_age(25.5)
5. 避免测试实现细节
# 不好的测试:测试实现细节
def test_internal_state():
user = User("张三")
assert user._internal_name == "张三" # 测试私有属性
# 好的测试:测试公共行为
def test_user_name():
user = User("张三")
assert user.get_name() == "张三" # 测试公共接口
测试实战示例
测试 Flask 应用
# app.py
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/api/users/<int:user_id>')
def get_user(user_id):
# 模拟数据库查询
users = {1: {"id": 1, "name": "张三"}, 2: {"id": 2, "name": "李四"}}
user = users.get(user_id)
if user:
return jsonify(user)
return jsonify({"error": "User not found"}), 404
# test_app.py
import pytest
from app import app
@pytest.fixture
def client():
"""创建测试客户端"""
app.config['TESTING'] = True
with app.test_client() as client:
yield client
def test_get_user_success(client):
"""测试获取用户成功"""
response = client.get('/api/users/1')
assert response.status_code == 200
data = response.get_json()
assert data["name"] == "张三"
def test_get_user_not_found(client):
"""测试用户不存在"""
response = client.get('/api/users/999')
assert response.status_code == 404
data = response.get_json()
assert "error" in data
测试数据库操作
import pytest
import sqlite3
from contextlib import contextmanager
# 被测试的代码
class UserRepository:
def __init__(self, db_path):
self.db_path = db_path
@contextmanager
def get_connection(self):
conn = sqlite3.connect(self.db_path)
try:
yield conn
finally:
conn.close()
def create_user(self, name, email):
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
(name, email)
)
conn.commit()
return cursor.lastrowid
def get_user(self, user_id):
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
return cursor.fetchone()
# 测试
@pytest.fixture
def temp_db():
"""创建临时数据库"""
import tempfile
fd, path = tempfile.mkstemp(suffix=".db")
# 创建表
conn = sqlite3.connect(path)
conn.execute("""
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL
)
""")
conn.commit()
conn.close()
yield path
# 清理
import os
os.close(fd)
os.unlink(path)
def test_create_and_get_user(temp_db):
"""测试创建和获取用户"""
repo = UserRepository(temp_db)
# 创建用户
user_id = repo.create_user("张三", "[email protected]")
assert user_id == 1
# 获取用户
user = repo.get_user(user_id)
assert user is not None
assert user[1] == "张三"
assert user[2] == "[email protected]"
def test_get_nonexistent_user(temp_db):
"""测试获取不存在的用户"""
repo = UserRepository(temp_db)
user = repo.get_user(999)
assert user is None
小结
本章我们学习了 Python 测试的完整知识体系:
- unittest:Python 内置测试框架
- pytest:更强大的第三方测试框架
- Fixture:测试固件,提供测试资源
- 参数化测试:使用不同参数运行同一测试
- Mock 和 Patch:模拟外部依赖
- 测试覆盖率:测量代码覆盖情况
- 最佳实践:命名规范、AAA 模式、测试隔离
练习
- 为一个计算器类编写完整的测试用例
- 使用 fixture 创建数据库测试环境
- 使用 mock 测试 HTTP 请求
- 使用参数化测试测试字符串处理函数
- 配置 pytest.ini 和覆盖率报告