跳到主要内容

中间件与异常处理

中间件允许你在请求到达路由处理函数之前和响应返回给客户端之后执行代码。异常处理则让你能够优雅地处理错误情况,返回有意义的错误信息。

中间件

什么是中间件

中间件是一个函数,它在每个请求处理前后执行:

  1. 接收请求
  2. 执行预处理代码
  3. 将请求传递给下一个处理程序
  4. 执行后处理代码
  5. 返回响应

创建中间件

使用 @app.middleware("http") 装饰器创建中间件:

import time
from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
# 请求前:记录开始时间
start_time = time.perf_counter()

# 调用下一个处理程序
response = await call_next(request)

# 响应后:计算处理时间并添加到响应头
process_time = time.perf_counter() - start_time
response.headers["X-Process-Time"] = f"{process_time:.4f}"

return response

中间件参数

参数说明
request请求对象,包含请求信息
call_next调用下一个处理程序的函数

请求和响应处理

from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
import json

app = FastAPI()

@app.middleware("http")
async def log_requests(request: Request, call_next):
# 读取请求体(注意:只能读取一次)
body = await request.body()

# 记录请求信息
print(f"请求方法: {request.method}")
print(f"请求路径: {request.url.path}")
print(f"请求头: {dict(request.headers)}")

# 处理请求
response = await call_next(request)

# 可以修改响应
response.headers["X-Request-ID"] = request.headers.get("X-Request-ID", "unknown")

return response

多个中间件的执行顺序

中间件按添加顺序的逆序执行(后进先出):

@app.middleware("http")
async def middleware_a(request: Request, call_next):
print("A: 请求前")
response = await call_next(request)
print("A: 响应后")
return response

@app.middleware("http")
async def middleware_b(request: Request, call_next):
print("B: 请求前")
response = await call_next(request)
print("B: 响应后")
return response

执行顺序:

B: 请求前
A: 请求前
路由处理
A: 响应后
B: 响应后

内置中间件

CORS 中间件

处理跨域资源共享(CORS):

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许的源
allow_credentials=True, # 允许携带凭据
allow_methods=["*"], # 允许的方法
allow_headers=["*"], # 允许的请求头
expose_headers=["X-Total-Count"], # 暴露给客户端的响应头
max_age=600, # 预检请求缓存时间
)

生产环境建议限制 allow_origins

app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://yourdomain.com",
"https://www.yourdomain.com",
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)

HTTPS 重定向

强制使用 HTTPS:

from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

app.add_middleware(HTTPSRedirectMiddleware)

Trusted Host

限制允许的主机名:

from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com", "localhost"]
)

GZip 压缩

自动压缩响应:

from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(GZipMiddleware, minimum_size=1000) # 超过 1000 字节才压缩

自定义中间件类

继承 BaseHTTPMiddleware 创建更复杂的中间件:

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response

class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 请求前处理
print(f"Processing: {request.url}")

# 调用下一个处理程序
response = await call_next(request)

# 响应后处理
response.headers["X-Custom-Header"] = "Custom Value"

return response

app = FastAPI()
app.add_middleware(CustomMiddleware)

带状态的中间件

from typing import Callable
from fastapi import FastAPI, Request

class RateLimitMiddleware:
def __init__(self, app, requests_per_minute: int = 60):
self.app = app
self.requests_per_minute = requests_per_minute
self.requests = {} # 实际应用应使用 Redis

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return

request = Request(scope, receive)
client_ip = request.client.host

# 检查请求限制
# ... 实现限流逻辑

await self.app(scope, receive, send)

app = FastAPI()
app.add_middleware(RateLimitMiddleware, requests_per_minute=100)

异常处理

HTTPException

FastAPI 提供了 HTTPException 用于抛出 HTTP 错误:

from fastapi import FastAPI, HTTPException, status

app = FastAPI()

items = {"foo": "The Foo Wrestlers", "bar": "The Bar Wrestlers"}

@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id not in items:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="商品不存在",
headers={"X-Error": "Not Found"}
)
return {"item": items[item_id]}

HTTPException 参数:

参数说明
status_codeHTTP 状态码
detail错误详情(会被 JSON 序列化)
headers额外的响应头

自定义异常

创建自定义异常类:

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

class UnicornException(Exception):
def __init__(self, name: str):
self.name = name

app = FastAPI()

@app.exception_handler(UnicornException)
async def unicorn_exception_handler(request: Request, exc: UnicornException):
return JSONResponse(
status_code=418,
content={"message": f"Oops! {exc.name} did something."},
)

@app.get("/unicorns/{name}")
async def read_unicorn(name: str):
if name == "yolo":
raise UnicornException(name=name)
return {"unicorn_name": name}

覆盖默认异常处理器

FastAPI 有默认的异常处理器,可以覆盖它们:

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError

app = FastAPI()

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
# 自定义验证错误的响应格式
return JSONResponse(
status_code=422,
content={
"success": False,
"message": "请求参数验证失败",
"errors": exc.errors()
}
)

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
# 自定义 HTTP 异常的响应格式
return JSONResponse(
status_code=exc.status_code,
content={
"success": False,
"message": exc.detail
}
)

全局异常处理器

捕获所有未处理的异常:

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import logging

logger = logging.getLogger(__name__)

app = FastAPI()

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
# 记录错误日志
logger.exception("Unhandled exception: %s", exc)

# 返回友好的错误信息
return JSONResponse(
status_code=500,
content={
"success": False,
"message": "服务器内部错误",
"detail": str(exc) if app.debug else None
}
)

使用 starlette.status

使用 status 模块提高代码可读性:

from fastapi import FastAPI, HTTPException, status

app = FastAPI()

@app.get("/items/{item_id}")
async def read_item(item_id: str):
if item_id not in items:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Item not found"
)
return items[item_id]

@app.post("/items/")
async def create_item(item: Item):
# 创建成功返回 201
return Response(status_code=status.HTTP_201_CREATED)

常用状态码常量:

常量含义
HTTP_200_OK200成功
HTTP_201_CREATED201已创建
HTTP_204_NO_CONTENT204无内容
HTTP_400_BAD_REQUEST400错误请求
HTTP_401_UNAUTHORIZED401未认证
HTTP_403_FORBIDDEN403禁止访问
HTTP_404_NOT_FOUND404未找到
HTTP_422_UNPROCESSABLE_ENTITY422验证错误
HTTP_500_INTERNAL_SERVER_ERROR500服务器错误

实际应用示例

请求日志中间件

import time
import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

logger = logging.getLogger(__name__)

app = FastAPI()

@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()

# 处理请求
response = await call_next(request)

# 记录日志
duration = time.time() - start_time
logger.info(
f"{request.method} {request.url.path} "
f"completed in {duration:.3f}s with status {response.status_code}"
)

return response

认证中间件

from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import JSONResponse
from jose import jwt, JWTError

app = FastAPI()

SECRET_KEY = "your-secret-key"

@app.middleware("http")
async def auth_middleware(request: Request, call_next):
# 排除不需要认证的路径
if request.url.path in ["/login", "/register", "/docs", "/openapi.json"]:
return await call_next(request)

# 获取 token
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证凭据"
)

token = auth_header.split(" ")[1]

try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
# 将用户信息存入请求状态
request.state.user = payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据"
)

return await call_next(request)

错误响应统一格式

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from typing import Any

app = FastAPI()

class ErrorResponse(BaseModel):
"""统一错误响应格式"""
success: bool = False
code: int
message: str
data: Any = None

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
code=exc.status_code,
message=str(exc.detail)
).model_dump()
)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
return JSONResponse(
status_code=422,
content=ErrorResponse(
code=422,
message="请求参数验证失败",
data=exc.errors()
).model_dump()
)

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content=ErrorResponse(
code=500,
message="服务器内部错误"
).model_dump()
)

小结

本章我们学习了:

  1. 中间件基础:创建和使用中间件
  2. 中间件执行顺序:后进先出的处理方式
  3. 内置中间件:CORS、HTTPS 重定向、GZip 等
  4. 异常处理:HTTPException 和自定义异常
  5. 异常处理器:自定义错误响应格式

中间件的典型应用:

  • 请求日志
  • 认证授权
  • 请求限流
  • 响应压缩
  • CORS 处理

异常处理的关键点:

  • 使用 HTTPException 抛出标准 HTTP 错误
  • 自定义异常处理器统一错误格式
  • 全局异常处理器捕获未处理异常

练习

  1. 创建一个中间件,为每个请求添加唯一 ID
  2. 实现 CORS 中间件,只允许特定域名访问
  3. 创建自定义异常,并实现对应的异常处理器
  4. 实现一个简单的请求限流中间件