安全认证
安全性是 Web 应用的关键组成部分。FastAPI 提供了完善的安全工具,支持多种认证方式,包括 OAuth2、JWT、API Key 等。本章将详细介绍如何实现安全的用户认证和授权。
安全基础概念
认证 vs 授权
- 认证(Authentication):验证用户身份,回答"你是谁"的问题
- 授权(Authorization):验证用户权限,回答"你能做什么"的问题
常见认证方式
| 方式 | 说明 | 适用场景 |
|---|---|---|
| API Key | 简单的密钥认证 | 服务间调用、公开 API |
| OAuth2 密码流程 | 用户名密码换令牌 | 自家应用 |
| OAuth2 授权码 | 第三方登录 | 社交登录 |
| JWT | 无状态令牌认证 | 分布式系统、微服务 |
API Key 认证
API Key 是最简单的认证方式,适用于服务间调用或公开 API。
查询参数方式
from typing import Annotated
from fastapi import FastAPI, Query, HTTPException, Security
from fastapi.security import APIKeyQuery
app = FastAPI()
# 模拟有效的 API Keys
api_keys = ["api-key-1", "api-key-2", "api-key-3"]
# 定义 API Key 获取方式
api_key_query = APIKeyQuery(name="api_key")
async def get_api_key(
api_key: Annotated[str, Security(api_key_query)]
):
if api_key not in api_keys:
raise HTTPException(
status_code=401,
detail="无效的 API Key"
)
return api_key
@app.get("/items/")
async def read_items(api_key: Annotated[str, Depends(get_api_key)]):
return {"message": "受保护的数据", "api_key": api_key}
访问方式:/items/?api_key=api-key-1
Header 方式
更推荐使用请求头传递 API Key:
from typing import Annotated
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import APIKeyHeader
app = FastAPI()
api_keys = ["secret-api-key"]
api_key_header = APIKeyHeader(name="X-API-Key")
async def get_api_key(
api_key: Annotated[str, Security(api_key_header)]
):
if api_key not in api_keys:
raise HTTPException(
status_code=401,
detail="无效的 API Key"
)
return api_key
@app.get("/protected/")
async def protected_route(api_key: Annotated[str, Depends(get_api_key)]):
return {"message": "认证成功"}
访问方式:
curl -H "X-API-Key: secret-api-key" http://localhost:8000/protected/
组合多种方式
允许通过多种方式传递 API Key:
from typing import Annotated
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import APIKeyQuery, APIKeyHeader, APIKey
app = FastAPI()
api_keys = ["secret-key"]
# 定义多种获取方式
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def get_api_key(
api_key_query: Annotated[str | None, Security(api_key_query)] = None,
api_key_header: Annotated[str | None, Security(api_key_header)] = None
):
# 优先使用 Header,其次使用 Query
api_key = api_key_header or api_key_query
if api_key is None:
raise HTTPException(
status_code=401,
detail="缺少 API Key"
)
if api_key not in api_keys:
raise HTTPException(
status_code=401,
detail="无效的 API Key"
)
return api_key
@app.get("/protected/")
async def protected_route(api_key: Annotated[str, Depends(get_api_key)]):
return {"message": "认证成功", "api_key": api_key}
OAuth2 密码流程
OAuth2 密码流程适用于用户直接在应用中输入用户名和密码的场景。
基本实现
from typing import Annotated
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
app = FastAPI()
# 模拟用户数据库
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "[email protected]",
"hashed_password": "fakehashedsecret",
"disabled": False,
}
}
# OAuth2 令牌获取端点
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def fake_hash_password(password: str):
return "fakehashed" + password
class User(BaseModel):
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
def fake_decode_token(token: str):
# 实际应用中应该验证 JWT
user = get_user(fake_users_db, token)
return user
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)]
):
user = fake_decode_token(token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="用户已禁用")
return current_user
# 获取令牌端点
@app.post("/token")
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
user_dict = fake_users_db.get(form_data.username)
if not user_dict:
raise HTTPException(
status_code=400,
detail="用户名或密码错误"
)
user = UserInDB(**user_dict)
hashed_password = fake_hash_password(form_data.password)
if not hashed_password == user.hashed_password:
raise HTTPException(
status_code=400,
detail="用户名或密码错误"
)
return {"access_token": user.username, "token_type": "bearer"}
# 受保护的路由
@app.get("/users/me")
async def read_users_me(
current_user: Annotated[User, Depends(get_current_active_user)]
):
return current_user
OAuth2PasswordRequestForm
这是一个表单类,包含以下字段:
username: 用户名password: 密码scope: 权限范围(可选)grant_type: 授权类型(可选)client_id: 客户端 ID(可选)client_secret: 客户端密钥(可选)
客户端请求令牌的方式:
curl -X POST "http://localhost:8000/token" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "username=johndoe&password=secret"
JWT 令牌认证
JWT(JSON Web Token)是一种无状态的认证方式,非常适合微服务架构。
安装依赖
pip install pyjwt "pwdlib[argon2]"
完整实现
from datetime import datetime, timedelta, timezone
from typing import Annotated
import jwt
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jwt.exceptions import InvalidTokenError
from pwdlib import PasswordHash
from pydantic import BaseModel
# 配置
SECRET_KEY = "your-secret-key-here" # 生产环境应使用 openssl rand -hex 32 生成
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
app = FastAPI()
# 密码哈希
password_hash = PasswordHash.recommended()
DUMMY_HASH = password_hash.hash("dummypassword")
# 模拟用户数据库
fake_users_db = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "[email protected]",
"hashed_password": password_hash.hash("secret"),
"disabled": False,
}
}
# 模型
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
class User(BaseModel):
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
# OAuth2
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 工具函数
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return password_hash.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""生成密码哈希"""
return password_hash.hash(password)
def get_user(db: dict, username: str) -> UserInDB | None:
"""从数据库获取用户"""
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
return None
def authenticate_user(db: dict, username: str, password: str) -> UserInDB | bool:
"""认证用户"""
user = get_user(db, username)
if not user:
# 即使没有用户也验证一次,防止时序攻击
verify_password(password, DUMMY_HASH)
return False
if not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)]
) -> User:
"""获取当前用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
user = get_user(fake_users_db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
) -> User:
"""获取当前活跃用户"""
if current_user.disabled:
raise HTTPException(status_code=400, detail="用户已禁用")
return current_user
# 路由
@app.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
) -> Token:
"""获取访问令牌"""
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
@app.get("/users/me", response_model=User)
async def read_users_me(
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""获取当前用户信息"""
return current_user
@app.get("/users/me/items")
async def read_own_items(
current_user: Annotated[User, Depends(get_current_active_user)]
):
"""获取当前用户的物品"""
return [{"item_id": "Foo", "owner": current_user.username}]
JWT 令牌结构
JWT 令牌由三部分组成,用点号分隔:
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c
- Header(头部):算法和令牌类型
- Payload(载荷):用户数据(如用户名、过期时间)
- Signature(签名):验证令牌是否被篡改
安全注意事项
- 密钥管理:不要在代码中硬编码密钥,使用环境变量
- HTTPS:生产环境必须使用 HTTPS
- 令牌过期:设置合理的过期时间
- 刷新令牌:长期应用应实现刷新令牌机制
import os
from datetime import timedelta
# 从环境变量读取配置
SECRET_KEY = os.environ.get("SECRET_KEY")
if not SECRET_KEY:
raise ValueError("必须设置 SECRET_KEY 环境变量")
# 不同类型的令牌可以有不同的过期时间
ACCESS_TOKEN_EXPIRE = timedelta(minutes=15)
REFRESH_TOKEN_EXPIRE = timedelta(days=7)
权限控制
基于角色的访问控制
from enum import Enum
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException, status
app = FastAPI()
class Role(str, Enum):
ADMIN = "admin"
USER = "user"
GUEST = "guest"
class User(BaseModel):
username: str
role: Role
# 模拟获取当前用户
async def get_current_user() -> User:
# 实际应用中从 JWT 或 Session 获取
return User(username="johndoe", role=Role.USER)
# 权限检查工厂函数
def require_role(required_role: Role):
"""创建角色检查依赖"""
async def role_checker(current_user: Annotated[User, Depends(get_current_user)]):
# 定义角色权限级别
role_hierarchy = {
Role.ADMIN: 3,
Role.USER: 2,
Role.GUEST: 1
}
user_level = role_hierarchy.get(current_user.role, 0)
required_level = role_hierarchy.get(required_role, 0)
if user_level < required_level:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足"
)
return current_user
return role_checker
# 不同权限级别的路由
@app.get("/public")
async def public_route():
"""公开路由,无需认证"""
return {"message": "公开数据"}
@app.get("/user", dependencies=[Depends(require_role(Role.USER))])
async def user_route():
"""需要用户权限"""
return {"message": "用户数据"}
@app.get("/admin", dependencies=[Depends(require_role(Role.ADMIN))])
async def admin_route():
"""需要管理员权限"""
return {"message": "管理员数据"}
权限装饰器
更优雅的权限控制方式:
from functools import wraps
from typing import Callable
def require_permissions(*permissions: str):
"""权限装饰器"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, current_user: User = Depends(get_current_user), **kwargs):
# 检查用户是否拥有所需权限
user_permissions = get_user_permissions(current_user) # 假设的函数
for perm in permissions:
if perm not in user_permissions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"缺少权限: {perm}"
)
return await func(*args, current_user=current_user, **kwargs)
return wrapper
return decorator
@app.get("/sensitive-data")
@require_permissions("read:sensitive", "admin:access")
async def get_sensitive_data(current_user: User = Depends(get_current_user)):
return {"data": "敏感数据"}
密码安全
密码哈希
永远不要明文存储密码,使用安全的哈希算法:
from pwdlib import PasswordHash
# 使用推荐的 Argon2 算法
password_hash = PasswordHash.recommended()
# 哈希密码
hashed = password_hash.hash("user_password")
# 验证密码
is_valid = password_hash.verify("user_password", hashed)
密码强度验证
import re
from pydantic import BaseModel, field_validator
class UserCreate(BaseModel):
username: str
password: str
@field_validator('password')
@classmethod
def validate_password(cls, v: str) -> str:
if len(v) < 8:
raise ValueError('密码至少需要 8 个字符')
if not re.search(r'[A-Z]', v):
raise ValueError('密码需要包含至少一个大写字母')
if not re.search(r'[a-z]', v):
raise ValueError('密码需要包含至少一个小写字母')
if not re.search(r'\d', v):
raise ValueError('密码需要包含至少一个数字')
return v
CORS 配置
跨域资源共享(CORS)是前后端分离应用的重要安全配置:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000", # 开发环境
"https://myapp.com", # 生产环境
],
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
max_age=3600, # 预检请求缓存时间
)
# 或使用正则匹配
app.add_middleware(
CORSMiddleware,
allow_origin_regex=r"https://.*\.myapp\.com", # 匹配所有子域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
安全最佳实践
1. 使用 HTTPS
生产环境必须使用 HTTPS:
from fastapi import FastAPI, Request
app = FastAPI()
@app.middleware("http")
async def ensure_https(request: Request, call_next):
# 检查是否是 HTTPS
if request.url.scheme != "https":
# 重定向到 HTTPS
https_url = request.url.replace(scheme="https")
from fastapi.responses import RedirectResponse
return RedirectResponse(https_url, status_code=301)
return await call_next(request)
2. 防止常见攻击
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
app = FastAPI()
# 速率限制(需要配合 Redis 等)
rate_limit_store = {}
@app.middleware("http")
async def rate_limit(request: Request, call_next):
client_ip = request.client.host
# 简单的速率限制示例
if client_ip in rate_limit_store:
if rate_limit_store[client_ip] > 100: # 每分钟最多 100 次请求
raise HTTPException(status_code=429, detail="请求过于频繁")
rate_limit_store[client_ip] += 1
else:
rate_limit_store[client_ip] = 1
return await call_next(request)
3. 安全头设置
from fastapi import FastAPI
from starlette.middleware.trustedhost import TrustedHostMiddleware
app = FastAPI()
# 信任的主机
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com"]
)
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
# 安全头
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
完整示例
from datetime import datetime, timedelta, timezone
from typing import Annotated
import os
import jwt
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jwt.exceptions import InvalidTokenError
from pwdlib import PasswordHash
from pydantic import BaseModel, EmailStr, field_validator
# 配置
SECRET_KEY = os.environ.get("SECRET_KEY", "dev-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
app = FastAPI(title="安全认证示例")
# 密码哈希
password_hash = PasswordHash.recommended()
DUMMY_HASH = password_hash.hash("dummypassword")
# 模拟数据库
users_db: dict[str, dict] = {}
# 模型
class UserCreate(BaseModel):
username: str
email: EmailStr
password: str
@field_validator('password')
@classmethod
def validate_password(cls, v: str) -> str:
if len(v) < 8:
raise ValueError('密码至少需要 8 个字符')
return v
class User(BaseModel):
username: str
email: str
disabled: bool = False
class UserInDB(User):
hashed_password: str
class Token(BaseModel):
access_token: str
token_type: str
# OAuth2
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 工具函数
def get_user(username: str) -> UserInDB | None:
if username in users_db:
return UserInDB(**users_db[username])
return None
def authenticate_user(username: str, password: str) -> UserInDB | bool:
user = get_user(username)
if not user:
password_hash.verify(password, DUMMY_HASH)
return False
if not password_hash.verify(password, user.hashed_password):
return False
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
if not username:
raise credentials_exception
except InvalidTokenError:
raise credentials_exception
user = get_user(username)
if not user:
raise credentials_exception
return user
# 路由
@app.post("/register", status_code=201)
async def register(user: UserCreate):
"""用户注册"""
if user.username in users_db:
raise HTTPException(status_code=400, detail="用户名已存在")
users_db[user.username] = {
"username": user.username,
"email": user.email,
"hashed_password": password_hash.hash(user.password),
"disabled": False
}
return {"message": "注册成功"}
@app.post("/token", response_model=Token)
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
"""用户登录"""
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return Token(access_token=access_token, token_type="bearer")
@app.get("/users/me", response_model=User)
async def read_users_me(current_user: Annotated[User, Depends(get_current_user)]):
"""获取当前用户信息"""
return current_user
小结
本章我们学习了:
- API Key 认证:最简单的认证方式
- OAuth2 密码流程:用户名密码换令牌
- JWT 令牌:无状态认证的实现
- 权限控制:基于角色的访问控制
- 密码安全:哈希和强度验证
- CORS 配置:跨域资源共享
- 安全最佳实践:HTTPS、速率限制、安全头
安全是 Web 应用的基石,务必在生产环境中落实这些安全措施。
练习
- 实现一个完整的用户注册、登录、获取用户信息流程
- 添加基于角色的权限控制,区分管理员和普通用户
- 实现密码强度验证,要求至少包含大小写字母和数字
- 配置 CORS,只允许特定域名的请求