跳到主要内容

中间件与异常处理

中间件是处理请求和响应的拦截器,可以在请求到达路由之前或响应返回客户端之前执行代码。异常处理则是确保应用在出错时能够优雅地返回有意义的错误信息。

中间件

什么是中间件

中间件是一个函数,它在每个请求处理之前和响应返回之前执行:

创建中间件

使用 @app.middleware("http") 装饰器创建中间件:

import time
from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
# 请求前:记录开始时间
start_time = time.perf_counter()

# 调用下一个中间件或路由处理
response = await call_next(request)

# 响应后:计算处理时间
process_time = time.perf_counter() - start_time

# 添加自定义响应头
response.headers["X-Process-Time"] = str(process_time)

return response

@app.get("/")
async def root():
return {"message": "Hello World"}

中间件执行顺序

当添加多个中间件时,执行顺序遵循"洋葱模型":

from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware("http")
async def middleware_a(request: Request, call_next):
print("A: 请求前")
response = await call_next(request)
print("A: 响应后")
return response

@app.middleware("http")
async def middleware_b(request: Request, call_next):
print("B: 请求前")
response = await call_next(request)
print("B: 响应后")
return response

# 执行顺序:
# 请求前: B → A → 路由
# 响应后: 路由 → A → B

输出:

B: 请求前
A: 请求前
A: 响应后
B: 响应后

中间件的典型用途

请求日志

import logging
from datetime import datetime
from fastapi import FastAPI, Request

app = FastAPI()

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@app.middleware("http")
async def log_requests(request: Request, call_next):
# 记录请求信息
logger.info(
f"请求开始: {request.method} {request.url} "
f"来自 {request.client.host if request.client else 'unknown'}"
)

start_time = datetime.now()

response = await call_next(request)

# 记录响应信息
process_time = (datetime.now() - start_time).total_seconds()
logger.info(
f"请求完成: {request.method} {request.url} "
f"状态码 {response.status_code} 耗时 {process_time:.3f}s"
)

return response

请求 ID 追踪

import uuid
from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware("http")
async def add_request_id(request: Request, call_next):
# 生成或获取请求 ID
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))

# 存储在请求状态中
request.state.request_id = request_id

response = await call_next(request)

# 添加到响应头
response.headers["X-Request-ID"] = request_id

return response

@app.get("/")
async def root(request: Request):
# 在路由中访问请求 ID
return {"request_id": request.state.request_id}

访问控制

from fastapi import FastAPI, Request, HTTPException, status

app = FastAPI()

# IP 黑名单
BLOCKED_IPS = {"192.168.1.100", "10.0.0.50"}

@app.middleware("http")
async def ip_filter(request: Request, call_next):
client_ip = request.client.host if request.client else None

if client_ip in BLOCKED_IPS:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="访问被拒绝"
)

return await call_next(request)

类形式的中间件

使用类创建可复用的中间件:

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware

class ProcessTimeMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
import time
start_time = time.perf_counter()

response = await call_next(request)

process_time = time.perf_counter() - start_time
response.headers["X-Process-Time"] = str(process_time)

return response

app = FastAPI()
app.add_middleware(ProcessTimeMiddleware)

内置中间件

CORS 中间件

跨域资源共享(CORS)是前后端分离应用的必备配置:

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"http://localhost:8080",
"https://myapp.com",
],
allow_credentials=True,
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有请求头
expose_headers=["X-Total-Count"], # 暴露给客户端的响应头
max_age=3600, # 预检请求缓存时间
)

配置说明:

参数说明
allow_origins允许的源列表,["*"] 表示允许所有
allow_credentials是否允许发送 Cookie
allow_methods允许的 HTTP 方法
allow_headers允许的请求头
expose_headers暴露给客户端的响应头
max_age预检请求缓存时间(秒)

Trusted Host 中间件

防止 Host 头攻击:

from fastapi import FastAPI
from starlette.middleware.trustedhost import TrustedHostMiddleware

app = FastAPI()

app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com", "localhost"]
)

GZip 中间件

自动压缩响应:

from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware

app = FastAPI()

app.add_middleware(GZipMiddleware, minimum_size=1000) # 大于 1KB 才压缩

HTTPS 重定向

强制使用 HTTPS:

from fastapi import FastAPI
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

app = FastAPI()
app.add_middleware(HTTPSRedirectMiddleware)

Session 中间件

from fastapi import FastAPI
from fastapi.middleware.session import SessionMiddleware

app = FastAPI()

# 需要设置安全的密钥
app.add_middleware(
SessionMiddleware,
secret_key="your-secret-key",
session_cookie="sessionid",
max_age=3600, # 1 小时
)

异常处理

HTTPException

FastAPI 的内置异常,用于返回 HTTP 错误响应:

from fastapi import FastAPI, HTTPException, status

app = FastAPI()

items = {"foo": "The Foo Wrestlers", "bar": "The Bar Wrestlers"}

@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id not in items:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="商品不存在",
headers={"X-Error": "Not Found"} # 可选的自定义响应头
)
return {"item": items[item_id]}

自定义异常处理器

处理自定义异常

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

app = FastAPI()

# 自定义异常
class ItemNotFoundException(Exception):
def __init__(self, item_id: str):
self.item_id = item_id

# 注册异常处理器
@app.exception_handler(ItemNotFoundException)
async def item_not_found_handler(request: Request, exc: ItemNotFoundException):
return JSONResponse(
status_code=404,
content={
"error": "ITEM_NOT_FOUND",
"message": f"商品 {exc.item_id} 不存在",
"item_id": exc.item_id
}
)

# 触发异常
@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id == "error":
raise ItemNotFoundException(item_id)
return {"item_id": item_id}

覆盖默认异常处理器

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError

app = FastAPI()

# 自定义验证错误响应
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content={
"error": "VALIDATION_ERROR",
"message": "请求数据验证失败",
"details": exc.errors()
}
)

处理所有异常

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import logging

app = FastAPI()
logger = logging.getLogger(__name__)

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
# 记录错误日志
logger.error(f"未处理的异常: {exc}", exc_info=True)

return JSONResponse(
status_code=500,
content={
"error": "INTERNAL_ERROR",
"message": "服务器内部错误"
}
)

完整的异常处理体系

from enum import Enum
from typing import Any
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel

app = FastAPI()

# 错误码枚举
class ErrorCode(str, Enum):
VALIDATION_ERROR = "VALIDATION_ERROR"
NOT_FOUND = "NOT_FOUND"
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
INTERNAL_ERROR = "INTERNAL_ERROR"

# 统一错误响应模型
class ErrorResponse(BaseModel):
error: ErrorCode
message: str
details: Any = None

# 自定义业务异常基类
class AppException(Exception):
def __init__(
self,
error_code: ErrorCode,
message: str,
status_code: int = 400,
details: Any = None
):
self.error_code = error_code
self.message = message
self.status_code = status_code
self.details = details

# 具体业务异常
class NotFoundException(AppException):
def __init__(self, resource: str, resource_id: str):
super().__init__(
error_code=ErrorCode.NOT_FOUND,
message=f"{resource} 不存在",
status_code=status.HTTP_404_NOT_FOUND,
details={"resource": resource, "id": resource_id}
)

class UnauthorizedException(AppException):
def __init__(self, message: str = "未授权"):
super().__init__(
error_code=ErrorCode.UNAUTHORIZED,
message=message,
status_code=status.HTTP_401_UNAUTHORIZED
)

class ForbiddenException(AppException):
def __init__(self, message: str = "权限不足"):
super().__init__(
error_code=ErrorCode.FORBIDDEN,
message=message,
status_code=status.HTTP_403_FORBIDDEN
)

# 异常处理器
@app.exception_handler(AppException)
async def app_exception_handler(request: Request, exc: AppException):
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
error=exc.error_code,
message=exc.message,
details=exc.details
).model_dump()
)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=ErrorResponse(
error=ErrorCode.VALIDATION_ERROR,
message="请求数据验证失败",
details=exc.errors()
).model_dump()
)

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
error=ErrorCode.INTERNAL_ERROR,
message=str(exc.detail)
).model_dump()
)

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
# 生产环境不应返回详细错误信息
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=ErrorResponse(
error=ErrorCode.INTERNAL_ERROR,
message="服务器内部错误"
).model_dump()
)

# 使用示例
@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id == "missing":
raise NotFoundException("商品", item_id)
if item_id == "forbidden":
raise ForbiddenException("无权访问此商品")
return {"item_id": item_id}

后台任务

后台任务允许在响应返回后执行操作,适用于发送邮件、日志记录等场景:

from fastapi import FastAPI, BackgroundTasks

app = FastAPI()

def send_email(email: str, message: str):
"""发送邮件(模拟)"""
print(f"发送邮件到 {email}: {message}")

def write_log(message: str):
"""写入日志"""
with open("log.txt", "a") as f:
f.write(message + "\n")

@app.post("/send-notification/{email}")
async def send_notification(
email: str,
background_tasks: BackgroundTasks
):
# 添加后台任务
background_tasks.add_task(
send_email,
email,
"您有新的通知"
)
background_tasks.add_task(
write_log,
f"发送通知到 {email}"
)

return {"message": "通知已发送", "email": email}

后台任务与依赖注入

from typing import Annotated
from fastapi import FastAPI, BackgroundTasks, Depends

app = FastAPI()

def get_background_tasks():
return BackgroundTasks()

@app.post("/items/")
async def create_item(
item: dict,
background_tasks: Annotated[BackgroundTasks, Depends()]
):
# 创建商品的逻辑
background_tasks.add_task(lambda: print(f"创建商品: {item}"))
return item

实际应用示例

请求日志中间件 + 异常处理

import logging
import time
import uuid
from contextlib import asynccontextmanager
from typing import Callable

from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel

# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# 应用
app = FastAPI(title="中间件与异常处理示例")

# 请求上下文
class RequestContext:
def __init__(self):
self.request_id: str = ""
self.start_time: float = 0

# 日志中间件
@app.middleware("http")
async def logging_middleware(request: Request, call_next):
# 初始化请求上下文
ctx = RequestContext()
ctx.request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
ctx.start_time = time.perf_counter()

# 存储到请求状态
request.state.ctx = ctx

# 记录请求
logger.info(
f"[{ctx.request_id}] 请求开始: {request.method} {request.url.path}"
)

try:
response = await call_next(request)

# 计算处理时间
process_time = time.perf_counter() - ctx.start_time

# 添加响应头
response.headers["X-Request-ID"] = ctx.request_id
response.headers["X-Process-Time"] = f"{process_time:.3f}s"

# 记录响应
logger.info(
f"[{ctx.request_id}] 请求完成: "
f"状态码 {response.status_code}, 耗时 {process_time:.3f}s"
)

return response

except Exception as exc:
# 记录异常
logger.error(
f"[{ctx.request_id}] 请求异常: {exc}",
exc_info=True
)
raise

# 异常处理器
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
ctx = getattr(request.state, "ctx", None)
request_id = ctx.request_id if ctx else "unknown"

return JSONResponse(
status_code=exc.status_code,
content={
"error": "HTTP_ERROR",
"message": str(exc.detail),
"request_id": request_id
},
headers={"X-Request-ID": request_id}
)

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
ctx = getattr(request.state, "ctx", None)
request_id = ctx.request_id if ctx else "unknown"

logger.error(f"[{request_id}] 未处理异常: {exc}", exc_info=True)

return JSONResponse(
status_code=500,
content={
"error": "INTERNAL_ERROR",
"message": "服务器内部错误",
"request_id": request_id
},
headers={"X-Request-ID": request_id}
)

# 路由
@app.get("/")
async def root():
return {"message": "Hello World"}

@app.get("/error")
async def trigger_error():
raise ValueError("这是一个测试错误")

@app.get("/http-error")
async def trigger_http_error():
raise HTTPException(status_code=400, detail="这是一个 HTTP 错误")

小结

本章我们学习了:

  1. 中间件基础:理解中间件的执行流程和洋葱模型
  2. 创建中间件:使用装饰器和类创建中间件
  3. 内置中间件:CORS、Trusted Host、GZip 等
  4. 异常处理:自定义异常和异常处理器
  5. 后台任务:响应返回后执行操作

中间件和异常处理是构建健壮 Web 应用的基础:

  • 中间件用于横切关注点(日志、认证、压缩等)
  • 异常处理用于统一的错误响应格式
  • 后台任务用于异步处理非关键操作

练习

  1. 创建一个中间件,记录每个请求的方法、路径、状态码和处理时间
  2. 实现一个自定义异常体系,包含业务异常和统一的错误响应格式
  3. 配置 CORS 中间件,允许特定域名的跨域请求
  4. 使用后台任务实现"发送欢迎邮件"功能,在用户注册后异步发送