跳到主要内容

测试

测试是软件开发中不可或缺的环节。良好的测试不仅能确保代码按预期工作,还能在重构时提供安全网,防止引入新的 bug。Flask 提供了完善的测试支持,配合 pytest 框架可以轻松构建全面的测试体系。

理解测试的价值

在深入技术细节之前,让我们先理解为什么测试如此重要。

测试的层次

软件测试通常分为几个层次,每个层次关注不同的方面:

  • 单元测试(Unit Tests):测试单个函数或方法的行为,是最小粒度的测试
  • 集成测试(Integration Tests):测试多个组件协同工作的行为
  • 功能测试(Functional Tests):从用户角度测试完整的功能流程
  • 端到端测试(End-to-End Tests):模拟真实用户操作,测试整个系统

Flask 的测试工具主要支持前三类测试,端到端测试通常需要额外的工具如 Selenium 或 Playwright。

测试金字塔

测试金字塔是一种测试策略模型,建议:

  1. 编写大量快速的单元测试作为基础
  2. 编写适量的集成测试
  3. 编写少量耗时的端到端测试

这种策略既能保证代码质量,又能保持测试效率。

安装测试依赖

# 核心测试工具
pip install pytest pytest-cov

# Flask 测试插件(可选但推荐)
pip install pytest-flask

# Mock 工具
pip install pytest-mock

# 代码覆盖率
pip install coverage

创建 requirements-test.txt

pytest>=7.0.0
pytest-cov>=4.0.0
pytest-flask>=1.2.0
pytest-mock>=3.10.0
coverage>=7.0.0

测试配置

Flask 测试配置

为测试环境创建专门的配置类:

# config.py
import os

class Config:
"""基础配置"""
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-key')
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL', 'sqlite:///app.db')
SQLALCHEMY_TRACK_MODIFICATIONS = False

class TestConfig(Config):
"""测试配置"""
TESTING = True # 启用测试模式
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' # 内存数据库
WTF_CSRF_ENABLED = False # 禁用 CSRF(测试时方便)
SERVER_NAME = 'localhost:5000' # 设置服务器名(用于 url_for)

# 禁用邮件发送
MAIL_SUPPRESS_SEND = True

# 更快的密码哈希(仅用于测试)
SECURITY_PASSWORD_HASH = 'plaintext'

TESTING = True 会启用以下行为:

  • 异常传播不被捕获(便于调试)
  • 模板错误不被捕获
  • 测试客户端可用

项目结构

推荐的项目结构:

myapp/
├── app/
│ ├── __init__.py # 应用工厂
│ ├── models/
│ ├── routes/
│ ├── services/
│ └── utils/
├── migrations/
├── tests/
│ ├── __init__.py
│ ├── conftest.py # pytest 配置和共享 fixtures
│ ├── fixtures/ # 测试数据
│ │ └── users.json
│ ├── unit/ # 单元测试
│ │ ├── test_models.py
│ │ ├── test_utils.py
│ │ └── test_services.py
│ ├── integration/ # 集成测试
│ │ ├── test_auth.py
│ │ └── test_api.py
│ └── functional/ # 功能测试
│ └── test_user_flow.py
├── config.py
├── requirements.txt
├── requirements-test.txt
└── pytest.ini # pytest 配置文件

pytest 配置文件

创建 pytest.ini 配置测试运行方式:

[pytest]
# 测试文件和目录的匹配模式
python_files = test_*.py
python_classes = Test*
python_functions = test_*

# 测试目录
testpaths = tests

# 命令行选项
addopts =
-v # 详细输出
--tb=short # 简短的回溯信息
--strict-markers # 严格标记模式
-ra # 显示所有测试摘要

# 标记定义
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks tests as integration tests
unit: marks tests as unit tests
api: marks tests for API endpoints

# 环境变量
env =
FLASK_ENV=testing
TESTING=True

pytest Fixtures

Fixtures 是 pytest 最强大的功能之一,用于提供测试所需的资源和环境。

基础 Fixtures

# tests/conftest.py
import pytest
from app import create_app
from app.extensions import db
from config import TestConfig

@pytest.fixture
def app():
"""创建测试应用实例"""
# 创建应用
app = create_app(TestConfig)

# 创建应用上下文
with app.app_context():
# 创建数据库表
db.create_all()

# 返回应用实例
yield app

# 清理:删除所有表
db.session.remove()
db.drop_all()

@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()

@pytest.fixture
def runner(app):
"""创建 CLI 测试运行器"""
return app.test_cli_runner()

@pytest.fixture
def app_context(app):
"""创建应用上下文"""
with app.app_context():
yield

数据库 Fixtures

使用事务回滚确保测试之间相互隔离:

# tests/conftest.py
from app.models import User, Post

@pytest.fixture
def db_session(app):
"""数据库会话 fixture,每个测试后回滚"""
with app.app_context():
# 开始事务
connection = db.engine.connect()
transaction = connection.begin()

# 绑定会话到连接
options = dict(bind=connection, binds={})
session = db.create_scoped_session(options=options)
db.session = session

yield session

# 回滚事务
transaction.rollback()
connection.close()
session.remove()

@pytest.fixture
def sample_user(db_session):
"""创建示例用户"""
user = User(
username='testuser',
email='[email protected]',
is_active=True
)
user.set_password('password123')
db_session.add(user)
db_session.commit()
return user

@pytest.fixture
def admin_user(db_session):
"""创建管理员用户"""
user = User(
username='admin',
email='[email protected]',
is_active=True,
is_admin=True
)
user.set_password('admin123')
db_session.add(user)
db_session.commit()
return user

@pytest.fixture
def sample_posts(db_session, sample_user):
"""创建示例文章"""
posts = []
for i in range(5):
post = Post(
title=f'测试文章 {i+1}',
content=f'这是第 {i+1} 篇测试文章的内容',
author=sample_user
)
db_session.add(post)
posts.append(post)
db_session.commit()
return posts

认证 Fixtures

简化需要登录的测试:

# tests/conftest.py

@pytest.fixture
def logged_in_client(client, sample_user):
"""已登录的测试客户端"""
with client.session_transaction() as session:
session['user_id'] = sample_user.id
session['_fresh'] = True
return client

@pytest.fixture
def admin_client(client, admin_user):
"""管理员登录的测试客户端"""
with client.session_transaction() as session:
session['user_id'] = admin_user.id
session['_fresh'] = True
return client

# 使用 Flask-Login 时
@pytest.fixture
def auth_client(client, sample_user):
"""使用 Flask-Login 登录的客户端"""
from flask_login import login_user

with client:
# 通过登录路由登录
client.post('/auth/login', data={
'username': 'testuser',
'password': 'password123'
})
yield client

工厂 Fixtures

使用工厂模式创建灵活的测试数据:

# tests/conftest.py

@pytest.fixture
def user_factory(db_session):
"""用户工厂 fixture"""
created_users = []

def create_user(**kwargs):
defaults = {
'username': f'user_{len(created_users)}',
'email': f'user_{len(created_users)}@example.com',
'is_active': True
}
defaults.update(kwargs)

user = User(**defaults)
if 'password' in kwargs:
user.set_password(kwargs['password'])
else:
user.set_password('default_password')

db_session.add(user)
db_session.commit()
created_users.append(user)
return user

yield create_user

# 清理创建的用户
for user in created_users:
db_session.delete(user)
db_session.commit()

@pytest.fixture
def post_factory(db_session, user_factory):
"""文章工厂 fixture"""
created_posts = []

def create_post(author=None, **kwargs):
if author is None:
author = user_factory()

defaults = {
'title': f'文章 {len(created_posts)}',
'content': '这是测试内容',
'author': author
}
defaults.update(kwargs)

post = Post(**defaults)
db_session.add(post)
db_session.commit()
created_posts.append(post)
return post

yield create_post

使用工厂 fixture:

def test_user_posts(post_factory, user_factory):
# 创建用户和文章
user = user_factory(username='alice')
post1 = post_factory(author=user, title='第一篇文章')
post2 = post_factory(author=user, title='第二篇文章')

# 测试...
assert user.posts.count() == 2

测试客户端

Flask 的测试客户端是测试 Web 应用的核心工具,它允许在不运行服务器的情况下发送 HTTP 请求。

发送请求

def test_get_request(client):
"""测试 GET 请求"""
response = client.get('/')

# 检查状态码
assert response.status_code == 200

# 检查响应内容(字节)
assert b'Hello' in response.data

# 检查响应内容(文本)
assert 'Hello' in response.text

# 检查 Content-Type
assert response.content_type == 'text/html; charset=utf-8'

def test_post_request(client):
"""测试 POST 请求"""
response = client.post('/api/users', json={
'username': 'newuser',
'email': '[email protected]'
})

assert response.status_code == 201
assert response.json['username'] == 'newuser'

def test_with_headers(client):
"""测试带请求头的请求"""
response = client.get(
'/api/profile',
headers={
'Authorization': 'Bearer test-token',
'Accept': 'application/json'
}
)
assert response.status_code == 200

def test_with_query_params(client):
"""测试带查询参数的请求"""
response = client.get('/api/users', query_string={
'page': 2,
'per_page': 10,
'search': 'john'
})
assert response.status_code == 200

表单数据测试

def test_form_submission(client):
"""测试表单提交"""
response = client.post('/auth/register', data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
}, follow_redirects=True)

assert response.status_code == 200
assert b'注册成功' in response.data

def test_file_upload(client):
"""测试文件上传"""
import io

data = {
'title': '测试文件',
'file': (io.BytesIO(b'文件内容'), 'test.txt', 'text/plain')
}

response = client.post(
'/upload',
data=data,
content_type='multipart/form-data'
)

assert response.status_code == 200

JSON API 测试

import json

def test_json_api(client):
"""测试 JSON API"""
# 发送 JSON 数据
response = client.post(
'/api/users',
json={
'username': 'apiuser',
'email': '[email protected]',
'password': 'password123'
}
)

assert response.status_code == 201

# 访问 JSON 响应
data = response.json
assert data['username'] == 'apiuser'
assert 'id' in data

def test_api_error_response(client):
"""测试 API 错误响应"""
response = client.post('/api/users', json={
'username': '', # 空用户名应触发错误
'email': 'invalid-email'
})

assert response.status_code == 400
assert 'error' in response.json
assert 'username' in response.json['errors']

跟随重定向

def test_redirect_follow(client):
"""测试跟随重定向"""
# 不跟随重定向
response = client.get('/logout')
assert response.status_code == 302

# 跟随重定向
response = client.get('/logout', follow_redirects=True)
assert response.status_code == 200

# 检查重定向历史
assert len(response.history) == 1
assert response.history[0].status_code == 302
assert response.request.path == '/login'

def test_login_redirect(client):
"""测试登录后重定向"""
# 访问需要认证的页面
response = client.get('/dashboard')
assert response.status_code == 302
assert '/login' in response.headers['Location']

# 登录后重定向回原页面
response = client.post('/login?next=/dashboard', data={
'username': 'testuser',
'password': 'password123'
}, follow_redirects=True)

assert response.request.path == '/dashboard'

Session 操作

from flask import session

def test_access_session(client):
"""访问 session"""
with client:
# 发送请求
client.post('/login', data={
'username': 'testuser',
'password': 'password123'
})

# 在 with 块内可以访问 session
assert 'user_id' in session
assert session['user_id'] == 1

def test_modify_session(client):
"""修改 session"""
# 预设 session
with client.session_transaction() as sess:
sess['user_id'] = 1
sess['username'] = 'testuser'

# 现在发送请求,session 已设置
response = client.get('/dashboard')
assert response.status_code == 200

def test_clear_session(client):
"""清除 session"""
with client.session_transaction() as sess:
sess['user_id'] = 1

# 登出
response = client.get('/logout', follow_redirects=True)

with client:
client.get('/') # 触发请求上下文
assert 'user_id' not in session

测试模型

模型创建和验证

# tests/unit/test_models.py
import pytest
from app.models import User, Post
from app.extensions import db
from datetime import datetime

class TestUserModel:
"""用户模型测试"""

def test_create_user(self, db_session):
"""测试创建用户"""
user = User(
username='testuser',
email='[email protected]'
)
user.set_password('password123')

db_session.add(user)
db_session.commit()

assert user.id is not None
assert user.username == 'testuser'
assert user.email == '[email protected]'
assert user.created_at is not None

def test_password_hashing(self, db_session):
"""测试密码哈希"""
user = User(username='test')
user.set_password('password')

# 密码应该被哈希
assert user.password_hash != 'password'

# 验证正确密码
assert user.check_password('password') is True

# 验证错误密码
assert user.check_password('wrong') is False

def test_user_repr(self, db_session):
"""测试字符串表示"""
user = User(username='testuser')
assert 'testuser' in repr(user)

def test_unique_username(self, db_session):
"""测试用户名唯一性"""
user1 = User(username='duplicate', email='[email protected]')
user1.set_password('password')
db_session.add(user1)
db_session.commit()

user2 = User(username='duplicate', email='[email protected]')
user2.set_password('password')
db_session.add(user2)

with pytest.raises(Exception): # IntegrityError
db_session.commit()

def test_user_to_dict(self, sample_user):
"""测试字典序列化"""
user_dict = sample_user.to_dict()

assert user_dict['id'] == sample_user.id
assert user_dict['username'] == sample_user.username
assert 'password' not in user_dict # 敏感信息不应暴露

class TestPostModel:
"""文章模型测试"""

def test_create_post(self, db_session, sample_user):
"""测试创建文章"""
post = Post(
title='测试文章',
content='这是测试内容',
author=sample_user
)

db_session.add(post)
db_session.commit()

assert post.id is not None
assert post.author_id == sample_user.id
assert post.author.username == sample_user.username

def test_post_ordering(self, db_session, sample_user):
"""测试文章排序"""
from time import sleep

post1 = Post(title='第一篇', content='...', author=sample_user)
db_session.add(post1)
db_session.commit()

sleep(0.1) # 确保时间差

post2 = Post(title='第二篇', content='...', author=sample_user)
db_session.add(post2)
db_session.commit()

posts = Post.query.order_by(Post.created_at.desc()).all()
assert posts[0].title == '第二篇'
assert posts[1].title == '第一篇'

关系测试

class TestRelationships:
"""测试模型关系"""

def test_user_posts_relationship(self, db_session, sample_user):
"""测试用户-文章关系"""
# 创建多篇文章
for i in range(3):
post = Post(
title=f'文章 {i}',
content='内容',
author=sample_user
)
db_session.add(post)
db_session.commit()

# 测试关系
assert sample_user.posts.count() == 3

# 测试级联删除
db_session.delete(sample_user)
db_session.commit()

assert Post.query.count() == 0

def test_many_to_many(self, db_session):
"""测试多对多关系"""
from app.models import Tag

post = Post(title='测试', content='内容')
tag1 = Tag(name='Python')
tag2 = Tag(name='Flask')

post.tags.append(tag1)
post.tags.append(tag2)

db_session.add_all([post, tag1, tag2])
db_session.commit()

assert len(post.tags) == 2
assert tag1.posts.count() == 1

测试视图函数

路由测试

# tests/integration/test_routes.py

class TestPublicRoutes:
"""公开路由测试"""

def test_home_page(self, client):
"""测试首页"""
response = client.get('/')
assert response.status_code == 200
assert b'Welcome' in response.data

def test_about_page(self, client):
"""测试关于页面"""
response = client.get('/about')
assert response.status_code == 200

def test_404_page(self, client):
"""测试 404 页面"""
response = client.get('/nonexistent')
assert response.status_code == 404

class TestAuthRoutes:
"""认证路由测试"""

def test_login_page(self, client):
"""测试登录页面"""
response = client.get('/auth/login')
assert response.status_code == 200
assert b'login' in response.data.lower()

def test_login_success(self, client, sample_user):
"""测试登录成功"""
response = client.post('/auth/login', data={
'username': 'testuser',
'password': 'password123'
}, follow_redirects=True)

assert response.status_code == 200
assert b'欢迎' in response.data

def test_login_invalid_password(self, client, sample_user):
"""测试密码错误"""
response = client.post('/auth/login', data={
'username': 'testuser',
'password': 'wrongpassword'
})

assert response.status_code == 200 # 返回登录页面
assert b'用户名或密码错误' in response.data

def test_register(self, client):
"""测试注册"""
response = client.post('/auth/register', data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
}, follow_redirects=True)

assert response.status_code == 200
assert b'注册成功' in response.data

class TestProtectedRoutes:
"""受保护路由测试"""

def test_dashboard_requires_login(self, client):
"""测试需要登录"""
response = client.get('/dashboard')
assert response.status_code == 302
assert '/login' in response.headers['Location']

def test_dashboard_with_login(self, logged_in_client):
"""测试登录后访问"""
response = logged_in_client.get('/dashboard')
assert response.status_code == 200

def test_admin_requires_admin(self, logged_in_client):
"""测试需要管理员权限"""
response = logged_in_client.get('/admin')
assert response.status_code == 403

def test_admin_with_admin_user(self, admin_client):
"""测试管理员访问"""
response = admin_client.get('/admin')
assert response.status_code == 200

API 测试

# tests/integration/test_api.py
import json

class TestUserAPI:
"""用户 API 测试"""

def test_list_users(self, client, sample_user):
"""测试获取用户列表"""
response = client.get('/api/v1/users')

assert response.status_code == 200
data = response.json
assert isinstance(data, list)
assert len(data) > 0

def test_get_user(self, client, sample_user):
"""测试获取单个用户"""
response = client.get(f'/api/v1/users/{sample_user.id}')

assert response.status_code == 200
assert response.json['username'] == sample_user.username

def test_get_user_not_found(self, client):
"""测试获取不存在的用户"""
response = client.get('/api/v1/users/99999')
assert response.status_code == 404

def test_create_user(self, client):
"""测试创建用户"""
response = client.post('/api/v1/users', json={
'username': 'apiuser',
'email': '[email protected]',
'password': 'password123'
})

assert response.status_code == 201
assert response.json['username'] == 'apiuser'

def test_create_user_validation_error(self, client):
"""测试创建用户验证失败"""
response = client.post('/api/v1/users', json={
'username': '', # 空用户名
'email': 'invalid' # 无效邮箱
})

assert response.status_code == 400
assert 'errors' in response.json

def test_update_user(self, client, sample_user):
"""测试更新用户"""
response = client.put(
f'/api/v1/users/{sample_user.id}',
json={'username': 'updated_name'}
)

assert response.status_code == 200
assert response.json['username'] == 'updated_name'

def test_delete_user(self, client, sample_user):
"""测试删除用户"""
response = client.delete(f'/api/v1/users/{sample_user.id}')

assert response.status_code == 204

# 确认已删除
response = client.get(f'/api/v1/users/{sample_user.id}')
assert response.status_code == 404

class TestAuthenticationAPI:
"""认证 API 测试"""

def test_token_generation(self, client, sample_user):
"""测试获取 Token"""
response = client.post('/api/v1/auth/login', json={
'username': 'testuser',
'password': 'password123'
})

assert response.status_code == 200
assert 'access_token' in response.json

def test_protected_endpoint_with_token(self, client, sample_user):
"""测试使用 Token 访问受保护端点"""
# 获取 Token
login_response = client.post('/api/v1/auth/login', json={
'username': 'testuser',
'password': 'password123'
})
token = login_response.json['access_token']

# 使用 Token 访问
response = client.get(
'/api/v1/profile',
headers={'Authorization': f'Bearer {token}'}
)

assert response.status_code == 200

def test_protected_endpoint_without_token(self, client):
"""测试无 Token 访问受保护端点"""
response = client.get('/api/v1/profile')
assert response.status_code == 401

测试表单

# tests/unit/test_forms.py
from app.forms import LoginForm, RegistrationForm

class TestLoginForm:
"""登录表单测试"""

def test_valid_login_form(self, app):
"""测试有效的登录表单"""
with app.app_context():
form = LoginForm(data={
'username': 'testuser',
'password': 'password123',
'remember': True
})
assert form.validate() is True

def test_empty_username(self, app):
"""测试空用户名"""
with app.app_context():
form = LoginForm(data={
'username': '',
'password': 'password123'
})
assert form.validate() is False
assert 'username' in form.errors

def test_empty_password(self, app):
"""测试空密码"""
with app.app_context():
form = LoginForm(data={
'username': 'testuser',
'password': ''
})
assert form.validate() is False
assert 'password' in form.errors

class TestRegistrationForm:
"""注册表单测试"""

def test_valid_registration(self, app):
"""测试有效的注册表单"""
with app.app_context():
form = RegistrationForm(data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() is True

def test_password_mismatch(self, app):
"""测试密码不匹配"""
with app.app_context():
form = RegistrationForm(data={
'username': 'newuser',
'email': '[email protected]',
'password': 'password123',
'confirm': 'different'
})
assert form.validate() is False
assert 'confirm' in form.errors

def test_invalid_email(self, app):
"""测试无效邮箱"""
with app.app_context():
form = RegistrationForm(data={
'username': 'newuser',
'email': 'not-an-email',
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() is False
assert 'email' in form.errors

def test_short_username(self, app):
"""测试用户名太短"""
with app.app_context():
form = RegistrationForm(data={
'username': 'ab', # 少于 3 个字符
'email': '[email protected]',
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() is False
assert 'username' in form.errors

Mock 和依赖注入

Mock 用于隔离测试单元,避免依赖外部服务。

Mock 外部 API

# tests/unit/test_services.py
import pytest
from unittest.mock import patch, MagicMock
from app.services import NotificationService, PaymentService

class TestNotificationService:
"""通知服务测试"""

@patch('app.services.requests.post')
def test_send_email(self, mock_post):
"""测试发送邮件(Mock HTTP 请求)"""
# 配置 Mock 返回值
mock_post.return_value = MagicMock(
status_code=200,
json=lambda: {'success': True, 'id': 'mail-123'}
)

service = NotificationService()
result = service.send_email(
to='[email protected]',
subject='测试邮件',
body='邮件内容'
)

# 验证调用
assert result['success'] is True
mock_post.assert_called_once()

# 验证调用参数
call_args = mock_post.call_args
assert '[email protected]' in str(call_args)

@patch('app.services.requests.post')
def test_send_email_failure(self, mock_post):
"""测试发送邮件失败"""
mock_post.return_value = MagicMock(
status_code=500,
json=lambda: {'success': False, 'error': 'Server error'}
)

service = NotificationService()
result = service.send_email(
to='[email protected]',
subject='测试',
body='内容'
)

assert result['success'] is False

class TestPaymentService:
"""支付服务测试"""

@patch('stripe.Charge.create')
def test_process_payment(self, mock_charge):
"""测试支付处理"""
mock_charge.return_value = {
'id': 'ch_123',
'status': 'succeeded',
'amount': 1000
}

service = PaymentService()
result = service.process_payment(
amount=10.00,
currency='usd',
token='tok_visa'
)

assert result['success'] is True
assert result['charge_id'] == 'ch_123'

@patch('stripe.Charge.create')
def test_payment_declined(self, mock_charge):
"""测试支付被拒绝"""
import stripe
mock_charge.side_effect = stripe.error.CardError(
'Card declined', 'param', 'declined'
)

service = PaymentService()
result = service.process_payment(
amount=10.00,
currency='usd',
token='tok_declined'
)

assert result['success'] is False
assert 'declined' in result['error']

Mock 数据库

from unittest.mock import patch, MagicMock

class TestUserService:
"""用户服务测试"""

@patch('app.services.User.query')
def test_get_user_by_username(self, mock_query):
"""测试通过用户名获取用户"""
# 创建 Mock 用户
mock_user = MagicMock()
mock_user.username = 'testuser'
mock_user.email = '[email protected]'

# 配置 Mock 返回
mock_query.filter_by.return_value.first.return_value = mock_user

from app.services import get_user_by_username
user = get_user_by_username('testuser')

assert user.username == 'testuser'
mock_query.filter_by.assert_called_once_with(username='testuser')

@patch('app.services.db.session')
def test_create_user(self, mock_session):
"""测试创建用户"""
from app.services import create_user

user = create_user(
username='newuser',
email='[email protected]',
password='password123'
)

# 验证 add 和 commit 被调用
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()

使用 pytest-mock

# pytest-mock 提供了更简洁的 Mock 接口

def test_with_mocker(mocker):
"""使用 pytest-mock"""
# mocker.patch 自动清理
mock_get = mocker.patch('requests.get')
mock_get.return_value.json.return_value = {'data': 'test'}

# 测试代码...

# 验证调用
mock_get.assert_called_once()

def test_spy_with_mocker(mocker, sample_user):
"""使用 Spy 监控方法调用"""
# spy 会记录调用但不改变行为
spy = mocker.spy(sample_user, 'check_password')

sample_user.check_password('password123')

spy.assert_called_once_with('password123')

测试覆盖率

配置覆盖率

创建 .coveragerc

[run]
source = app
branch = True
omit =
app/__init__.py
app/config.py
*/tests/*
*/migrations/*

[report]
exclude_lines =
pragma: no cover
def __repr__
raise AssertionError
raise NotImplementedError
if __name__ == .__main__.:
if TYPE_CHECKING:
@abstractmethod
precision = 2

[html]
directory = htmlcov

运行覆盖率测试

# 运行测试并收集覆盖率
pytest --cov=app --cov-report=term-missing

# 生成 HTML 报告
pytest --cov=app --cov-report=html

# 同时生成多种报告
pytest --cov=app --cov-report=term --cov-report=html --cov-report=xml

覆盖率报告解读

Name                           Stmts   Miss  Cover   Missing
------------------------------------------------------------
app/__init__.py 10 0 100%
app/models/__init__.py 0 0 100%
app/models/user.py 35 2 94% 45-46
app/models/post.py 28 5 82% 23, 45-48
app/routes/auth.py 45 10 78% 34-35, 67-72
------------------------------------------------------------
TOTAL 118 17 86%
  • Stmts:语句总数
  • Miss:未覆盖的语句数
  • Cover:覆盖率百分比
  • Missing:未覆盖的行号

覆盖率目标

  • 单元测试:80% 以上
  • 关键业务逻辑:95% 以上
  • 整体项目:70% 以上

注意:覆盖率只是指标,不是目标。100% 覆盖率不代表没有 bug。

测试命令行命令

# tests/test_cli.py

def test_init_db_command(runner):
"""测试初始化数据库命令"""
result = runner.invoke(args=['init-db'])
assert 'Initialized' in result.output

def test_create_user_command(runner):
"""测试创建用户命令"""
result = runner.invoke(args=[
'create-user',
'--username', 'cliuser',
'--email', '[email protected]'
])

assert result.exit_code == 0
assert 'User cliuser created' in result.output

def test_command_with_error(runner):
"""测试命令错误处理"""
result = runner.invoke(args=['create-user']) # 缺少参数
assert result.exit_code != 0
assert 'Error' in result.output or 'Missing' in result.output

测试异步视图

import pytest

@pytest.mark.asyncio
async def test_async_view(client):
"""测试异步视图"""
response = await client.get('/async-data')
assert response.status_code == 200

@pytest.mark.asyncio
async def test_async_api(client):
"""测试异步 API"""
response = await client.post('/api/async-users', json={
'username': 'asyncuser',
'email': '[email protected]'
})

assert response.status_code == 201

测试最佳实践

1. 测试命名

# 好的命名:描述测试场景和预期结果
def test_login_with_valid_credentials_redirects_to_dashboard():
pass

def test_create_user_with_duplicate_username_returns_409():
pass

# 不好的命名:太模糊
def test_login():
pass

def test_user():
pass

2. 测试结构:AAA 模式

def test_user_creation():
# Arrange(准备):设置测试数据
user_data = {
'username': 'testuser',
'email': '[email protected]'
}

# Act(执行):执行被测试的操作
user = User(**user_data)
user.set_password('password')
db.session.add(user)
db.session.commit()

# Assert(断言):验证结果
assert user.id is not None
assert user.username == 'testuser'

3. 测试隔离

# 每个测试应该独立,不依赖其他测试
def test_create_user(db_session):
user = User(username='user1')
db_session.add(user)
db_session.commit()
assert user.id is not None

def test_create_another_user(db_session):
# 这个测试不应该依赖上一个测试的数据
user = User(username='user2')
db_session.add(user)
db_session.commit()
assert user.id is not None

4. 测试边界情况

class TestUserValidation:
"""测试用户验证边界情况"""

def test_empty_username(self):
"""空用户名"""
pass

def test_username_too_long(self):
"""用户名超长"""
pass

def test_username_special_chars(self):
"""用户名包含特殊字符"""
pass

def test_username_max_length(self):
"""用户名最大长度"""
pass

5. 使用参数化测试

@pytest.mark.parametrize('username,email,expected_valid', [
('validuser', '[email protected]', True),
('', '[email protected]', False), # 空用户名
('ab', '[email protected]', False), # 用户名太短
('validuser', 'invalid-email', False), # 无效邮箱
('validuser', '', False), # 空邮箱
])
def test_user_validation(username, email, expected_valid, app):
"""参数化测试用户验证"""
with app.app_context():
form = RegistrationForm(data={
'username': username,
'email': email,
'password': 'password123',
'confirm': 'password123'
})
assert form.validate() == expected_valid

6. 避免测试实现细节

# 不好的测试:测试实现细节
def test_user_password_hash():
user = User(username='test')
user.set_password('password')
assert user.password_hash.startswith('pbkdf2:') # 依赖具体实现

# 好的测试:测试行为
def test_user_password_verification():
user = User(username='test')
user.set_password('password')
assert user.check_password('password') is True
assert user.check_password('wrong') is False

CI/CD 集成

GitHub Actions

# .github/workflows/tests.yml
name: Tests

on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

services:
postgres:
image: postgres:15
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: test_db
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt

- name: Run tests
env:
DATABASE_URL: postgresql://test:test@localhost:5432/test_db
SECRET_KEY: test-secret-key
run: |
pytest --cov=app --cov-report=xml --cov-report=term

- name: Upload coverage
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml

GitLab CI

# .gitlab-ci.yml
stages:
- test

test:
stage: test
image: python:3.11
services:
- postgres:15
variables:
POSTGRES_DB: test_db
POSTGRES_USER: test
POSTGRES_PASSWORD: test
DATABASE_URL: postgresql://test:test@postgres:5432/test_db
script:
- pip install -r requirements.txt
- pip install -r requirements-test.txt
- pytest --cov=app --cov-report=term
coverage: '/TOTAL.+?(\d+%)$/'

常见问题

测试数据库迁移

def test_database_migration(app):
"""测试数据库迁移"""
from flask_migrate import upgrade

# 升级到最新版本
upgrade()

# 验证表存在
with app.app_context():
from sqlalchemy import inspect
inspector = inspect(db.engine)
tables = inspector.get_table_names()

assert 'users' in tables
assert 'posts' in tables

测试 Flask-Mail

from flask_mail import Mail, Message

def test_email_sending(app):
"""测试邮件发送"""
app.config['MAIL_SUPPRESS_SEND'] = True

with app.app_context():
with Mail(app).record_messages() as outbox:
# 发送邮件
msg = Message(
'测试邮件',
sender='[email protected]',
recipients=['[email protected]']
)
msg.body = '这是测试邮件内容'
Mail(app).send(msg)

# 验证
assert len(outbox) == 1
assert outbox[0].subject == '测试邮件'

测试 Flask-Cache

def test_cache(client, mocker):
"""测试缓存"""
# Mock 缓存
mock_cache = mocker.patch('app.extensions.cache.get')
mock_cache.return_value = {'cached': 'data'}

response = client.get('/api/cached-data')

assert response.json == {'cached': 'data'}
mock_cache.assert_called_once()

小结

本章详细介绍了 Flask 应用的测试方法:

  1. 测试基础:理解测试层次和金字塔模型
  2. pytest Fixtures:创建可复用的测试资源和数据
  3. 测试客户端:模拟 HTTP 请求测试视图
  4. 模型测试:验证数据模型的行为和关系
  5. API 测试:测试 RESTful API 端点
  6. Mock 技术:隔离外部依赖进行单元测试
  7. 测试覆盖率:衡量测试完整性
  8. 最佳实践:编写清晰、可靠、可维护的测试

良好的测试习惯是高质量软件的基石。投入时间编写测试会在长期开发中带来丰厚回报。

参考资料