跳到主要内容

依赖注入

依赖注入(Dependency Injection)是 FastAPI 最强大和最具特色的特性之一。它提供了一种简洁、优雅的方式来共享逻辑、管理资源、实现认证等功能。

什么是依赖注入

依赖注入是一种设计模式,它允许代码声明自己需要什么"依赖",由框架负责提供这些依赖。

为什么需要依赖注入

在 Web 开发中,经常需要在多个地方重复相同的逻辑:

  • 数据库连接管理
  • 用户认证和授权
  • 公共参数处理(如分页)
  • 外部服务集成

依赖注入让这些逻辑可以定义一次,在多处复用。

基本使用

创建依赖

依赖是一个可调用对象(函数、类等),可以接收与路径操作函数相同的参数:

from typing import Annotated
from fastapi import Depends, FastAPI

app = FastAPI()

# 定义依赖函数
async def common_parameters(
q: str | None = None,
skip: int = 0,
limit: int = 100
):
return {"q": q, "skip": skip, "limit": limit}

# 使用依赖
@app.get("/items/")
async def read_items(commons: Annotated[dict, Depends(common_parameters)]):
return commons

@app.get("/users/")
async def read_users(commons: Annotated[dict, Depends(common_parameters)]):
return commons

当请求到达时,FastAPI 会:

  1. 调用依赖函数,传入相应的参数
  2. 获取依赖函数的返回值
  3. 将返回值注入到路径操作函数

共享 Annotated 依赖

当同一个依赖在多处使用时,可以创建类型别名:

from typing import Annotated

# 创建依赖类型别名
CommonsDep = Annotated[dict, Depends(common_parameters)]

@app.get("/items/")
async def read_items(commons: CommonsDep):
return commons

@app.get("/users/")
async def read_users(commons: CommonsDep):
return commons

这种方式:

  • 减少代码重复
  • 保持类型信息,编辑器仍能提供自动补全
  • 更易于维护

类作为依赖

基本类依赖

使用类作为依赖,可以更好地组织代码:

from fastapi import FastAPI, Depends
from typing import Annotated

app = FastAPI()

class CommonParams:
def __init__(
self,
q: str | None = None,
skip: int = 0,
limit: int = 100
):
self.q = q
self.skip = skip
self.limit = limit

@app.get("/items/")
async def read_items(commons: Annotated[CommonParams, Depends()]):
return {
"q": commons.q,
"skip": commons.skip,
"limit": commons.limit
}

注意 Depends() 没有传参数,FastAPI 会自动使用 CommonParams 作为依赖。

类依赖的优势

class Pagination:
"""分页参数类"""
def __init__(
self,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100)
):
self.page = page
self.page_size = page_size
self.offset = (page - 1) * page_size

@property
def limit(self) -> int:
return self.page_size

# 使用
@app.get("/items/")
async def read_items(pagination: Annotated[Pagination, Depends()]):
# 可以直接使用 pagination.offset, pagination.limit
return {"offset": pagination.offset, "limit": pagination.limit}

依赖层次结构

依赖可以依赖其他依赖,形成层次结构:

from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException

app = FastAPI()

# 基础依赖
def get_db():
"""获取数据库连接"""
db = Database()
try:
yield db
finally:
db.close()

# 依赖基础依赖
def get_user(db: Annotated[Database, Depends(get_db)]):
"""从数据库获取用户"""
user = db.get_current_user()
if not user:
raise HTTPException(status_code=401, detail="未登录")
return user

# 更高层的依赖
def get_active_user(user: Annotated[User, Depends(get_user)]):
"""获取激活状态的用户"""
if not user.is_active:
raise HTTPException(status_code=403, detail="用户未激活")
return user

# 使用
@app.get("/profile/")
async def get_profile(user: Annotated[User, Depends(get_active_user)]):
return user

依赖执行顺序:

yield 依赖

使用 yield 的依赖可以在请求结束后执行清理代码:

# 数据库连接
def get_db():
db = Database()
try:
yield db
finally:
db.close() # 请求结束后自动关闭

# 文件处理
def get_file():
f = open("data.txt", "r")
try:
yield f
finally:
f.close() # 确保文件关闭

# 使用
@app.get("/items/")
async def read_items(db: Annotated[Database, Depends(get_db)]):
return db.query(...)

执行顺序:

  1. 依赖函数执行到 yield,返回值
  2. 路径操作函数执行
  3. 路径操作函数完成后,执行 yield 后的代码

多个 yield 依赖的执行顺序

def dep_a():
print("A: before yield")
yield "A"
print("A: after yield")

def dep_b():
print("B: before yield")
yield "B"
print("B: after yield")

@app.get("/")
async def root(
a: Annotated[str, Depends(dep_a)],
b: Annotated[str, Depends(dep_b)]
):
print("Handler")
return {"a": a, "b": b}

输出顺序:

A: before yield
B: before yield
Handler
B: after yield
A: after yield

注意:后执行的依赖先清理(类似栈结构)。

全局依赖

应用级依赖

为整个应用添加依赖:

from fastapi import FastAPI, Depends

app = FastAPI(dependencies=[Depends(verify_token), Depends(verify_key)])

@app.get("/items/")
async def read_items():
return [{"item": "Foo"}, {"item": "Bar"}]

@app.get("/users/")
async def read_users():
return [{"user": "Alice"}, {"user": "Bob"}]

所有路由都会先执行 verify_tokenverify_key 依赖。

路由器级依赖

为一组路由添加依赖:

from fastapi import FastAPI, APIRouter, Depends

app = FastAPI()
router = APIRouter(dependencies=[Depends(verify_token)])

@router.get("/items/")
async def read_items():
return [...]

@router.get("/users/")
async def read_users():
return [...]

app.include_router(router)

路由级依赖

为单个路由添加依赖:

@app.get("/items/", dependencies=[Depends(verify_token)])
async def read_items():
return [...]

路径操作装饰器中的依赖

当依赖的返回值不需要在函数中使用时,放在 dependencies 参数中:

@app.get("/items/", dependencies=[Depends(verify_token)])
async def read_items():
# verify_token 的返回值被忽略
return [...]

适用场景:

  • 验证请求头
  • 检查用户权限
  • 记录日志

异步与同步

FastAPI 会智能处理同步和异步依赖:

# 异步依赖
async def async_dep():
await some_async_operation()
return "async"

# 同步依赖
def sync_dep():
some_sync_operation()
return "sync"

# 可以混合使用
@app.get("/")
async def root(
a: Annotated[str, Depends(async_dep)],
b: Annotated[str, Depends(sync_dep)]
):
return {"a": a, "b": b}

实际应用示例

数据库会话管理

from typing import Annotated
from sqlmodel import Session, create_engine
from fastapi import Depends

engine = create_engine("sqlite:///database.db")

def get_session():
with Session(engine) as session:
yield session

SessionDep = Annotated[Session, Depends(get_session)]

@app.get("/users/")
async def read_users(session: SessionDep):
users = session.exec(select(User)).all()
return users

用户认证

from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from typing import Annotated

app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
user = verify_token(token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据"
)
return user

async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="用户已禁用")
return current_user

# 使用
@app.get("/users/me")
async def read_users_me(
current_user: Annotated[User, Depends(get_current_active_user)]
):
return current_user

权限控制

from enum import Enum

class Role(str, Enum):
ADMIN = "admin"
USER = "user"
GUEST = "guest"

def require_role(required_role: Role):
"""创建角色检查依赖"""
async def role_checker(current_user: User = Depends(get_current_user)):
if current_user.role != required_role and current_user.role != Role.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足"
)
return current_user
return role_checker

# 使用
@app.delete("/users/{user_id}")
async def delete_user(
user_id: int,
user: User = Depends(require_role(Role.ADMIN))
):
# 只有管理员可以删除用户
return {"deleted": user_id}

分页参数

from fastapi import Query
from dataclasses import dataclass

@dataclass
class PaginationParams:
"""分页参数"""
page: int = Query(1, ge=1, description="页码")
page_size: int = Query(10, ge=1, le=100, description="每页数量")

@property
def offset(self) -> int:
return (self.page - 1) * self.page_size

@property
def limit(self) -> int:
return self.page_size

Pagination = Annotated[PaginationParams, Depends()]

@app.get("/items/")
async def list_items(pagination: Pagination):
items = db.query(Item).offset(pagination.offset).limit(pagination.limit).all()
return {
"items": items,
"page": pagination.page,
"page_size": pagination.page_size
}

缓存依赖

from functools import lru_cache

@lru_cache
def get_settings():
"""获取应用配置(单例)"""
return Settings()

SettingsDep = Annotated[Settings, Depends(get_settings)]

@app.get("/config")
async def get_config(settings: SettingsDep):
return {"app_name": settings.app_name}

依赖覆盖

在测试中可以覆盖依赖:

from fastapi.testclient import TestClient

# 原始依赖
def get_db():
return Database("production.db")

# 测试依赖
def get_test_db():
return Database("test.db")

app = FastAPI()
app.dependency_overrides[get_db] = get_test_db

client = TestClient(app)

最佳实践

1. 使用类型别名

# 定义类型别名
DBSession = Annotated[Session, Depends(get_session)]
CurrentUser = Annotated[User, Depends(get_current_user)]

# 使用
@app.get("/users/me")
async def get_me(user: CurrentUser, session: DBSession):
return user

2. 依赖分层

# 基础层:资源获取
def get_db(): ...

# 中间层:业务逻辑
def get_user(db: DBSession, token: str): ...

# 应用层:权限控制
def get_admin_user(user: User): ...

3. 使用 yield 进行资源清理

def get_db():
db = Database()
try:
yield db
finally:
db.close() # 确保资源释放

4. 避免循环依赖

# 错误:循环依赖
def dep_a(b = Depends(dep_b)): ...
def dep_b(a = Depends(dep_a)): ...

# 正确:重构依赖结构
def dep_base(): ...
def dep_a(base = Depends(dep_base)): ...
def dep_b(base = Depends(dep_base)): ...

小结

本章我们学习了:

  1. 依赖注入基础:创建和使用依赖
  2. 类作为依赖:使用类组织依赖逻辑
  3. 依赖层次:依赖可以依赖其他依赖
  4. yield 依赖:在请求结束后执行清理
  5. 全局依赖:应用级、路由器级、路由级依赖
  6. 实际应用:数据库管理、用户认证、权限控制

依赖注入的优势:

  • 代码复用:一次定义,多处使用
  • 易于测试:可以轻松覆盖依赖
  • 清晰的依赖关系:依赖层次明确
  • 自动文档:依赖信息包含在 OpenAPI 文档中

练习

  1. 创建一个分页依赖,包含 page、page_size 和自动计算的 offset
  2. 实现一个 JWT 认证依赖,从请求头获取并验证 token
  3. 创建数据库会话依赖,使用 yield 确保会话关闭
  4. 实现一个权限检查依赖工厂函数,根据角色限制访问