测试
测试是确保代码质量的重要手段。FastAPI 提供了完善的测试支持,可以轻松编写单元测试和集成测试。
测试基础
安装依赖
pip install pytest httpx
创建测试客户端
使用 TestClient 创建测试客户端:
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/")
async def read_root():
return {"message": "Hello World"}
# 创建测试客户端
client = TestClient(app)
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello World"}
运行测试
# 运行所有测试
pytest
# 运行指定文件
pytest test_main.py
# 显示详细信息
pytest -v
# 显示打印输出
pytest -s
测试 GET 请求
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/items/{item_id}")
async def read_item(item_id: int, q: str = None):
return {"item_id": item_id, "q": q}
client = TestClient(app)
def test_read_item():
# 测试基本请求
response = client.get("/items/42")
assert response.status_code == 200
assert response.json() == {"item_id": 42, "q": None}
def test_read_item_with_query():
# 测试带查询参数的请求
response = client.get("/items/42?q=search")
assert response.status_code == 200
assert response.json() == {"item_id": 42, "q": "search"}
测试 POST 请求
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel
app = FastAPI()
class Item(BaseModel):
name: str
price: float
@app.post("/items/")
async def create_item(item: Item):
return item
client = TestClient(app)
def test_create_item():
# 发送 JSON 数据
response = client.post(
"/items/",
json={"name": "商品", "price": 99.9}
)
assert response.status_code == 200
assert response.json() == {"name": "商品", "price": 99.9}
def test_create_item_invalid():
# 测试验证错误
response = client.post(
"/items/",
json={"name": "商品"} # 缺少 price
)
assert response.status_code == 422
测试文件上传
from fastapi import FastAPI, UploadFile, File
from fastapi.testclient import TestClient
app = FastAPI()
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
content = await file.read()
return {"filename": file.filename, "size": len(content)}
client = TestClient(app)
def test_upload_file():
# 创建测试文件
file_content = b"test file content"
response = client.post(
"/upload/",
files={"file": ("test.txt", file_content, "text/plain")}
)
assert response.status_code == 200
assert response.json() == {"filename": "test.txt", "size": 17}
测试表单数据
from fastapi import FastAPI, Form
from fastapi.testclient import TestClient
app = FastAPI()
@app.post("/login/")
async def login(username: str = Form(...), password: str = Form(...)):
return {"username": username}
client = TestClient(app)
def test_login():
response = client.post(
"/login/",
data={"username": "user", "password": "pass"}
)
assert response.status_code == 200
assert response.json() == {"username": "user"}
测试认证
测试 OAuth2
from fastapi import FastAPI, Depends
from fastapi.security import OAuth2PasswordBearer
from fastapi.testclient import TestClient
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@app.get("/users/me")
async def read_users_me(token: str = Depends(oauth2_scheme)):
return {"token": token}
client = TestClient(app)
def test_with_auth():
# 添加 Authorization 头
response = client.get(
"/users/me",
headers={"Authorization": "Bearer test-token"}
)
assert response.status_code == 200
assert response.json() == {"token": "test-token"}
def test_without_auth():
# 未提供 token
response = client.get("/users/me")
assert response.status_code == 401
测试 API Key
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import APIKeyHeader
from fastapi.testclient import TestClient
app = FastAPI()
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != "secret-key":
raise HTTPException(status_code=403)
return api_key
@app.get("/protected")
async def protected(api_key: str = Depends(get_api_key)):
return {"message": "success"}
client = TestClient(app)
def test_valid_api_key():
response = client.get(
"/protected",
headers={"X-API-Key": "secret-key"}
)
assert response.status_code == 200
def test_invalid_api_key():
response = client.get(
"/protected",
headers={"X-API-Key": "wrong-key"}
)
assert response.status_code == 403
测试依赖覆盖
在测试中可以覆盖依赖,使用测试数据或模拟服务:
from typing import Annotated
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient
app = FastAPI()
# 正常依赖:从数据库获取用户
async def get_current_user():
# 实际应用中从数据库或 token 获取
return {"id": 1, "name": "Real User"}
@app.get("/users/me")
async def read_users_me(user: Annotated[dict, Depends(get_current_user)]):
return user
# 测试依赖:返回测试用户
async def override_get_current_user():
return {"id": 999, "name": "Test User"}
client = TestClient(app)
def test_read_users_me():
# 覆盖依赖
app.dependency_overrides[get_current_user] = override_get_current_user
response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"id": 999, "name": "Test User"}
# 清除覆盖
app.dependency_overrides = {}
测试数据库
使用测试数据库
import pytest
from typing import Generator
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine, select
# 测试数据库
TEST_DATABASE_URL = "sqlite:///./test.db"
test_engine = create_engine(TEST_DATABASE_URL)
def get_test_session():
with Session(test_engine) as session:
yield session
app = FastAPI()
app.dependency_overrides[get_session] = get_test_session
@pytest.fixture(scope="function")
def client() -> Generator:
# 每个测试前创建表
SQLModel.metadata.create_all(test_engine)
with TestClient(app) as c:
yield c
# 测试后删除所有数据
SQLModel.metadata.drop_all(test_engine)
def test_create_item(client: TestClient):
response = client.post("/items/", json={"name": "Test Item"})
assert response.status_code == 201
使用内存数据库
import pytest
from typing import Generator
from sqlmodel import Session, SQLModel, create_engine
@pytest.fixture(name="session")
def session_fixture() -> Generator[Session, None, None]:
# 使用内存数据库
engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
@pytest.fixture(name="client")
def client_fixture(session: Session) -> Generator:
def get_session_override():
return session
app.dependency_overrides[get_session] = get_session_override
from fastapi.testclient import TestClient
client = TestClient(app)
yield client
app.dependency_overrides.clear()
测试异常
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/items/{item_id}")
async def read_item(item_id: int):
if item_id == 0:
raise HTTPException(status_code=404, detail="Item not found")
return {"item_id": item_id}
client = TestClient(app)
def test_item_not_found():
response = client.get("/items/0")
assert response.status_code == 404
assert response.json()["detail"] == "Item not found"
测试事件
测试生命周期事件
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
state = {"startup": False, "shutdown": False}
@app.on_event("startup")
async def startup():
state["startup"] = True
@app.on_event("shutdown")
async def shutdown():
state["shutdown"] = True
def test_lifespan():
with TestClient(app) as client:
assert state["startup"] is True
assert state["shutdown"] is False
assert state["shutdown"] is True
使用 Lifespan
from contextlib import asynccontextmanager
from fastapi import FastAPI
state = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动
state["db"] = "connected"
yield
# 关闭
state.pop("db", None)
app = FastAPI(lifespan=lifespan)
def test_with_lifespan():
from fastapi.testclient import TestClient
with TestClient(app) as client:
assert "db" in state
pytest 高级用法
使用 pytest fixtures
import pytest
from fastapi.testclient import TestClient
@pytest.fixture
def client():
from main import app
return TestClient(app)
@pytest.fixture
def auth_headers():
return {"Authorization": "Bearer test-token"}
def test_protected_route(client, auth_headers):
response = client.get("/protected", headers=auth_headers)
assert response.status_code == 200
参数化测试
import pytest
@pytest.mark.parametrize("item_id,expected_status", [
(1, 200),
(2, 200),
(999, 404),
(0, 404),
])
def test_read_item(client, item_id, expected_status):
response = client.get(f"/items/{item_id}")
assert response.status_code == expected_status
测试分组
# 使用 mark 标记测试
@pytest.mark.slow
def test_slow_operation():
import time
time.sleep(1)
assert True
@pytest.mark.integration
def test_database():
# 集成测试
pass
# 运行特定标记的测试
# pytest -m slow
# pytest -m "not slow"
# pytest -m "slow and integration"
测试覆盖率
安装 pytest-cov
pip install pytest-cov
运行测试并生成覆盖率报告
# 终端输出
pytest --cov=app tests/
# 生成 HTML 报告
pytest --cov=app --cov-report=html tests/
配置 .coveragerc
[run]
source = app
omit =
app/tests/*
*/__pycache__/*
[report]
exclude_lines =
pragma: no cover
if __name__ == .__main__.:
完整测试示例
# test_main.py
import pytest
from typing import Generator, Annotated
from fastapi import FastAPI, Depends, HTTPException
from fastapi.testclient import TestClient
from pydantic import BaseModel
from sqlmodel import Session, SQLModel, create_engine, Field, select
# 模型
class Item(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
price: float
class ItemCreate(BaseModel):
name: str
price: float
# 数据库
TEST_URL = "sqlite:///:memory:"
engine = create_engine(TEST_URL)
def get_session():
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
# 应用
app = FastAPI()
@app.on_event("startup")
def startup():
SQLModel.metadata.create_all(engine)
@app.post("/items/", response_model=Item)
def create_item(item: ItemCreate, session: SessionDep):
db_item = Item.model_validate(item)
session.add(db_item)
session.commit()
session.refresh(db_item)
return db_item
@app.get("/items/", response_model=list[Item])
def read_items(session: SessionDep):
return session.exec(select(Item)).all()
@app.get("/items/{item_id}", response_model=Item)
def read_item(item_id: int, session: SessionDep):
item = session.get(Item, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return item
# 测试
@pytest.fixture(name="client")
def client_fixture():
SQLModel.metadata.create_all(engine)
with TestClient(app) as client:
yield client
SQLModel.metadata.drop_all(engine)
class TestItems:
"""商品接口测试"""
def test_create_item(self, client):
"""测试创建商品"""
response = client.post(
"/items/",
json={"name": "测试商品", "price": 99.9}
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "测试商品"
assert data["price"] == 99.9
assert "id" in data
def test_create_item_invalid(self, client):
"""测试创建商品验证失败"""
response = client.post(
"/items/",
json={"name": "测试商品"} # 缺少 price
)
assert response.status_code == 422
def test_read_items(self, client):
"""测试获取商品列表"""
# 创建测试数据
client.post("/items/", json={"name": "商品1", "price": 10})
client.post("/items/", json={"name": "商品2", "price": 20})
response = client.get("/items/")
assert response.status_code == 200
assert len(response.json()) == 2
@pytest.mark.parametrize("item_id,should_exist", [
(1, True),
(999, False),
])
def test_read_item(self, client, item_id, should_exist):
"""测试获取单个商品"""
# 创建测试数据
client.post("/items/", json={"name": "测试商品", "price": 99.9})
response = client.get(f"/items/{item_id}")
if should_exist:
assert response.status_code == 200
else:
assert response.status_code == 404
小结
本章我们学习了:
- 测试客户端:使用 TestClient 测试 API
- 请求测试:测试 GET、POST、文件上传等
- 认证测试:测试 OAuth2 和 API Key 认证
- 依赖覆盖:在测试中替换依赖
- 数据库测试:使用测试数据库
- pytest 高级用法:fixtures、参数化、分组
- 测试覆盖率:使用 pytest-cov
测试的最佳实践:
- 每个测试应该独立,不依赖其他测试
- 使用 fixtures 共享测试数据和配置
- 覆盖正常和异常情况
- 测试验证逻辑和安全功能
练习
- 编写一个用户注册接口的完整测试用例
- 测试一个需要认证的接口,使用依赖覆盖模拟用户
- 使用内存数据库测试 CRUD 操作
- 编写参数化测试,验证输入验证逻辑