中间件与异常处理
中间件允许你在请求到达路由处理函数之前和响应返回给客户端之后执行代码。异常处理则让你能够优雅地处理错误情况,返回有意义的错误信息。
中间件
什么是中间件
中间件是一个函数,它在每个请求处理前后执行:
- 接收请求
- 执行预处理代码
- 将请求传递给下一个处理程序
- 执行后处理代码
- 返回响应
创建中间件
使用 @app.middleware("http") 装饰器创建中间件:
import time
from fastapi import FastAPI, Request
app = FastAPI()
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
# 请求前:记录开始时间
start_time = time.perf_counter()
# 调用下一个处理程序
response = await call_next(request)
# 响应后:计算处理时间并添加到响应头
process_time = time.perf_counter() - start_time
response.headers["X-Process-Time"] = f"{process_time:.4f}"
return response
中间件参数
| 参数 | 说明 |
|---|---|
request | 请求对象,包含请求信息 |
call_next | 调用下一个处理程序的函数 |
请求和响应处理
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
import json
app = FastAPI()
@app.middleware("http")
async def log_requests(request: Request, call_next):
# 读取请求体(注意:只能读取一次)
body = await request.body()
# 记录请求信息
print(f"请求方法: {request.method}")
print(f"请求路径: {request.url.path}")
print(f"请求头: {dict(request.headers)}")
# 处理请求
response = await call_next(request)
# 可以修改响应
response.headers["X-Request-ID"] = request.headers.get("X-Request-ID", "unknown")
return response
多个中间件的执行顺序
中间件按添加顺序的逆序执行(后进先出):
@app.middleware("http")
async def middleware_a(request: Request, call_next):
print("A: 请求前")
response = await call_next(request)
print("A: 响应后")
return response
@app.middleware("http")
async def middleware_b(request: Request, call_next):
print("B: 请求前")
response = await call_next(request)
print("B: 响应后")
return response
执行顺序:
B: 请求前
A: 请求前
路由处理
A: 响应后
B: 响应后
内置中间件
CORS 中间件
处理跨域资源共享(CORS):
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许的源
allow_credentials=True, # 允许携带凭据
allow_methods=["*"], # 允许的方法
allow_headers=["*"], # 允许的请求头
expose_headers=["X-Total-Count"], # 暴露给客户端的响应头
max_age=600, # 预检请求缓存时间
)
生产环境建议限制 allow_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://yourdomain.com",
"https://www.yourdomain.com",
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
HTTPS 重定向
强制使用 HTTPS:
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
app.add_middleware(HTTPSRedirectMiddleware)
Trusted Host
限制允许的主机名:
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com", "localhost"]
)
GZip 压缩
自动压缩响应:
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=1000) # 超过 1000 字节才压缩
自定义中间件类
继承 BaseHTTPMiddleware 创建更复杂的中间件:
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 请求前处理
print(f"Processing: {request.url}")
# 调用下一个处理程序
response = await call_next(request)
# 响应后处理
response.headers["X-Custom-Header"] = "Custom Value"
return response
app = FastAPI()
app.add_middleware(CustomMiddleware)
带状态的中间件
from typing import Callable
from fastapi import FastAPI, Request
class RateLimitMiddleware:
def __init__(self, app, requests_per_minute: int = 60):
self.app = app
self.requests_per_minute = requests_per_minute
self.requests = {} # 实际应用应使用 Redis
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
client_ip = request.client.host
# 检查请求限制
# ... 实现限流逻辑
await self.app(scope, receive, send)
app = FastAPI()
app.add_middleware(RateLimitMiddleware, requests_per_minute=100)
异常处理
HTTPException
FastAPI 提供了 HTTPException 用于抛出 HTTP 错误:
from fastapi import FastAPI, HTTPException, status
app = FastAPI()
items = {"foo": "The Foo Wrestlers", "bar": "The Bar Wrestlers"}
@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id not in items:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="商品不存在",
headers={"X-Error": "Not Found"}
)
return {"item": items[item_id]}
HTTPException 参数:
| 参数 | 说明 |
|---|---|
status_code | HTTP 状态码 |
detail | 错误详情(会被 JSON 序列化) |
headers | 额外的响应头 |
自定义异常
创建自定义异常类:
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
class UnicornException(Exception):
def __init__(self, name: str):
self.name = name
app = FastAPI()
@app.exception_handler(UnicornException)
async def unicorn_exception_handler(request: Request, exc: UnicornException):
return JSONResponse(
status_code=418,
content={"message": f"Oops! {exc.name} did something."},
)
@app.get("/unicorns/{name}")
async def read_unicorn(name: str):
if name == "yolo":
raise UnicornException(name=name)
return {"unicorn_name": name}
覆盖默认异常处理器
FastAPI 有默认的异常处理器,可以覆盖它们:
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
app = FastAPI()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
# 自定义验证错误的响应格式
return JSONResponse(
status_code=422,
content={
"success": False,
"message": "请求参数验证失败",
"errors": exc.errors()
}
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
# 自定义 HTTP 异常的响应格式
return JSONResponse(
status_code=exc.status_code,
content={
"success": False,
"message": exc.detail
}
)
全局异常处理器
捕获所有未处理的异常:
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
# 记录错误日志
logger.exception("Unhandled exception: %s", exc)
# 返回友好的错误信息
return JSONResponse(
status_code=500,
content={
"success": False,
"message": "服务器内部错误",
"detail": str(exc) if app.debug else None
}
)
使用 starlette.status
使用 status 模块提高代码可读性:
from fastapi import FastAPI, HTTPException, status
app = FastAPI()
@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id not in items:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Item not found"
)
return items[item_id]
@app.post("/items/")
async def create_item(item: Item):
# 创建成功返回 201
return Response(status_code=status.HTTP_201_CREATED)
常用状态码常量:
| 常量 | 值 | 含义 |
|---|---|---|
HTTP_200_OK | 200 | 成功 |
HTTP_201_CREATED | 201 | 已创建 |
HTTP_204_NO_CONTENT | 204 | 无内容 |
HTTP_400_BAD_REQUEST | 400 | 错误请求 |
HTTP_401_UNAUTHORIZED | 401 | 未认证 |
HTTP_403_FORBIDDEN | 403 | 禁止访问 |
HTTP_404_NOT_FOUND | 404 | 未找到 |
HTTP_422_UNPROCESSABLE_ENTITY | 422 | 验证错误 |
HTTP_500_INTERNAL_SERVER_ERROR | 500 | 服务器错误 |
实际应用示例
请求日志中间件
import time
import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
logger = logging.getLogger(__name__)
app = FastAPI()
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
# 处理请求
response = await call_next(request)
# 记录日志
duration = time.time() - start_time
logger.info(
f"{request.method} {request.url.path} "
f"completed in {duration:.3f}s with status {response.status_code}"
)
return response
认证中间件
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import JSONResponse
from jose import jwt, JWTError
app = FastAPI()
SECRET_KEY = "your-secret-key"
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
# 排除不需要认证的路径
if request.url.path in ["/login", "/register", "/docs", "/openapi.json"]:
return await call_next(request)
# 获取 token
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证凭据"
)
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
# 将用户信息存入请求状态
request.state.user = payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据"
)
return await call_next(request)
错误响应统一格式
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from typing import Any
app = FastAPI()
class ErrorResponse(BaseModel):
"""统一错误响应格式"""
success: bool = False
code: int
message: str
data: Any = None
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
code=exc.status_code,
message=str(exc.detail)
).model_dump()
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content=ErrorResponse(
code=422,
message="请求参数验证失败",
data=exc.errors()
).model_dump()
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content=ErrorResponse(
code=500,
message="服务器内部错误"
).model_dump()
)
小结
本章我们学习了:
- 中间件基础:创建和使用中间件
- 中间件执行顺序:后进先出的处理方式
- 内置中间件:CORS、HTTPS 重定向、GZip 等
- 异常处理:HTTPException 和自定义异常
- 异常处理器:自定义错误响应格式
中间件的典型应用:
- 请求日志
- 认证授权
- 请求限流
- 响应压缩
- CORS 处理
异常处理的关键点:
- 使用 HTTPException 抛出标准 HTTP 错误
- 自定义异常处理器统一错误格式
- 全局异常处理器捕获未处理异常
练习
- 创建一个中间件,为每个请求添加唯一 ID
- 实现 CORS 中间件,只允许特定域名访问
- 创建自定义异常,并实现对应的异常处理器
- 实现一个简单的请求限流中间件