依赖注入
依赖注入(Dependency Injection)是 FastAPI 最强大和最具特色的特性之一。它提供了一种简洁、优雅的方式来共享逻辑、管理资源、实现认证等功能。
什么是依赖注入
依赖注入是一种设计模式,它允许代码声明自己需要什么"依赖",由框架负责提供这些依赖。
为什么需要依赖注入
在 Web 开发中,经常需要在多个地方重复相同的逻辑:
- 数据库连接管理
- 用户认证和授权
- 公共参数处理(如分页)
- 外部服务集成
依赖注入让这些逻辑可以定义一次,在多处复用。
基本使用
创建依赖
依赖是一个可调用对象(函数、类等),可以接收与路径操作函数相同的参数:
from typing import Annotated
from fastapi import Depends, FastAPI
app = FastAPI()
# 定义依赖函数
async def common_parameters(
q: str | None = None,
skip: int = 0,
limit: int = 100
):
return {"q": q, "skip": skip, "limit": limit}
# 使用依赖
@app.get("/items/")
async def read_items(commons: Annotated[dict, Depends(common_parameters)]):
return commons
@app.get("/users/")
async def read_users(commons: Annotated[dict, Depends(common_parameters)]):
return commons
当请求到达时,FastAPI 会:
- 调用依赖函数,传入相应的参数
- 获取依赖函数的返回值
- 将返回值注入到路径操作函数
共享 Annotated 依赖
当同一个依赖在多处使用时,可以创建类型别名:
from typing import Annotated
# 创建依赖类型别名
CommonsDep = Annotated[dict, Depends(common_parameters)]
@app.get("/items/")
async def read_items(commons: CommonsDep):
return commons
@app.get("/users/")
async def read_users(commons: CommonsDep):
return commons
这种方式:
- 减少代码重复
- 保持类型信息,编辑器仍能提供自动补全
- 更易于维护
类作为依赖
基本类依赖
使用类作为依赖,可以更好地组织代码:
from fastapi import FastAPI, Depends
from typing import Annotated
app = FastAPI()
class CommonParams:
def __init__(
self,
q: str | None = None,
skip: int = 0,
limit: int = 100
):
self.q = q
self.skip = skip
self.limit = limit
@app.get("/items/")
async def read_items(commons: Annotated[CommonParams, Depends()]):
return {
"q": commons.q,
"skip": commons.skip,
"limit": commons.limit
}
注意 Depends() 没有传参数,FastAPI 会自动使用 CommonParams 作为依赖。
类依赖的优势
class Pagination:
"""分页参数类"""
def __init__(
self,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100)
):
self.page = page
self.page_size = page_size
self.offset = (page - 1) * page_size
@property
def limit(self) -> int:
return self.page_size
# 使用
@app.get("/items/")
async def read_items(pagination: Annotated[Pagination, Depends()]):
# 可以直接使用 pagination.offset, pagination.limit
return {"offset": pagination.offset, "limit": pagination.limit}
依赖层次结构
依赖可以依赖其他依赖,形成层次结构:
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException
app = FastAPI()
# 基础依赖
def get_db():
"""获取数据库连接"""
db = Database()
try:
yield db
finally:
db.close()
# 依赖基础依赖
def get_user(db: Annotated[Database, Depends(get_db)]):
"""从数据库获取用户"""
user = db.get_current_user()
if not user:
raise HTTPException(status_code=401, detail="未登录")
return user
# 更高层的依赖
def get_active_user(user: Annotated[User, Depends(get_user)]):
"""获取激活状态的用户"""
if not user.is_active:
raise HTTPException(status_code=403, detail="用户未激活")
return user
# 使用
@app.get("/profile/")
async def get_profile(user: Annotated[User, Depends(get_active_user)]):
return user
依赖执行顺序:
yield 依赖
使用 yield 的依赖可以在请求结束后执行清理代码:
# 数据库连接
def get_db():
db = Database()
try:
yield db
finally:
db.close() # 请求结束后自动关闭
# 文件处理
def get_file():
f = open("data.txt", "r")
try:
yield f
finally:
f.close() # 确保文件关闭
# 使用
@app.get("/items/")
async def read_items(db: Annotated[Database, Depends(get_db)]):
return db.query(...)
执行顺序:
- 依赖函数执行到
yield,返回值 - 路径操作函数执行
- 路径操作函数完成后,执行
yield后的代码
多个 yield 依赖的执行顺序
def dep_a():
print("A: before yield")
yield "A"
print("A: after yield")
def dep_b():
print("B: before yield")
yield "B"
print("B: after yield")
@app.get("/")
async def root(
a: Annotated[str, Depends(dep_a)],
b: Annotated[str, Depends(dep_b)]
):
print("Handler")
return {"a": a, "b": b}
输出顺序:
A: before yield
B: before yield
Handler
B: after yield
A: after yield
注意:后执行的依赖先清理(类似栈结构)。
全局依赖
应用级依赖
为整个应用添加依赖:
from fastapi import FastAPI, Depends
app = FastAPI(dependencies=[Depends(verify_token), Depends(verify_key)])
@app.get("/items/")
async def read_items():
return [{"item": "Foo"}, {"item": "Bar"}]
@app.get("/users/")
async def read_users():
return [{"user": "Alice"}, {"user": "Bob"}]
所有路由都会先执行 verify_token 和 verify_key 依赖。
路由器级依赖
为一组路由添加依赖:
from fastapi import FastAPI, APIRouter, Depends
app = FastAPI()
router = APIRouter(dependencies=[Depends(verify_token)])
@router.get("/items/")
async def read_items():
return [...]
@router.get("/users/")
async def read_users():
return [...]
app.include_router(router)
路由级依赖
为单个路由添加依赖:
@app.get("/items/", dependencies=[Depends(verify_token)])
async def read_items():
return [...]
路径操作装饰器中的依赖
当依赖的返回值不需要在函数中使用时,放在 dependencies 参数中:
@app.get("/items/", dependencies=[Depends(verify_token)])
async def read_items():
# verify_token 的返回值被忽略
return [...]
适用场景:
- 验证请求头
- 检查用户权限
- 记录日志
异步与同步
FastAPI 会智能处理同步和异步依赖:
# 异步依赖
async def async_dep():
await some_async_operation()
return "async"
# 同步依赖
def sync_dep():
some_sync_operation()
return "sync"
# 可以混合使用
@app.get("/")
async def root(
a: Annotated[str, Depends(async_dep)],
b: Annotated[str, Depends(sync_dep)]
):
return {"a": a, "b": b}
实际应用示例
数据库会话管理
from typing import Annotated
from sqlmodel import Session, create_engine
from fastapi import Depends
engine = create_engine("sqlite:///database.db")
def get_session():
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
@app.get("/users/")
async def read_users(session: SessionDep):
users = session.exec(select(User)).all()
return users
用户认证
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from typing import Annotated
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
user = verify_token(token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据"
)
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="用户已禁用")
return current_user
# 使用
@app.get("/users/me")
async def read_users_me(
current_user: Annotated[User, Depends(get_current_active_user)]
):
return current_user
权限控制
from enum import Enum
class Role(str, Enum):
ADMIN = "admin"
USER = "user"
GUEST = "guest"
def require_role(required_role: Role):
"""创建角色检查依赖"""
async def role_checker(current_user: User = Depends(get_current_user)):
if current_user.role != required_role and current_user.role != Role.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足"
)
return current_user
return role_checker
# 使用
@app.delete("/users/{user_id}")
async def delete_user(
user_id: int,
user: User = Depends(require_role(Role.ADMIN))
):
# 只有管理员可以删除用户
return {"deleted": user_id}
分页参数
from fastapi import Query
from dataclasses import dataclass
@dataclass
class PaginationParams:
"""分页参数"""
page: int = Query(1, ge=1, description="页码")
page_size: int = Query(10, ge=1, le=100, description="每页数量")
@property
def offset(self) -> int:
return (self.page - 1) * self.page_size
@property
def limit(self) -> int:
return self.page_size
Pagination = Annotated[PaginationParams, Depends()]
@app.get("/items/")
async def list_items(pagination: Pagination):
items = db.query(Item).offset(pagination.offset).limit(pagination.limit).all()
return {
"items": items,
"page": pagination.page,
"page_size": pagination.page_size
}
缓存依赖
from functools import lru_cache
@lru_cache
def get_settings():
"""获取应用配置(单例)"""
return Settings()
SettingsDep = Annotated[Settings, Depends(get_settings)]
@app.get("/config")
async def get_config(settings: SettingsDep):
return {"app_name": settings.app_name}
依赖覆盖
在测试中可以覆盖依赖:
from fastapi.testclient import TestClient
# 原始依赖
def get_db():
return Database("production.db")
# 测试依赖
def get_test_db():
return Database("test.db")
app = FastAPI()
app.dependency_overrides[get_db] = get_test_db
client = TestClient(app)
最佳实践
1. 使用类型别名
# 定义类型别名
DBSession = Annotated[Session, Depends(get_session)]
CurrentUser = Annotated[User, Depends(get_current_user)]
# 使用
@app.get("/users/me")
async def get_me(user: CurrentUser, session: DBSession):
return user
2. 依赖分层
# 基础层:资源获取
def get_db(): ...
# 中间层:业务逻辑
def get_user(db: DBSession, token: str): ...
# 应用层:权限控制
def get_admin_user(user: User): ...
3. 使用 yield 进行资源清理
def get_db():
db = Database()
try:
yield db
finally:
db.close() # 确保资源释放
4. 避免循环依赖
# 错误:循环依赖
def dep_a(b = Depends(dep_b)): ...
def dep_b(a = Depends(dep_a)): ...
# 正确:重构依赖结构
def dep_base(): ...
def dep_a(base = Depends(dep_base)): ...
def dep_b(base = Depends(dep_base)): ...
小结
本章我们学习了:
- 依赖注入基础:创建和使用依赖
- 类作为依赖:使用类组织依赖逻辑
- 依赖层次:依赖可以依赖其他依赖
- yield 依赖:在请求结束后执行清理
- 全局依赖:应用级、路由器级、路由级依赖
- 实际应用:数据库管理、用户认证、权限控制
依赖注入的优势:
- 代码复用:一次定义,多处使用
- 易于测试:可以轻松覆盖依赖
- 清晰的依赖关系:依赖层次明确
- 自动文档:依赖信息包含在 OpenAPI 文档中
练习
- 创建一个分页依赖,包含 page、page_size 和自动计算的 offset
- 实现一个 JWT 认证依赖,从请求头获取并验证 token
- 创建数据库会话依赖,使用 yield 确保会话关闭
- 实现一个权限检查依赖工厂函数,根据角色限制访问