跳到主要内容

测试

测试是确保代码质量的重要手段。FastAPI 提供了完善的测试支持,可以轻松编写单元测试和集成测试。

测试基础

安装依赖

pip install pytest httpx

创建测试客户端

使用 TestClient 创建测试客户端:

from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/")
async def read_root():
return {"message": "Hello World"}

# 创建测试客户端
client = TestClient(app)

def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello World"}

运行测试

# 运行所有测试
pytest

# 运行指定文件
pytest test_main.py

# 显示详细信息
pytest -v

# 显示打印输出
pytest -s

测试 GET 请求

from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/items/{item_id}")
async def read_item(item_id: int, q: str = None):
return {"item_id": item_id, "q": q}

client = TestClient(app)

def test_read_item():
# 测试基本请求
response = client.get("/items/42")
assert response.status_code == 200
assert response.json() == {"item_id": 42, "q": None}

def test_read_item_with_query():
# 测试带查询参数的请求
response = client.get("/items/42?q=search")
assert response.status_code == 200
assert response.json() == {"item_id": 42, "q": "search"}

测试 POST 请求

from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel

app = FastAPI()

class Item(BaseModel):
name: str
price: float

@app.post("/items/")
async def create_item(item: Item):
return item

client = TestClient(app)

def test_create_item():
# 发送 JSON 数据
response = client.post(
"/items/",
json={"name": "商品", "price": 99.9}
)
assert response.status_code == 200
assert response.json() == {"name": "商品", "price": 99.9}

def test_create_item_invalid():
# 测试验证错误
response = client.post(
"/items/",
json={"name": "商品"} # 缺少 price
)
assert response.status_code == 422

测试文件上传

from fastapi import FastAPI, UploadFile, File
from fastapi.testclient import TestClient

app = FastAPI()

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
content = await file.read()
return {"filename": file.filename, "size": len(content)}

client = TestClient(app)

def test_upload_file():
# 创建测试文件
file_content = b"test file content"

response = client.post(
"/upload/",
files={"file": ("test.txt", file_content, "text/plain")}
)

assert response.status_code == 200
assert response.json() == {"filename": "test.txt", "size": 17}

测试表单数据

from fastapi import FastAPI, Form
from fastapi.testclient import TestClient

app = FastAPI()

@app.post("/login/")
async def login(username: str = Form(...), password: str = Form(...)):
return {"username": username}

client = TestClient(app)

def test_login():
response = client.post(
"/login/",
data={"username": "user", "password": "pass"}
)
assert response.status_code == 200
assert response.json() == {"username": "user"}

测试认证

测试 OAuth2

from fastapi import FastAPI, Depends
from fastapi.security import OAuth2PasswordBearer
from fastapi.testclient import TestClient

app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

@app.get("/users/me")
async def read_users_me(token: str = Depends(oauth2_scheme)):
return {"token": token}

client = TestClient(app)

def test_with_auth():
# 添加 Authorization 头
response = client.get(
"/users/me",
headers={"Authorization": "Bearer test-token"}
)
assert response.status_code == 200
assert response.json() == {"token": "test-token"}

def test_without_auth():
# 未提供 token
response = client.get("/users/me")
assert response.status_code == 401

测试 API Key

from fastapi import FastAPI, Security, HTTPException
from fastapi.security import APIKeyHeader
from fastapi.testclient import TestClient

app = FastAPI()
api_key_header = APIKeyHeader(name="X-API-Key")

async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != "secret-key":
raise HTTPException(status_code=403)
return api_key

@app.get("/protected")
async def protected(api_key: str = Depends(get_api_key)):
return {"message": "success"}

client = TestClient(app)

def test_valid_api_key():
response = client.get(
"/protected",
headers={"X-API-Key": "secret-key"}
)
assert response.status_code == 200

def test_invalid_api_key():
response = client.get(
"/protected",
headers={"X-API-Key": "wrong-key"}
)
assert response.status_code == 403

测试依赖覆盖

在测试中可以覆盖依赖,使用测试数据或模拟服务:

from typing import Annotated
from fastapi import FastAPI, Depends
from fastapi.testclient import TestClient

app = FastAPI()

# 正常依赖:从数据库获取用户
async def get_current_user():
# 实际应用中从数据库或 token 获取
return {"id": 1, "name": "Real User"}

@app.get("/users/me")
async def read_users_me(user: Annotated[dict, Depends(get_current_user)]):
return user

# 测试依赖:返回测试用户
async def override_get_current_user():
return {"id": 999, "name": "Test User"}

client = TestClient(app)

def test_read_users_me():
# 覆盖依赖
app.dependency_overrides[get_current_user] = override_get_current_user

response = client.get("/users/me")
assert response.status_code == 200
assert response.json() == {"id": 999, "name": "Test User"}

# 清除覆盖
app.dependency_overrides = {}

测试数据库

使用测试数据库

import pytest
from typing import Generator
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine, select

# 测试数据库
TEST_DATABASE_URL = "sqlite:///./test.db"
test_engine = create_engine(TEST_DATABASE_URL)

def get_test_session():
with Session(test_engine) as session:
yield session

app = FastAPI()
app.dependency_overrides[get_session] = get_test_session

@pytest.fixture(scope="function")
def client() -> Generator:
# 每个测试前创建表
SQLModel.metadata.create_all(test_engine)

with TestClient(app) as c:
yield c

# 测试后删除所有数据
SQLModel.metadata.drop_all(test_engine)

def test_create_item(client: TestClient):
response = client.post("/items/", json={"name": "Test Item"})
assert response.status_code == 201

使用内存数据库

import pytest
from typing import Generator
from sqlmodel import Session, SQLModel, create_engine

@pytest.fixture(name="session")
def session_fixture() -> Generator[Session, None, None]:
# 使用内存数据库
engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(engine)

with Session(engine) as session:
yield session

@pytest.fixture(name="client")
def client_fixture(session: Session) -> Generator:
def get_session_override():
return session

app.dependency_overrides[get_session] = get_session_override

from fastapi.testclient import TestClient
client = TestClient(app)
yield client

app.dependency_overrides.clear()

测试异常

from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/items/{item_id}")
async def read_item(item_id: int):
if item_id == 0:
raise HTTPException(status_code=404, detail="Item not found")
return {"item_id": item_id}

client = TestClient(app)

def test_item_not_found():
response = client.get("/items/0")
assert response.status_code == 404
assert response.json()["detail"] == "Item not found"

测试事件

测试生命周期事件

from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()
state = {"startup": False, "shutdown": False}

@app.on_event("startup")
async def startup():
state["startup"] = True

@app.on_event("shutdown")
async def shutdown():
state["shutdown"] = True

def test_lifespan():
with TestClient(app) as client:
assert state["startup"] is True
assert state["shutdown"] is False

assert state["shutdown"] is True

使用 Lifespan

from contextlib import asynccontextmanager
from fastapi import FastAPI

state = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动
state["db"] = "connected"
yield
# 关闭
state.pop("db", None)

app = FastAPI(lifespan=lifespan)

def test_with_lifespan():
from fastapi.testclient import TestClient

with TestClient(app) as client:
assert "db" in state

pytest 高级用法

使用 pytest fixtures

import pytest
from fastapi.testclient import TestClient

@pytest.fixture
def client():
from main import app
return TestClient(app)

@pytest.fixture
def auth_headers():
return {"Authorization": "Bearer test-token"}

def test_protected_route(client, auth_headers):
response = client.get("/protected", headers=auth_headers)
assert response.status_code == 200

参数化测试

import pytest

@pytest.mark.parametrize("item_id,expected_status", [
(1, 200),
(2, 200),
(999, 404),
(0, 404),
])
def test_read_item(client, item_id, expected_status):
response = client.get(f"/items/{item_id}")
assert response.status_code == expected_status

测试分组

# 使用 mark 标记测试
@pytest.mark.slow
def test_slow_operation():
import time
time.sleep(1)
assert True

@pytest.mark.integration
def test_database():
# 集成测试
pass

# 运行特定标记的测试
# pytest -m slow
# pytest -m "not slow"
# pytest -m "slow and integration"

测试覆盖率

安装 pytest-cov

pip install pytest-cov

运行测试并生成覆盖率报告

# 终端输出
pytest --cov=app tests/

# 生成 HTML 报告
pytest --cov=app --cov-report=html tests/

配置 .coveragerc

[run]
source = app
omit =
app/tests/*
*/__pycache__/*

[report]
exclude_lines =
pragma: no cover
if __name__ == .__main__.:

完整测试示例

# test_main.py
import pytest
from typing import Generator, Annotated
from fastapi import FastAPI, Depends, HTTPException
from fastapi.testclient import TestClient
from pydantic import BaseModel
from sqlmodel import Session, SQLModel, create_engine, Field, select

# 模型
class Item(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
price: float

class ItemCreate(BaseModel):
name: str
price: float

# 数据库
TEST_URL = "sqlite:///:memory:"
engine = create_engine(TEST_URL)

def get_session():
with Session(engine) as session:
yield session

SessionDep = Annotated[Session, Depends(get_session)]

# 应用
app = FastAPI()

@app.on_event("startup")
def startup():
SQLModel.metadata.create_all(engine)

@app.post("/items/", response_model=Item)
def create_item(item: ItemCreate, session: SessionDep):
db_item = Item.model_validate(item)
session.add(db_item)
session.commit()
session.refresh(db_item)
return db_item

@app.get("/items/", response_model=list[Item])
def read_items(session: SessionDep):
return session.exec(select(Item)).all()

@app.get("/items/{item_id}", response_model=Item)
def read_item(item_id: int, session: SessionDep):
item = session.get(Item, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
return item

# 测试
@pytest.fixture(name="client")
def client_fixture():
SQLModel.metadata.create_all(engine)
with TestClient(app) as client:
yield client
SQLModel.metadata.drop_all(engine)

class TestItems:
"""商品接口测试"""

def test_create_item(self, client):
"""测试创建商品"""
response = client.post(
"/items/",
json={"name": "测试商品", "price": 99.9}
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "测试商品"
assert data["price"] == 99.9
assert "id" in data

def test_create_item_invalid(self, client):
"""测试创建商品验证失败"""
response = client.post(
"/items/",
json={"name": "测试商品"} # 缺少 price
)
assert response.status_code == 422

def test_read_items(self, client):
"""测试获取商品列表"""
# 创建测试数据
client.post("/items/", json={"name": "商品1", "price": 10})
client.post("/items/", json={"name": "商品2", "price": 20})

response = client.get("/items/")
assert response.status_code == 200
assert len(response.json()) == 2

@pytest.mark.parametrize("item_id,should_exist", [
(1, True),
(999, False),
])
def test_read_item(self, client, item_id, should_exist):
"""测试获取单个商品"""
# 创建测试数据
client.post("/items/", json={"name": "测试商品", "price": 99.9})

response = client.get(f"/items/{item_id}")
if should_exist:
assert response.status_code == 200
else:
assert response.status_code == 404

小结

本章我们学习了:

  1. 测试客户端:使用 TestClient 测试 API
  2. 请求测试:测试 GET、POST、文件上传等
  3. 认证测试:测试 OAuth2 和 API Key 认证
  4. 依赖覆盖:在测试中替换依赖
  5. 数据库测试:使用测试数据库
  6. pytest 高级用法:fixtures、参数化、分组
  7. 测试覆盖率:使用 pytest-cov

测试的最佳实践:

  • 每个测试应该独立,不依赖其他测试
  • 使用 fixtures 共享测试数据和配置
  • 覆盖正常和异常情况
  • 测试验证逻辑和安全功能

练习

  1. 编写一个用户注册接口的完整测试用例
  2. 测试一个需要认证的接口,使用依赖覆盖模拟用户
  3. 使用内存数据库测试 CRUD 操作
  4. 编写参数化测试,验证输入验证逻辑