测试
测试是软件开发中不可或缺的环节。良好的测试不仅能确保代码按预期工作,还能在重构时提供安全网,防止引入新的 bug。Flask 提供了完善的测试支持,配合 pytest 框架可以轻松构建全面的测试体系。
理解测试的价值
在深入技术细节之前,让我们先理解为什么测试如此重要。
测试的层次
软件测试通常分为几个层次,每个层次关注不同的方面:
- 单元测试(Unit Tests):测试单个函数或方法的行为,是最小粒度的测试
- 集成测试(Integration Tests):测试多个组件协同工作的行为
- 功能测试(Functional Tests):从用户角度测试完整的功能流程
- 端到端测试(End-to-End Tests):模拟真实用户操作,测试整个系统
Flask 的测试工具主要支持前三类测试,端到端测试通常需要额外的工具如 Selenium 或 Playwright。
测试金字塔
测试金字塔是一种测试策略模型,建议:
- 编写大量快速的单元测试作为基础
- 编写适量的集成测试
- 编写少量耗时的端到端测试
这种策略既能保证代码质量,又能保持测试效率。
安装测试依赖
# 核心测试工具
pip install pytest pytest-cov
# Flask 测试插件(可选但推荐)
pip install pytest-flask
# Mock 工具
pip install pytest-mock
# 代码覆盖率
pip install coverage
创建 requirements-test.txt:
pytest>=7.0.0
pytest-cov>=4.0.0
pytest-flask>=1.2.0
pytest-mock>=3.10.0
coverage>=7.0.0
测试配置
Flask 测试配置
为测试环境创建专门的配置类:
# config.py
import os
class Config:
"""基础配置"""
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-key')
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL', 'sqlite:///app.db')
SQLALCHEMY_TRACK_MODIFICATIONS = False
class TestConfig(Config):
"""测试配置"""
TESTING = True # 启用测试模式
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' # 内存数据库
WTF_CSRF_ENABLED = False # 禁用 CSRF(测试时方便)
SERVER_NAME = 'localhost:5000' # 设置服务器名(用于 url_for)
# 禁用邮件发送
MAIL_SUPPRESS_SEND = True
# 更快的密码哈希(仅用于测试)
SECURITY_PASSWORD_HASH = 'plaintext'
TESTING = True 会启用以下行为:
- 异常传播不被捕获(便于调试)
- 模板错误不被捕获
- 测试客户端可用
项目结构
推荐的项目结构:
myapp/
├── app/
│ ├── __init__.py # 应用工厂
│ ├── models/
│ ├── routes/
│ ├── services/
│ └── utils/
├── migrations/
├── tests/
│ ├── __init__.py
│ ├── conftest.py # pytest 配置和共享 fixtures
│ ├── fixtures/ # 测试数据
│ │ └── users.json
│ ├── unit/ # 单元测试
│ │ ├── test_models.py
│ │ ├── test_utils.py
│ │ └── test_services.py
│ ├── integration/ # 集成测试
│ │ ├── test_auth.py
│ │ └── test_api.py
│ └── functional/ # 功能测试
│ └── test_user_flow.py
├── config.py
├── requirements.txt
├── requirements-test.txt
└── pytest.ini # pytest 配置文件
pytest 配置文件
创建 pytest.ini 配置测试运行方式:
[pytest]
# 测试文件和目录的匹配模式
python_files = test_*.py
python_classes = Test*
python_functions = test_*
# 测试目录
testpaths = tests
# 命令行选项
addopts =
-v # 详细输出
--tb=short # 简短的回溯信息
--strict-markers # 严格标记模式
-ra # 显示所有测试摘要
# 标记定义
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks tests as integration tests
unit: marks tests as unit tests
api: marks tests for API endpoints
# 环境变量
env =
FLASK_ENV=testing
TESTING=True
pytest Fixtures
Fixtures 是 pytest 最强大的功能之一,用于提供测试所需的资源和环境。
基础 Fixtures
# tests/conftest.py
import pytest
from app import create_app
from app.extensions import db
from config import TestConfig
@pytest.fixture
def app():
"""创建测试应用实例"""
# 创建应用
app = create_app(TestConfig)
# 创建应用上下文
with app.app_context():
# 创建数据库表
db.create_all()
# 返回应用实例
yield app
# 清理:删除所有表
db.session.remove()
db.drop_all()
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
@pytest.fixture
def runner(app):
"""创建 CLI 测试运行器"""
return app.test_cli_runner()
@pytest.fixture
def app_context(app):
"""创建应用上下文"""
with app.app_context():
yield
数据库 Fixtures
使用事务回滚确保测试之间相互隔离:
# tests/conftest.py
from app.models import User, Post
@pytest.fixture
def db_session(app):
"""数据库会话 fixture,每个测试后回滚"""
with app.app_context():
# 开始事务
connection = db.engine.connect()
transaction = connection.begin()
# 绑定会话到连接
options = dict(bind=connection, binds={})
session = db.create_scoped_session(options=options)
db.session = session
yield session
# 回滚事务
transaction.rollback()
connection.close()
session.remove()
@pytest.fixture
def sample_user(db_session):
"""创建示例用户"""
user = User(
username='testuser',
email='[email protected]',
is_active=True
)
user.set_password('password123')
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def admin_user(db_session):
"""创建管理员用户"""
user = User(
username='admin',
email='[email protected]',
is_active=True,
is_admin=True
)
user.set_password('admin123')
db_session.add(user)
db_session.commit()
return user
@pytest.fixture
def sample_posts(db_session, sample_user):
"""创建示例文章"""
posts = []
for i in range(5):
post = Post(
title=f'测试文章 {i+1}',
content=f'这是第 {i+1} 篇测试文章的内容',
author=sample_user
)
db_session.add(post)
posts.append(post)
db_session.commit()
return posts
认证 Fixtures
简化需要登录的测试:
# tests/conftest.py
@pytest.fixture
def logged_in_client(client, sample_user):
"""已登录的测试客户端"""
with client.session_transaction() as session:
session['user_id'] = sample_user.id
session['_fresh'] = True
return client
@pytest.fixture
def admin_client(client, admin_user):
"""管理员登录的测试客户端"""
with client.session_transaction() as session:
session['user_id'] = admin_user.id
session['_fresh'] = True
return client
# 使用 Flask-Login 时
@pytest.fixture
def auth_client(client, sample_user):
"""使用 Flask-Login 登录的客户端"""
from flask_login import login_user
with client:
# 通过登录路由登录
client.post('/auth/login', data={
'username': 'testuser',
'password': 'password123'
})
yield client
工厂 Fixtures
使用工厂模式创建灵活的测试数据:
# tests/conftest.py
@pytest.fixture
def user_factory(db_session):
"""用户工厂 fixture"""
created_users = []
def create_user(**kwargs):
defaults = {
'username': f'user_{len(created_users)}',
'email': f'user_{len(created_users)}@example.com',
'is_active': True
}
defaults.update(kwargs)
user = User(**defaults)
if 'password' in kwargs:
user.set_password(kwargs['password'])
else:
user.set_password('default_password')
db_session.add(user)
db_session.commit()
created_users.append(user)
return user
yield create_user
# 清理创建的用户
for user in created_users:
db_session.delete(user)
db_session.commit()
@pytest.fixture
def post_factory(db_session, user_factory):
"""文章工厂 fixture"""
created_posts = []
def create_post(author=None, **kwargs):
if author is None:
author = user_factory()
defaults = {
'title': f'文章 {len(created_posts)}',
'content': '这是测试内容',
'author': author
}
defaults.update(kwargs)
post = Post(**defaults)
db_session.add(post)
db_session.commit()
created_posts.append(post)
return post
yield create_post
使用工厂 fixture:
def test_user_posts(post_factory, user_factory):
# 创建用户和文章
user = user_factory(username='alice')
post1 = post_factory(author=user, title='第一篇文章')
post2 = post_factory(author=user, title='第二篇文章')
# 测试...
assert user.posts.count() == 2
测试客户端
Flask 的测试客户端是测试 Web 应用的核心工具,它允许在不运行服务器的情况下发送 HTTP 请求。
发送请求
def test_get_request(client):
"""测试 GET 请求"""
response = client.get('/')
# 检查状态码
assert response.status_code == 200
# 检查响应内容(字节)
assert b'Hello' in response.data
# 检查响应内容(文本)
assert 'Hello' in response.text
# 检查 Content-Type
assert response.content_type == 'text/html; charset=utf-8'
def test_post_request(client):
"""测试 POST 请求"""
response = client.post('/api/users', json={
'username': 'newuser',
'email': '[email protected]'
})
assert response.status_code == 201
assert response.json['username'] == 'newuser'
def test_with_headers(client):
"""测试带请求头的请求"""
response = client.get(
'/api/profile',
headers={
'Authorization': 'Bearer test-token',
'Accept': 'application/json'
}
)
assert response.status_code == 200
def test_with_query_params(client):
"""测试带查询参数的请求"""
response = client.get('/api/users', query_string={
'page': 2,
'per_page': 10,
'search': 'john'
})
assert response.status_code == 200
表单数据测试
def test_form_submission(client):
"""测试表单提交"""
response = client.post('/auth/register', data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
}, follow_redirects=True)
assert response.status_code == 200
assert b'注册成功' in response.data
def test_file_upload(client):
"""测试文件上传"""
import io
data = {
'title': '测试文件',
'file': (io.BytesIO(b'文件内容'), 'test.txt', 'text/plain')
}
response = client.post(
'/upload',
data=data,
content_type='multipart/form-data'
)
assert response.status_code == 200
JSON API 测试
import json
def test_json_api(client):
"""测试 JSON API"""
# 发送 JSON 数据
response = client.post(
'/api/users',
json={
'username': 'apiuser',
'email': '[email protected]',
'password': 'password123'
}
)
assert response.status_code == 201
# 访问 JSON 响应
data = response.json
assert data['username'] == 'apiuser'
assert 'id' in data
def test_api_error_response(client):
"""测试 API 错误响应"""
response = client.post('/api/users', json={
'username': '', # 空用户名应触发错误
'email': 'invalid-email'
})
assert response.status_code == 400
assert 'error' in response.json
assert 'username' in response.json['errors']
跟随重定向
def test_redirect_follow(client):
"""测试跟随重定向"""
# 不跟随重定向
response = client.get('/logout')
assert response.status_code == 302
# 跟随重定向
response = client.get('/logout', follow_redirects=True)
assert response.status_code == 200
# 检查重定向历史
assert len(response.history) == 1
assert response.history[0].status_code == 302
assert response.request.path == '/login'
def test_login_redirect(client):
"""测试登录后重定向"""
# 访问需要认证的页面
response = client.get('/dashboard')
assert response.status_code == 302
assert '/login' in response.headers['Location']
# 登录后重定向回原页面
response = client.post('/login?next=/dashboard', data={
'username': 'testuser',
'password': 'password123'
}, follow_redirects=True)
assert response.request.path == '/dashboard'
Session 操作
from flask import session
def test_access_session(client):
"""访问 session"""
with client:
# 发送请求
client.post('/login', data={
'username': 'testuser',
'password': 'password123'
})
# 在 with 块内可以访问 session
assert 'user_id' in session
assert session['user_id'] == 1
def test_modify_session(client):
"""修改 session"""
# 预设 session
with client.session_transaction() as sess:
sess['user_id'] = 1
sess['username'] = 'testuser'
# 现在发送请求,session 已设置
response = client.get('/dashboard')
assert response.status_code == 200
def test_clear_session(client):
"""清除 session"""
with client.session_transaction() as sess:
sess['user_id'] = 1
# 登出
response = client.get('/logout', follow_redirects=True)
with client:
client.get('/') # 触发请求上下文
assert 'user_id' not in session
测试模型
模型创建和验证
# tests/unit/test_models.py
import pytest
from app.models import User, Post
from app.extensions import db
from datetime import datetime
class TestUserModel:
"""用户模型测试"""
def test_create_user(self, db_session):
"""测试创建用户"""
user = User(
username='testuser',
email='[email protected]'
)
user.set_password('password123')
db_session.add(user)
db_session.commit()
assert user.id is not None
assert user.username == 'testuser'
assert user.email == '[email protected]'
assert user.created_at is not None
def test_password_hashing(self, db_session):
"""测试密码哈希"""
user = User(username='test')
user.set_password('password')
# 密码应该被哈希
assert user.password_hash != 'password'
# 验证正确密码
assert user.check_password('password') is True
# 验证错误密码
assert user.check_password('wrong') is False
def test_user_repr(self, db_session):
"""测试字符串表示"""
user = User(username='testuser')
assert 'testuser' in repr(user)
def test_unique_username(self, db_session):
"""测试用户名唯一性"""
user1 = User(username='duplicate', email='[email protected]')
user1.set_password('password')
db_session.add(user1)
db_session.commit()
user2 = User(username='duplicate', email='[email protected]')
user2.set_password('password')
db_session.add(user2)
with pytest.raises(Exception): # IntegrityError
db_session.commit()
def test_user_to_dict(self, sample_user):
"""测试字典序列化"""
user_dict = sample_user.to_dict()
assert user_dict['id'] == sample_user.id
assert user_dict['username'] == sample_user.username
assert 'password' not in user_dict # 敏感信息不应暴露
class TestPostModel:
"""文章模型测试"""
def test_create_post(self, db_session, sample_user):
"""测试创建文章"""
post = Post(
title='测试文章',
content='这是测试内容',
author=sample_user
)
db_session.add(post)
db_session.commit()
assert post.id is not None
assert post.author_id == sample_user.id
assert post.author.username == sample_user.username
def test_post_ordering(self, db_session, sample_user):
"""测试文章排序"""
from time import sleep
post1 = Post(title='第一篇', content='...', author=sample_user)
db_session.add(post1)
db_session.commit()
sleep(0.1) # 确保时间差
post2 = Post(title='第二篇', content='...', author=sample_user)
db_session.add(post2)
db_session.commit()
posts = Post.query.order_by(Post.created_at.desc()).all()
assert posts[0].title == '第二篇'
assert posts[1].title == '第一篇'
关系测试
class TestRelationships:
"""测试模型关系"""
def test_user_posts_relationship(self, db_session, sample_user):
"""测试用户-文章关系"""
# 创建多篇文章
for i in range(3):
post = Post(
title=f'文章 {i}',
content='内容',
author=sample_user
)
db_session.add(post)
db_session.commit()
# 测试关系
assert sample_user.posts.count() == 3
# 测试级联删除
db_session.delete(sample_user)
db_session.commit()
assert Post.query.count() == 0
def test_many_to_many(self, db_session):
"""测试多对多关系"""
from app.models import Tag
post = Post(title='测试', content='内容')
tag1 = Tag(name='Python')
tag2 = Tag(name='Flask')
post.tags.append(tag1)
post.tags.append(tag2)
db_session.add_all([post, tag1, tag2])
db_session.commit()
assert len(post.tags) == 2
assert tag1.posts.count() == 1
测试视图函数
路由测试
# tests/integration/test_routes.py
class TestPublicRoutes:
"""公开路由测试"""
def test_home_page(self, client):
"""测试首页"""
response = client.get('/')
assert response.status_code == 200
assert b'Welcome' in response.data
def test_about_page(self, client):
"""测试关于页面"""
response = client.get('/about')
assert response.status_code == 200
def test_404_page(self, client):
"""测试 404 页面"""
response = client.get('/nonexistent')
assert response.status_code == 404
class TestAuthRoutes:
"""认证路由测试"""
def test_login_page(self, client):
"""测试登录页面"""
response = client.get('/auth/login')
assert response.status_code == 200
assert b'login' in response.data.lower()
def test_login_success(self, client, sample_user):
"""测试登录成功"""
response = client.post('/auth/login', data={
'username': 'testuser',
'password': 'password123'
}, follow_redirects=True)
assert response.status_code == 200
assert b'欢迎' in response.data
def test_login_invalid_password(self, client, sample_user):
"""测试密码错误"""
response = client.post('/auth/login', data={
'username': 'testuser',
'password': 'wrongpassword'
})
assert response.status_code == 200 # 返回登录页面
assert b'用户名或密码错误' in response.data
def test_register(self, client):
"""测试注册"""
response = client.post('/auth/register', data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
}, follow_redirects=True)
assert response.status_code == 200
assert b'注册成功' in response.data
class TestProtectedRoutes:
"""受保护路由测试"""
def test_dashboard_requires_login(self, client):
"""测试需要登录"""
response = client.get('/dashboard')
assert response.status_code == 302
assert '/login' in response.headers['Location']
def test_dashboard_with_login(self, logged_in_client):
"""测试登录后访问"""
response = logged_in_client.get('/dashboard')
assert response.status_code == 200
def test_admin_requires_admin(self, logged_in_client):
"""测试需要管理员权限"""
response = logged_in_client.get('/admin')
assert response.status_code == 403
def test_admin_with_admin_user(self, admin_client):
"""测试管理员访问"""
response = admin_client.get('/admin')
assert response.status_code == 200
API 测试
# tests/integration/test_api.py
import json
class TestUserAPI:
"""用户 API 测试"""
def test_list_users(self, client, sample_user):
"""测试获取用户列表"""
response = client.get('/api/v1/users')
assert response.status_code == 200
data = response.json
assert isinstance(data, list)
assert len(data) > 0
def test_get_user(self, client, sample_user):
"""测试获取单个用户"""
response = client.get(f'/api/v1/users/{sample_user.id}')
assert response.status_code == 200
assert response.json['username'] == sample_user.username
def test_get_user_not_found(self, client):
"""测试获取不存在的用户"""
response = client.get('/api/v1/users/99999')
assert response.status_code == 404
def test_create_user(self, client):
"""测试创建用户"""
response = client.post('/api/v1/users', json={
'username': 'apiuser',
'email': '[email protected]',
'password': 'password123'
})
assert response.status_code == 201
assert response.json['username'] == 'apiuser'
def test_create_user_validation_error(self, client):
"""测试创建用户验证失败"""
response = client.post('/api/v1/users', json={
'username': '', # 空用户名
'email': 'invalid' # 无效邮箱
})
assert response.status_code == 400
assert 'errors' in response.json
def test_update_user(self, client, sample_user):
"""测试更新用户"""
response = client.put(
f'/api/v1/users/{sample_user.id}',
json={'username': 'updated_name'}
)
assert response.status_code == 200
assert response.json['username'] == 'updated_name'
def test_delete_user(self, client, sample_user):
"""测试删除用户"""
response = client.delete(f'/api/v1/users/{sample_user.id}')
assert response.status_code == 204
# 确认已删除
response = client.get(f'/api/v1/users/{sample_user.id}')
assert response.status_code == 404
class TestAuthenticationAPI:
"""认证 API 测试"""
def test_token_generation(self, client, sample_user):
"""测试获取 Token"""
response = client.post('/api/v1/auth/login', json={
'username': 'testuser',
'password': 'password123'
})
assert response.status_code == 200
assert 'access_token' in response.json
def test_protected_endpoint_with_token(self, client, sample_user):
"""测试使用 Token 访问受保护端点"""
# 获取 Token
login_response = client.post('/api/v1/auth/login', json={
'username': 'testuser',
'password': 'password123'
})
token = login_response.json['access_token']
# 使用 Token 访问
response = client.get(
'/api/v1/profile',
headers={'Authorization': f'Bearer {token}'}
)
assert response.status_code == 200
def test_protected_endpoint_without_token(self, client):
"""测试无 Token 访问受保护端点"""
response = client.get('/api/v1/profile')
assert response.status_code == 401
测试表单
# tests/unit/test_forms.py
from app.forms import LoginForm, RegistrationForm
class TestLoginForm:
"""登录表单测试"""
def test_valid_login_form(self, app):
"""测试有效的登录表单"""
with app.app_context():
form = LoginForm(data={
'username': 'testuser',
'password': 'password123',
'remember': True
})
assert form.validate() is True
def test_empty_username(self, app):
"""测试空用户名"""
with app.app_context():
form = LoginForm(data={
'username': '',
'password': 'password123'
})
assert form.validate() is False
assert 'username' in form.errors
def test_empty_password(self, app):
"""测试空密码"""
with app.app_context():
form = LoginForm(data={
'username': 'testuser',
'password': ''
})
assert form.validate() is False
assert 'password' in form.errors
class TestRegistrationForm:
"""注册表单测试"""
def test_valid_registration(self, app):
"""测试有效的注册表单"""
with app.app_context():
form = RegistrationForm(data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() is True
def test_password_mismatch(self, app):
"""测试密码不匹配"""
with app.app_context():
form = RegistrationForm(data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'different'
})
assert form.validate() is False
assert 'confirm' in form.errors
def test_invalid_email(self, app):
"""测试无效邮箱"""
with app.app_context():
form = RegistrationForm(data={
'username': 'newuser',
'email': 'not-an-email',
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() is False
assert 'email' in form.errors
def test_short_username(self, app):
"""测试用户名太短"""
with app.app_context():
form = RegistrationForm(data={
'username': 'ab', # 少于 3 个字符
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() is False
assert 'username' in form.errors
Mock 和依赖注入
Mock 用于隔离测试单元,避免依赖外部服务。
Mock 外部 API
# tests/unit/test_services.py
import pytest
from unittest.mock import patch, MagicMock
from app.services import NotificationService, PaymentService
class TestNotificationService:
"""通知服务测试"""
@patch('app.services.requests.post')
def test_send_email(self, mock_post):
"""测试发送邮件(Mock HTTP 请求)"""
# 配置 Mock 返回值
mock_post.return_value = MagicMock(
status_code=200,
json=lambda: {'success': True, 'id': 'mail-123'}
)
service = NotificationService()
result = service.send_email(
to='[email protected]',
subject='测试邮件',
body='邮件内容'
)
# 验证调用
assert result['success'] is True
mock_post.assert_called_once()
# 验证调用参数
call_args = mock_post.call_args
assert '[email protected]' in str(call_args)
@patch('app.services.requests.post')
def test_send_email_failure(self, mock_post):
"""测试发送邮件失败"""
mock_post.return_value = MagicMock(
status_code=500,
json=lambda: {'success': False, 'error': 'Server error'}
)
service = NotificationService()
result = service.send_email(
to='[email protected]',
subject='测试',
body='内容'
)
assert result['success'] is False
class TestPaymentService:
"""支付服务测试"""
@patch('stripe.Charge.create')
def test_process_payment(self, mock_charge):
"""测试支付处理"""
mock_charge.return_value = {
'id': 'ch_123',
'status': 'succeeded',
'amount': 1000
}
service = PaymentService()
result = service.process_payment(
amount=10.00,
currency='usd',
token='tok_visa'
)
assert result['success'] is True
assert result['charge_id'] == 'ch_123'
@patch('stripe.Charge.create')
def test_payment_declined(self, mock_charge):
"""测试支付被拒绝"""
import stripe
mock_charge.side_effect = stripe.error.CardError(
'Card declined', 'param', 'declined'
)
service = PaymentService()
result = service.process_payment(
amount=10.00,
currency='usd',
token='tok_declined'
)
assert result['success'] is False
assert 'declined' in result['error']
Mock 数据库
from unittest.mock import patch, MagicMock
class TestUserService:
"""用户服务测试"""
@patch('app.services.User.query')
def test_get_user_by_username(self, mock_query):
"""测试通过用户名获取用户"""
# 创建 Mock 用户
mock_user = MagicMock()
mock_user.username = 'testuser'
mock_user.email = '[email protected]'
# 配置 Mock 返回
mock_query.filter_by.return_value.first.return_value = mock_user
from app.services import get_user_by_username
user = get_user_by_username('testuser')
assert user.username == 'testuser'
mock_query.filter_by.assert_called_once_with(username='testuser')
@patch('app.services.db.session')
def test_create_user(self, mock_session):
"""测试创建用户"""
from app.services import create_user
user = create_user(
username='newuser',
email='[email protected]',
password='password123'
)
# 验证 add 和 commit 被调用
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
使用 pytest-mock
# pytest-mock 提供了更简洁的 Mock 接口
def test_with_mocker(mocker):
"""使用 pytest-mock"""
# mocker.patch 自动清理
mock_get = mocker.patch('requests.get')
mock_get.return_value.json.return_value = {'data': 'test'}
# 测试代码...
# 验证调用
mock_get.assert_called_once()
def test_spy_with_mocker(mocker, sample_user):
"""使用 Spy 监控方法调用"""
# spy 会记录调用但不改变行为
spy = mocker.spy(sample_user, 'check_password')
sample_user.check_password('password123')
spy.assert_called_once_with('password123')
测试覆盖率
配置覆盖率
创建 .coveragerc:
[run]
source = app
branch = True
omit =
app/__init__.py
app/config.py
*/tests/*
*/migrations/*
[report]
exclude_lines =
pragma: no cover
def __repr__
raise AssertionError
raise NotImplementedError
if __name__ == .__main__.:
if TYPE_CHECKING:
@abstractmethod
precision = 2
[html]
directory = htmlcov
运行覆盖率测试
# 运行测试并收集覆盖率
pytest --cov=app --cov-report=term-missing
# 生成 HTML 报告
pytest --cov=app --cov-report=html
# 同时生成多种报告
pytest --cov=app --cov-report=term --cov-report=html --cov-report=xml
覆盖率报告解读
Name Stmts Miss Cover Missing
------------------------------------------------------------
app/__init__.py 10 0 100%
app/models/__init__.py 0 0 100%
app/models/user.py 35 2 94% 45-46
app/models/post.py 28 5 82% 23, 45-48
app/routes/auth.py 45 10 78% 34-35, 67-72
------------------------------------------------------------
TOTAL 118 17 86%
- Stmts:语句总数
- Miss:未覆盖的语句数
- Cover:覆盖率百分比
- Missing:未覆盖的行号
覆盖率目标
- 单元测试:80% 以上
- 关键业务逻辑:95% 以上
- 整体项目:70% 以上
注意:覆盖率只是指标,不是目标。100% 覆盖率不代表没有 bug。
测试命令行命令
# tests/test_cli.py
def test_init_db_command(runner):
"""测试初始化数据库命令"""
result = runner.invoke(args=['init-db'])
assert 'Initialized' in result.output
def test_create_user_command(runner):
"""测试创建用户命令"""
result = runner.invoke(args=[
'create-user',
'--username', 'cliuser',
'--email', '[email protected]'
])
assert result.exit_code == 0
assert 'User cliuser created' in result.output
def test_command_with_error(runner):
"""测试命令错误处理"""
result = runner.invoke(args=['create-user']) # 缺少参数
assert result.exit_code != 0
assert 'Error' in result.output or 'Missing' in result.output
测试异步视图
import pytest
@pytest.mark.asyncio
async def test_async_view(client):
"""测试异步视图"""
response = await client.get('/async-data')
assert response.status_code == 200
@pytest.mark.asyncio
async def test_async_api(client):
"""测试异步 API"""
response = await client.post('/api/async-users', json={
'username': 'asyncuser',
'email': '[email protected]'
})
assert response.status_code == 201
测试最佳实践
1. 测试命名
# 好的命名:描述测试场景和预期结果
def test_login_with_valid_credentials_redirects_to_dashboard():
pass
def test_create_user_with_duplicate_username_returns_409():
pass
# 不好的命名:太模糊
def test_login():
pass
def test_user():
pass
2. 测试结构:AAA 模式
def test_user_creation():
# Arrange(准备):设置测试数据
user_data = {
'username': 'testuser',
'email': '[email protected]'
}
# Act(执行):执行被测试的操作
user = User(**user_data)
user.set_password('password')
db.session.add(user)
db.session.commit()
# Assert(断言):验证结果
assert user.id is not None
assert user.username == 'testuser'
3. 测试隔离
# 每个测试应该独立,不依赖其他测试
def test_create_user(db_session):
user = User(username='user1')
db_session.add(user)
db_session.commit()
assert user.id is not None
def test_create_another_user(db_session):
# 这个测试不应该依赖上一个测试的数据
user = User(username='user2')
db_session.add(user)
db_session.commit()
assert user.id is not None
4. 测试边界情况
class TestUserValidation:
"""测试用户验证边界情况"""
def test_empty_username(self):
"""空用户名"""
pass
def test_username_too_long(self):
"""用户名超长"""
pass
def test_username_special_chars(self):
"""用户名包含特殊字符"""
pass
def test_username_max_length(self):
"""用户名最大长度"""
pass
5. 使用参数化测试
@pytest.mark.parametrize('username,email,expected_valid', [
('validuser', '[email protected]', True),
('', '[email protected]', False), # 空用户名
('ab', '[email protected]', False), # 用户名太短
('validuser', 'invalid-email', False), # 无效邮箱
('validuser', '', False), # 空邮箱
])
def test_user_validation(username, email, expected_valid, app):
"""参数化测试用户验证"""
with app.app_context():
form = RegistrationForm(data={
'username': username,
'email': email,
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() == expected_valid
6. 避免测试实现细节
# 不好的测试:测试实现细节
def test_user_password_hash():
user = User(username='test')
user.set_password('password')
assert user.password_hash.startswith('pbkdf2:') # 依赖具体实现
# 好的测试:测试行为
def test_user_password_verification():
user = User(username='test')
user.set_password('password')
assert user.check_password('password') is True
assert user.check_password('wrong') is False
CI/CD 集成
GitHub Actions
# .github/workflows/tests.yml
name: Tests
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:15
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: test_db
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Run tests
env:
DATABASE_URL: postgresql://test:test@localhost:5432/test_db
SECRET_KEY: test-secret-key
run: |
pytest --cov=app --cov-report=xml --cov-report=term
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
GitLab CI
# .gitlab-ci.yml
stages:
- test
test:
stage: test
image: python:3.11
services:
- postgres:15
variables:
POSTGRES_DB: test_db
POSTGRES_USER: test
POSTGRES_PASSWORD: test
DATABASE_URL: postgresql://test:test@postgres:5432/test_db
script:
- pip install -r requirements.txt
- pip install -r requirements-test.txt
- pytest --cov=app --cov-report=term
coverage: '/TOTAL.+?(\d+%)$/'
常见问题
测试数据库迁移
def test_database_migration(app):
"""测试数据库迁移"""
from flask_migrate import upgrade
# 升级到最新版本
upgrade()
# 验证表存在
with app.app_context():
from sqlalchemy import inspect
inspector = inspect(db.engine)
tables = inspector.get_table_names()
assert 'users' in tables
assert 'posts' in tables
测试 Flask-Mail
from flask_mail import Mail, Message
def test_email_sending(app):
"""测试邮件发送"""
app.config['MAIL_SUPPRESS_SEND'] = True
with app.app_context():
with Mail(app).record_messages() as outbox:
# 发送邮件
msg = Message(
'测试邮件',
sender='[email protected]',
recipients=['[email protected]']
)
msg.body = '这是测试邮件内容'
Mail(app).send(msg)
# 验证
assert len(outbox) == 1
assert outbox[0].subject == '测试邮件'
测试 Flask-Cache
def test_cache(client, mocker):
"""测试缓存"""
# Mock 缓存
mock_cache = mocker.patch('app.extensions.cache.get')
mock_cache.return_value = {'cached': 'data'}
response = client.get('/api/cached-data')
assert response.json == {'cached': 'data'}
mock_cache.assert_called_once()
小结
本章详细介绍了 Flask 应用的测试方法:
- 测试基础:理解测试层次和金字塔模型
- pytest Fixtures:创建可复用的测试资源和数据
- 测试客户端:模拟 HTTP 请求测试视图
- 模型测试:验证数据模型的行为和关系
- API 测试:测试 RESTful API 端点
- Mock 技术:隔离外部依赖进行单元测试
- 测试覆盖率:衡量测试完整性
- 最佳实践:编写清晰、可靠、可维护的测试
良好的测试习惯是高质量软件的基石。投入时间编写测试会在长期开发中带来丰厚回报。