跳到主要内容

测试

测试是确保代码质量的重要手段。FastAPI 提供了完善的测试支持,可以轻松编写单元测试和集成测试。

测试基础

安装依赖

pip install pytest httpx

创建测试客户端

使用 TestClient 创建测试客户端:

from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/")
async def read_root():
return {"message": "Hello World"}

# 创建测试客户端
client = TestClient(app)

def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello World"}

运行测试

# 运行所有测试
pytest

# 运行指定文件
pytest test_main.py

# 显示详细信息
pytest -v

# 显示打印输出
pytest -s

测试 GET 请求

from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/items/{item_id}")
async def read_item(item_id: int, q: str = None):
return {"item_id": item_id, "q": q}

client = TestClient(app)

def test_read_item():
# 测试基本请求
response = client.get("/items/42")
assert response.status_code == 200
assert response.json() == {"item_id": 42, "q": None}

def test_read_item_with_query():
# 测试带查询参数的请求
response = client.get("/items/42?q=search")
assert response.status_code == 200
assert response.json() == {"item_id": 42, "q": "search"}

测试 POST 请求

from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel

app = FastAPI()

class Item(BaseModel):
name: str
price: float

@app.post("/items/")
async def create_item(item: Item):
return item

client = TestClient(app)

def test_create_item():
# 发送 JSON 数据
response = client.post(
"/items/",
json={"name": "商品", "price": 99.9}
)
assert response.status_code == 200
assert response.json() == {"name": "商品", "price": 99.9}

def test_create_item_invalid():
# 测试验证错误
response = client.post(
"/items/",
json={"name": "商品"} # 缺少 price
)
assert response.status_code == 422

测试文件上传

from fastapi import FastAPI, UploadFile, File
from fastapi.testclient import TestClient

app = FastAPI()

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
content = await file.read()
return {"filename": file.filename, "size": len(content)}

client = TestClient(app)

def test_upload_file():
# 创建测试文件
file_content = b"test file content"

response = client.post(
"/upload/",
files={"file": ("test.txt", file_content, "text/plain")}
)

assert response.status_code == 200
assert response.json() == {"filename": "test.txt", "size": 17}

测试表单数据

from fastapi import FastAPI, Form
from fastapi.testclient import TestClient

app = FastAPI()

@app.post("/login/")
async def login(username: str = Form(...), password: str = Form(...)):
return {"username": username}

client = TestClient(app)

def test_login():
response = client.post(
"/login/",
data={"username": "user", "password": "pass"}
)
assert response.status_code == 200
assert response.json() == {"username": "user"}

测试认证

测试 OAuth2

from fastapi import FastAPI, Depends
from fastapi.security import OAuth2PasswordBearer
from fastapi.testclient import TestClient

app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

@app.get("/users/me")
async def read_users_me(token: str = Depends(oauth2_scheme)):
return {"token": token}

client = TestClient(app)

def test_with_auth():
# 添加 Authorization 头
response = client.get(
"/users/me",
headers={"Authorization": "Bearer test-token"}
)
assert response.status_code == 200
assert response.json() == {"token": "test-token"}

def test_without_auth():
# 未提供 token
response = client.get("/users/me")
assert response.status_code == 401

测试 API Key

from fastapi import FastAPI, Security, HTTPException
from fastapi.security import APIKeyHeader
from fastapi.testclient import TestClient

app = FastAPI()
api_key_header = APIKeyHeader(name="X-API-Key")

async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != "secret-key":
raise HTTPException(status_code=403)
return api_key

@app.get("/protected")
async def protected(api_key: str = Depends(get_api_key)):
return {"message": "success"}

client = TestClient(app)

def test_valid_api_key():
response = client.get(
"/protected",
headers={"X-API-Key": "secret-key"}
)
assert response.status_code == 200

def test_invalid_api_key():
response = client.get(
"/protected",
headers={"X-API-Key": "wrong-key"}
)
assert response.status_code == 403

测试依赖覆盖

在测试中可以覆盖依赖,使用测试数据或模拟服务:

from typing import Annotated
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient

app = FastAPI()

# 正常依赖:从数据库获取用户
async def get_current_user():
# 实际应用中从数据库或 token 获取
return {"id": 1, "name": "Real User"}

@app.get("/users/me")
async def read_users_me(user: Annotated[dict, Depends(get_current_user)]):
return user

# 测试依赖:返回测试用户
async def override_get_current_user():
return {"id": 999, "name": "Test User"}

client = TestClient(app)

def test_read_users_me():
# 覆盖依赖
app.dependency_overrides[get_current_user] = override_get_current_user

response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"id": 999, "name": "Test User"}

# 清除覆盖
app.dependency_overrides = {}

测试数据库

使用测试数据库

import pytest
from typing import Generator
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine, select

# 测试数据库
TEST_DATABASE_URL = "sqlite:///./test.db"
test_engine = create_engine(TEST_DATABASE_URL)

def get_test_session():
with Session(test_engine) as session:
yield session

app = FastAPI()
app.dependency_overrides[get_session] = get_test_session

@pytest.fixture(scope="function")
def client() -> Generator:
# 每个测试前创建表
SQLModel.metadata.create_all(test_engine)

with TestClient(app) as c:
yield c

# 测试后删除所有数据
SQLModel.metadata.drop_all(test_engine)

def test_create_item(client: TestClient):
response = client.post("/items/", json={"name": "Test Item"})
assert response.status_code == 201

使用内存数据库

import pytest
from typing import Generator
from sqlmodel import Session, SQLModel, create_engine

@pytest.fixture(name="session")
def session_fixture() -> Generator[Session, None, None]:
# 使用内存数据库
engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(engine)

with Session(engine) as session:
yield session

@pytest.fixture(name="client")
def client_fixture(session: Session) -> Generator:
def get_session_override():
return session

app.dependency_overrides[get_session] = get_session_override

from fastapi.testclient import TestClient
client = TestClient(app)
yield client

app.dependency_overrides.clear()

测试异常

from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/items/{item_id}")
async def read_item(item_id: int):
if item_id == 0:
raise HTTPException(status_code=404, detail="Item not found")
return {"item_id": item_id}

client = TestClient(app)

def test_item_not_found():
response = client.get("/items/0")
assert response.status_code == 404
assert response.json()["detail"] == "Item not found"

测试事件

使用 Lifespan(推荐方式)

FastAPI 推荐使用 lifespan 参数来管理应用的生命周期事件。在测试中,TestClient 会自动触发 lifespan 事件:

from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.testclient import TestClient

# 应用状态
state = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时执行
state["db"] = "connected"
state["startup"] = True
yield
# 关闭时执行
state["db"] = None
state["shutdown"] = True

app = FastAPI(lifespan=lifespan)

@app.get("/")
async def root():
return {"db": state.get("db")}

def test_lifespan():
# 进入上下文时触发 startup
with TestClient(app) as client:
assert state["startup"] is True
assert state["db"] == "connected"

response = client.get("/")
assert response.json() == {"db": "connected"}

# 退出上下文时触发 shutdown
assert state["shutdown"] is True

工作原理

TestClient 作为上下文管理器使用时:

  1. 进入 with 块时,执行 lifespanyield 之前的代码(启动)
  2. 执行测试代码
  3. 退出 with 块时,执行 lifespanyield 之后的代码(关闭)

关于 on_event(已不推荐)

虽然 @app.on_event("startup")@app.on_event("shutdown") 仍然可用,但官方推荐使用 lifespan 方式:

# 已不推荐的方式
@app.on_event("startup")
async def startup():
state["startup"] = True

@app.on_event("shutdown")
async def shutdown():
state["shutdown"] = True

注意:如果同时定义了 lifespanon_event 处理器,只有 lifespan 中的代码会被执行。

pytest 高级用法

使用 pytest fixtures

import pytest
from fastapi.testclient import TestClient

@pytest.fixture
def client():
from main import app
return TestClient(app)

@pytest.fixture
def auth_headers():
return {"Authorization": "Bearer test-token"}

def test_protected_route(client, auth_headers):
response = client.get("/protected", headers=auth_headers)
assert response.status_code == 200

参数化测试

import pytest

@pytest.mark.parametrize("item_id,expected_status", [
(1, 200),
(2, 200),
(999, 404),
(0, 404),
])
def test_read_item(client, item_id, expected_status):
response = client.get(f"/items/{item_id}")
assert response.status_code == expected_status

测试分组

# 使用 mark 标记测试
@pytest.mark.slow
def test_slow_operation():
import time
time.sleep(1)
assert True

@pytest.mark.integration
def test_database():
# 集成测试
pass

# 运行特定标记的测试
# pytest -m slow
# pytest -m "not slow"
# pytest -m "slow and integration"

测试覆盖率

安装 pytest-cov

pip install pytest-cov

运行测试并生成覆盖率报告

# 终端输出
pytest --cov=app tests/

# 生成 HTML 报告
pytest --cov=app --cov-report=html tests/

配置 .coveragerc

[run]
source = app
omit =
app/tests/*
*/__pycache__/*

[report]
exclude_lines =
pragma: no cover
if __name__ == .__main__.:

完整测试示例

# test_main.py
import pytest
from contextlib import asynccontextmanager
from typing import Generator, Annotated
from fastapi import FastAPI, Depends, HTTPException
from fastapi.testclient import TestClient
from pydantic import BaseModel
from sqlmodel import Session, SQLModel, create_engine, Field, select

# 模型
class Item(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
price: float

class ItemCreate(BaseModel):
name: str
price: float

# 数据库
TEST_URL = "sqlite:///:memory:"
engine = create_engine(TEST_URL)

def get_session():
with Session(engine) as session:
yield session

SessionDep = Annotated[Session, Depends(get_session)]

# Lifespan 事件管理
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时:创建数据库表
SQLModel.metadata.create_all(engine)
yield
# 关闭时:清理资源

# 应用
app = FastAPI(lifespan=lifespan)

@app.post("/items/", response_model=Item)
def create_item(item: ItemCreate, session: SessionDep):
db_item = Item.model_validate(item)
session.add(db_item)
session.commit()
session.refresh(db_item)
return db_item

@app.get("/items/", response_model=list[Item])
def read_items(session: SessionDep):
return session.exec(select(Item)).all()

@app.get("/items/{item_id}", response_model=Item)
def read_item(item_id: int, session: SessionDep):
item = session.get(Item, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return item

# 测试
@pytest.fixture(name="client")
def client_fixture():
SQLModel.metadata.create_all(engine)
with TestClient(app) as client:
yield client
SQLModel.metadata.drop_all(engine)

class TestItems:
"""商品接口测试"""

def test_create_item(self, client):
"""测试创建商品"""
response = client.post(
"/items/",
json={"name": "测试商品", "price": 99.9}
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "测试商品"
assert data["price"] == 99.9
assert "id" in data

def test_create_item_invalid(self, client):
"""测试创建商品验证失败"""
response = client.post(
"/items/",
json={"name": "测试商品"} # 缺少 price
)
assert response.status_code == 422

def test_read_items(self, client):
"""测试获取商品列表"""
# 创建测试数据
client.post("/items/", json={"name": "商品1", "price": 10})
client.post("/items/", json={"name": "商品2", "price": 20})

response = client.get("/items/")
assert response.status_code == 200
assert len(response.json()) == 2

@pytest.mark.parametrize("item_id,should_exist", [
(1, True),
(999, False),
])
def test_read_item(self, client, item_id, should_exist):
"""测试获取单个商品"""
# 创建测试数据
client.post("/items/", json={"name": "测试商品", "price": 99.9})

response = client.get(f"/items/{item_id}")
if should_exist:
assert response.status_code == 200
else:
assert response.status_code == 404

小结

本章我们学习了:

  1. 测试客户端:使用 TestClient 测试 API
  2. 请求测试:测试 GET、POST、文件上传等
  3. 认证测试:测试 OAuth2 和 API Key 认证
  4. 依赖覆盖:在测试中替换依赖
  5. 数据库测试:使用测试数据库
  6. pytest 高级用法:fixtures、参数化、分组
  7. 测试覆盖率:使用 pytest-cov

测试的最佳实践:

  • 每个测试应该独立,不依赖其他测试
  • 使用 fixtures 共享测试数据和配置
  • 覆盖正常和异常情况
  • 测试验证逻辑和安全功能

练习

  1. 编写一个用户注册接口的完整测试用例
  2. 测试一个需要认证的接口,使用依赖覆盖模拟用户
  3. 使用内存数据库测试 CRUD 操作
  4. 编写参数化测试,验证输入验证逻辑