跳到主要内容

Python 测试

测试是软件开发中至关重要的环节,确保代码质量和功能正确性。本章将介绍 Python 测试的完整知识体系。

为什么需要测试?

测试的重要性

  1. 保证代码质量:及早发现和修复 bug
  2. 重构信心:有测试覆盖,重构时不怕破坏功能
  3. 文档作用:测试用例展示了代码的预期行为
  4. 设计驱动:编写测试帮助设计更好的 API

测试类型

┌─────────────────────────────────────────────────────────────┐
│ 测试金字塔 │
├─────────────────────────────────────────────────────────────┤
│ / \ │
│ / E2E测试 \ 数量少,执行慢 │
│ / \ │
│ / 集成测试 \ 数量中等 │
│ / \ │
│ / 单元测试 \ 数量多,执行快 │
│ / \ │
└─────────────────────────────────────────────────────────────┘

解释

  • 单元测试:测试单个函数或方法,执行速度快
  • 集成测试:测试多个模块之间的交互
  • 端到端测试(E2E):测试整个应用流程,最接近真实用户场景

unittest 模块

Python 内置的测试框架,基于 xUnit 架构。

基本用法

import unittest

# 被测试的函数
def add(a, b):
return a + b

def divide(a, b):
if b == 0:
raise ValueError("不能除以零")
return a / b

# 测试类
class TestMathFunctions(unittest.TestCase):

def test_add_positive_numbers(self):
"""测试正数相加"""
result = add(2, 3)
self.assertEqual(result, 5)

def test_add_negative_numbers(self):
"""测试负数相加"""
result = add(-2, -3)
self.assertEqual(result, -5)

def test_add_zero(self):
"""测试与零相加"""
self.assertEqual(add(5, 0), 5)
self.assertEqual(add(0, 5), 5)

def test_divide_normal(self):
"""测试正常除法"""
self.assertEqual(divide(10, 2), 5)
self.assertAlmostEqual(divide(7, 3), 2.333, places=2)

def test_divide_by_zero(self):
"""测试除以零抛出异常"""
with self.assertRaises(ValueError) as context:
divide(10, 0)
self.assertEqual(str(context.exception), "不能除以零")

# 运行测试
if __name__ == '__main__':
unittest.main()

解释

  • 测试类必须继承 unittest.TestCase
  • 测试方法必须以 test_ 开头
  • 使用 assertEqualassertRaises 等断言方法

常用断言方法

import unittest

class TestAssertions(unittest.TestCase):

def test_equality(self):
"""相等性断言"""
self.assertEqual(1 + 1, 2)
self.assertNotEqual(1 + 1, 3)

def test_boolean(self):
"""布尔值断言"""
self.assertTrue(True)
self.assertFalse(False)

def test_comparison(self):
"""比较断言"""
self.assertGreater(5, 3)
self.assertLess(3, 5)
self.assertGreaterEqual(5, 5)
self.assertLessEqual(3, 5)

def test_membership(self):
"""成员断言"""
self.assertIn(1, [1, 2, 3])
self.assertNotIn(4, [1, 2, 3])

def test_type(self):
"""类型断言"""
self.assertIsInstance("hello", str)
self.assertNotIsInstance(123, str)

def test_none(self):
"""None 断言"""
self.assertIsNone(None)
self.assertIsNotNone("value")

def test_approximate(self):
"""近似值断言"""
self.assertAlmostEqual(0.1 + 0.2, 0.3, places=7)

def test_exception(self):
"""异常断言"""
with self.assertRaises(ValueError):
int("abc")

with self.assertRaisesRegex(ValueError, "invalid literal"):
int("abc")

测试固件(Fixture)

import unittest
import tempfile
import os

class TestFileOperations(unittest.TestCase):
"""文件操作测试示例"""

@classmethod
def setUpClass(cls):
"""整个测试类开始前执行一次"""
print("开始测试文件操作")
cls.test_dir = tempfile.mkdtemp()

@classmethod
def tearDownClass(cls):
"""整个测试类结束后执行一次"""
print("测试结束")
import shutil
shutil.rmtree(cls.test_dir)

def setUp(self):
"""每个测试方法执行前运行"""
self.test_file = os.path.join(self.test_dir, "test.txt")
with open(self.test_file, 'w') as f:
f.write("test content")

def tearDown(self):
"""每个测试方法执行后运行"""
if os.path.exists(self.test_file):
os.remove(self.test_file)

def test_read_file(self):
"""测试读取文件"""
with open(self.test_file, 'r') as f:
content = f.read()
self.assertEqual(content, "test content")

def test_file_exists(self):
"""测试文件存在"""
self.assertTrue(os.path.exists(self.test_file))

解释

  • setUpClass/tearDownClass:类级别固件,整个测试类只执行一次
  • setUp/tearDown:方法级别固件,每个测试方法前后执行

跳过测试

import unittest
import sys

class TestSkipping(unittest.TestCase):

@unittest.skip("跳过此测试")
def test_skip(self):
self.fail("不应该执行")

@unittest.skipIf(sys.version_info < (3, 10), "需要 Python 3.10+")
def test_skip_if(self):
# Python 3.10+ 特性
self.assertTrue(True)

@unittest.skipUnless(sys.platform == "linux", "仅 Linux 平台")
def test_skip_unless(self):
self.assertTrue(True)

@unittest.expectedFailure
def test_expected_failure(self):
"""预期失败的测试"""
self.assertEqual(1, 2) # 预期失败,标记为通过

if __name__ == '__main__':
unittest.main()

pytest 框架

pytest 是 Python 最流行的第三方测试框架,比 unittest 更简洁强大。

安装

pip install pytest pytest-cov

基本用法

# test_calc.py

def add(a, b):
"""被测试的函数"""
return a + b

def subtract(a, b):
return a - b

def multiply(a, b):
return a * b

def divide(a, b):
if b == 0:
raise ValueError("不能除以零")
return a / b


# ===== 测试函数 =====

def test_add():
"""测试加法"""
assert add(2, 3) == 5
assert add(-1, 1) == 0
assert add(0, 0) == 0

def test_subtract():
"""测试减法"""
assert subtract(5, 3) == 2
assert subtract(3, 5) == -2

def test_multiply():
"""测试乘法"""
assert multiply(2, 3) == 6
assert multiply(-2, 3) == -6
assert multiply(0, 100) == 0

def test_divide():
"""测试除法"""
assert divide(10, 2) == 5

# 测试异常
import pytest
with pytest.raises(ValueError) as excinfo:
divide(10, 0)
assert "不能除以零" in str(excinfo.value)

解释

  • pytest 使用简单的 assert 语句,不需要特殊方法
  • 测试文件名以 test_ 开头或 _test.py 结尾
  • 测试函数名以 test_ 开头
  • pytest 会自动发现并运行测试

运行测试

# 运行当前目录所有测试
pytest

# 运行指定文件
pytest test_calc.py

# 运行指定测试函数
pytest test_calc.py::test_add

# 运行指定测试类中的方法
pytest test_calc.py::TestMath::test_add

# 显示详细输出
pytest -v

# 显示 print 输出
pytest -s

# 只运行失败的测试
pytest --lf

# 先运行失败的,再运行其他的
pytest --ff

# 并行运行测试(需要 pytest-xdist)
pytest -n 4

# 生成覆盖率报告
pytest --cov=myapp --cov-report=html

pytest 断言

import pytest

def test_assertions():
"""pytest 断言示例"""
# 基本断言
assert 1 + 1 == 2
assert "hello" in "hello world"
assert 5 > 3

# 断言近似值
assert 0.1 + 0.2 == pytest.approx(0.3)

# 断言异常
with pytest.raises(ValueError):
int("abc")

# 断言异常消息
with pytest.raises(ValueError, match="invalid literal"):
int("abc")

# 断言警告
with pytest.warns(DeprecationWarning):
import warnings
warnings.warn("deprecated", DeprecationWarning)

def test_approx_comparison():
"""近似比较"""
# 浮点数近似
assert 0.1 + 0.2 == pytest.approx(0.3)

# 指定精度
assert 0.12345 == pytest.approx(0.123, abs=0.001)

# 列表近似
assert [0.1 + 0.2, 0.2 + 0.3] == pytest.approx([0.3, 0.5])

def test_exception_details():
"""异常详情测试"""
with pytest.raises(ValueError) as excinfo:
raise ValueError("错误消息")

assert excinfo.type is ValueError
assert "错误消息" in str(excinfo.value)

Fixture(测试固件)

pytest 的 fixture 比 unittest 更强大灵活。

基本 Fixture

import pytest

@pytest.fixture
def sample_data():
"""提供测试数据"""
return {"name": "张三", "age": 25}

def test_with_fixture(sample_data):
"""使用 fixture"""
assert sample_data["name"] == "张三"
assert sample_data["age"] == 25

Fixture 作用域

import pytest

# function 级别(默认):每个测试函数执行一次
@pytest.fixture(scope="function")
def func_fixture():
print("\n设置 function fixture")
yield "function"
print("\n清理 function fixture")

# class 级别:每个测试类执行一次
@pytest.fixture(scope="class")
def class_fixture():
print("\n设置 class fixture")
yield "class"
print("\n清理 class fixture")

# module 级别:每个模块执行一次
@pytest.fixture(scope="module")
def module_fixture():
print("\n设置 module fixture")
yield "module"
print("\n清理 module fixture")

# session 级别:整个测试会话执行一次
@pytest.fixture(scope="session")
def session_fixture():
print("\n设置 session fixture")
yield "session"
print("\n清理 session fixture")

def test_fixture_example(func_fixture, module_fixture):
"""测试 fixture 作用域"""
assert func_fixture == "function"
assert module_fixture == "module"

Fixture 清理操作

import pytest
import tempfile
import os

@pytest.fixture
def temp_file():
"""创建临时文件,测试后自动清理"""
# setup
fd, path = tempfile.mkstemp()
os.close(fd)
with open(path, 'w') as f:
f.write("test content")

# 提供给测试使用
yield path

# teardown(清理)
if os.path.exists(path):
os.remove(path)

def test_temp_file(temp_file):
"""测试使用临时文件"""
assert os.path.exists(temp_file)
with open(temp_file, 'r') as f:
content = f.read()
assert content == "test content"

内置 Fixture

import pytest

def test_tmp_path(tmp_path):
"""使用临时目录"""
# tmp_path 是 pathlib.Path 对象
test_file = tmp_path / "test.txt"
test_file.write_text("hello")
assert test_file.read_text() == "hello"

def test_tmp_path_factory(tmp_path_factory):
"""创建多个临时目录"""
dir1 = tmp_path_factory.mktemp("dir1")
dir2 = tmp_path_factory.mktemp("dir2")
assert dir1 != dir2

def test_capfd(capfd):
"""捕获标准输出"""
print("Hello, World!")
captured = capfd.readouterr()
assert captured.out == "Hello, World!\n"

def test_caplog(caplog):
"""捕获日志"""
import logging
logging.warning("warning message")
assert "warning message" in caplog.text

def test_monkeypatch(monkeypatch):
"""修改环境变量和属性"""
import os

# 设置环境变量
monkeypatch.setenv("TEST_VAR", "test_value")
assert os.environ["TEST_VAR"] == "test_value"

# 删除环境变量
monkeypatch.delenv("TEST_VAR", raising=False)
assert "TEST_VAR" not in os.environ

参数化测试

import pytest

# 单参数测试
@pytest.mark.parametrize("input,expected", [
(1, 2),
(2, 4),
(3, 6),
(10, 20),
])
def test_double(input, expected):
"""测试加倍函数"""
assert input * 2 == expected

# 多参数测试
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(5, 5, 10),
(-1, 1, 0),
(0, 0, 0),
])
def test_add(a, b, expected):
"""测试加法"""
assert a + b == expected

# 参数化标记
@pytest.mark.parametrize("value,expected", [
(1, 1),
pytest.param(2, 4, id="two_squared"),
pytest.param(3, 9, marks=pytest.mark.xfail(reason="故意失败")),
pytest.param(4, 16, marks=pytest.mark.skip("跳过")),
])
def test_square(value, expected):
assert value ** 2 == expected

# 组合参数化
@pytest.mark.parametrize("x", [1, 2])
@pytest.mark.parametrize("y", [10, 20])
def test_combination(x, y):
"""测试参数组合:会生成 4 个测试"""
assert x + y in [11, 12, 21, 22]

标记(Marks)

import pytest

# 跳过测试
@pytest.mark.skip(reason="尚未实现")
def test_not_implemented():
pass

# 条件跳过
@pytest.mark.skipif(sys.version_info < (3, 10), reason="需要 Python 3.10+")
def test_python310_feature():
pass

# 预期失败
@pytest.mark.xfail(reason="已知 bug")
def test_known_bug():
assert 1 == 2

# 自定义标记
@pytest.mark.slow
def test_slow_operation():
import time
time.sleep(1)
assert True

@pytest.mark.integration
def test_database():
"""集成测试"""
pass

# 运行指定标记的测试
# pytest -m slow 运行标记为 slow 的测试
# pytest -m "not slow" 运行未标记为 slow 的测试
# pytest -m "integration or slow"

测试类

import pytest

class TestCalculator:
"""计算器测试类"""

@pytest.fixture(autouse=True)
def setup(self):
"""每个测试方法前自动执行"""
self.calculator = Calculator()

def test_add(self):
assert self.calculator.add(1, 2) == 3

def test_subtract(self):
assert self.calculator.subtract(5, 3) == 2

@pytest.mark.parametrize("a,b,expected", [
(2, 3, 6),
(4, 5, 20),
])
def test_multiply(self, a, b, expected):
assert self.calculator.multiply(a, b) == expected


class Calculator:
def add(self, a, b): return a + b
def subtract(self, a, b): return a - b
def multiply(self, a, b): return a * b

Mock 和 Patch

使用 unittest.mock 模拟外部依赖,隔离测试。

基本 Mock

from unittest.mock import Mock, MagicMock

def test_mock_basic():
"""基本 Mock 用法"""
# 创建 Mock 对象
mock = Mock()

# 设置返回值
mock.method.return_value = 42
assert mock.method() == 42

# 验证调用
mock.method.assert_called_once()

# 验证调用参数
mock.method_with_args("hello")
mock.method_with_args.assert_called_with("hello")

# 检查调用次数
assert mock.method.call_count == 1

使用 patch

from unittest.mock import patch, MagicMock
import requests

# 被测试的函数
def get_user_info(user_id):
"""获取用户信息"""
response = requests.get(f"https://api.example.com/users/{user_id}")
return response.json()

# 测试
@patch('requests.get')
def test_get_user_info(mock_get):
"""测试获取用户信息"""
# 设置 mock 返回值
mock_response = MagicMock()
mock_response.json.return_value = {"id": 1, "name": "张三"}
mock_get.return_value = mock_response

# 调用函数
result = get_user_info(1)

# 验证结果
assert result == {"id": 1, "name": "张三"}

# 验证 requests.get 被正确调用
mock_get.assert_called_once_with("https://api.example.com/users/1")

# 使用上下文管理器
def test_with_context():
"""使用上下文管理器的 patch"""
with patch('requests.get') as mock_get:
mock_get.return_value.json.return_value = {"id": 2}
result = get_user_info(2)
assert result["id"] == 2

Mock 类方法

from unittest.mock import patch, MagicMock

class Database:
def query(self, sql):
"""真实的数据库查询"""
pass

class UserService:
def __init__(self, db):
self.db = db

def get_user(self, user_id):
return self.db.query(f"SELECT * FROM users WHERE id = {user_id}")

def test_user_service():
"""测试用户服务"""
# 创建 mock 数据库
mock_db = MagicMock()
mock_db.query.return_value = {"id": 1, "name": "张三"}

# 注入 mock
service = UserService(mock_db)
result = service.get_user(1)

# 验证
assert result["name"] == "张三"
mock_db.query.assert_called_once_with("SELECT * FROM users WHERE id = 1")

@patch('__main__.Database')
def test_with_patch(MockDatabase):
"""使用 patch 替换类"""
mock_instance = MockDatabase.return_value
mock_instance.query.return_value = {"id": 1}

db = Database()
result = db.query("SELECT 1")

assert result["id"] == 1

pytest-mock 插件

# pip install pytest-mock

def test_with_mocker(mocker):
"""使用 pytest-mock"""
# mocker.patch 更简洁
mock_get = mocker.patch('requests.get')
mock_get.return_value.json.return_value = {"id": 1}

result = get_user_info(1)
assert result["id"] == 1

# mocker.spy 监视方法调用但不替换
# spy = mocker.spy(SomeClass, 'method')

测试覆盖率

使用 pytest-cov 测量代码覆盖率。

# 安装
pip install pytest-cov

# 运行测试并生成覆盖率报告
pytest --cov=myapp

# 生成 HTML 报告
pytest --cov=myapp --cov-report=html

# 指定覆盖率要求
pytest --cov=myapp --cov-fail-under=80

配置文件

# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --cov=myapp --cov-report=html --cov-fail-under=80
markers =
slow: 慢速测试
integration: 集成测试
# pyproject.toml
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
addopts = "-v --cov=myapp --cov-report=term-missing"
markers = [
"slow: marks tests as slow",
"integration: marks tests as integration tests",
]

[tool.coverage.run]
source = ["myapp"]
omit = ["tests/*", "*/__pycache__/*"]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise NotImplementedError",
"if __name__ == .__main__.:",
]

测试最佳实践

1. 测试命名规范

# 好的命名:描述性、清晰
def test_add_should_return_sum_of_two_numbers():
pass

def test_divide_should_raise_error_when_dividing_by_zero():
pass

def test_user_should_be_able_to_login_with_valid_credentials():
pass

# 不好的命名
def test_1():
pass

def test_something():
pass

2. AAA 模式

def test_with_aaa_pattern():
"""Arrange-Act-Assert 模式"""
# Arrange(准备):设置测试数据和条件
calculator = Calculator()
a, b = 2, 3

# Act(执行):调用被测试的代码
result = calculator.add(a, b)

# Assert(断言):验证结果
assert result == 5

3. 测试隔离

import pytest
import tempfile
import os

@pytest.fixture
def isolated_file_system():
"""隔离的文件系统环境"""
# 创建临时目录
temp_dir = tempfile.mkdtemp()
original_dir = os.getcwd()
os.chdir(temp_dir)

yield temp_dir

# 清理
os.chdir(original_dir)
import shutil
shutil.rmtree(temp_dir)

def test_file_operations(isolated_file_system):
"""测试文件操作,使用隔离环境"""
# 在临时目录中操作,不影响真实文件系统
with open("test.txt", "w") as f:
f.write("hello")

assert os.path.exists("test.txt")

4. 测试边界条件

import pytest

def process_age(age):
"""处理年龄"""
if not isinstance(age, int):
raise TypeError("年龄必须是整数")
if age < 0:
raise ValueError("年龄不能为负")
if age > 150:
raise ValueError("年龄不合理")
return age

class TestProcessAge:
"""测试年龄处理函数"""

def test_normal_age(self):
assert process_age(25) == 25

@pytest.mark.parametrize("age", [0, 1, 100, 149])
def test_valid_boundary(self, age):
"""测试有效边界"""
assert process_age(age) == age

@pytest.mark.parametrize("age", [-1, -100, 151, 200])
def test_invalid_boundary(self, age):
"""测试无效边界"""
with pytest.raises(ValueError):
process_age(age)

def test_non_integer(self):
"""测试非整数输入"""
with pytest.raises(TypeError):
process_age("25")
with pytest.raises(TypeError):
process_age(25.5)

5. 避免测试实现细节

# 不好的测试:测试实现细节
def test_internal_state():
user = User("张三")
assert user._internal_name == "张三" # 测试私有属性

# 好的测试:测试公共行为
def test_user_name():
user = User("张三")
assert user.get_name() == "张三" # 测试公共接口

测试实战示例

测试 Flask 应用

# app.py
from flask import Flask, jsonify

app = Flask(__name__)

@app.route('/api/users/<int:user_id>')
def get_user(user_id):
# 模拟数据库查询
users = {1: {"id": 1, "name": "张三"}, 2: {"id": 2, "name": "李四"}}
user = users.get(user_id)
if user:
return jsonify(user)
return jsonify({"error": "User not found"}), 404

# test_app.py
import pytest
from app import app

@pytest.fixture
def client():
"""创建测试客户端"""
app.config['TESTING'] = True
with app.test_client() as client:
yield client

def test_get_user_success(client):
"""测试获取用户成功"""
response = client.get('/api/users/1')
assert response.status_code == 200
data = response.get_json()
assert data["name"] == "张三"

def test_get_user_not_found(client):
"""测试用户不存在"""
response = client.get('/api/users/999')
assert response.status_code == 404
data = response.get_json()
assert "error" in data

测试数据库操作

import pytest
import sqlite3
from contextlib import contextmanager

# 被测试的代码
class UserRepository:
def __init__(self, db_path):
self.db_path = db_path

@contextmanager
def get_connection(self):
conn = sqlite3.connect(self.db_path)
try:
yield conn
finally:
conn.close()

def create_user(self, name, email):
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT INTO users (name, email) VALUES (?, ?)",
(name, email)
)
conn.commit()
return cursor.lastrowid

def get_user(self, user_id):
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
return cursor.fetchone()

# 测试
@pytest.fixture
def temp_db():
"""创建临时数据库"""
import tempfile
fd, path = tempfile.mkstemp(suffix=".db")

# 创建表
conn = sqlite3.connect(path)
conn.execute("""
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL
)
""")
conn.commit()
conn.close()

yield path

# 清理
import os
os.close(fd)
os.unlink(path)

def test_create_and_get_user(temp_db):
"""测试创建和获取用户"""
repo = UserRepository(temp_db)

# 创建用户
user_id = repo.create_user("张三", "[email protected]")
assert user_id == 1

# 获取用户
user = repo.get_user(user_id)
assert user is not None
assert user[1] == "张三"
assert user[2] == "[email protected]"

def test_get_nonexistent_user(temp_db):
"""测试获取不存在的用户"""
repo = UserRepository(temp_db)
user = repo.get_user(999)
assert user is None

小结

本章我们学习了 Python 测试的完整知识体系:

  1. unittest:Python 内置测试框架
  2. pytest:更强大的第三方测试框架
  3. Fixture:测试固件,提供测试资源
  4. 参数化测试:使用不同参数运行同一测试
  5. Mock 和 Patch:模拟外部依赖
  6. 测试覆盖率:测量代码覆盖情况
  7. 最佳实践:命名规范、AAA 模式、测试隔离

练习

  1. 为一个计算器类编写完整的测试用例
  2. 使用 fixture 创建数据库测试环境
  3. 使用 mock 测试 HTTP 请求
  4. 使用参数化测试测试字符串处理函数
  5. 配置 pytest.ini 和覆盖率报告

参考资源