数据库集成
FastAPI 不强制使用特定的数据库,可以与任何数据库或 ORM 配合使用。本章介绍如何使用 SQLModel(基于 SQLAlchemy 和 Pydantic)进行数据库操作。
SQLModel 简介
SQLModel 是由 FastAPI 作者创建的 ORM 库,它结合了 SQLAlchemy 和 Pydantic 的优势:
- 基于 Pydantic:模型定义即是数据验证
- 基于 SQLAlchemy:强大的 SQL 支持
- 与 FastAPI 完美集成
- 支持多种数据库:PostgreSQL、MySQL、SQLite 等
安装
pip install sqlmodel
# 或安装异步支持
pip install "sqlmodel[asyncio]"
数据库连接
SQLite
from sqlmodel import SQLModel, create_engine
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = create_engine(sqlite_url, echo=True) # echo=True 打印 SQL 语句
对于 SQLite,需要添加 check_same_thread 参数:
engine = create_engine(
sqlite_url,
connect_args={"check_same_thread": False}
)
PostgreSQL
from sqlmodel import create_engine
postgres_url = "postgresql://user:password@localhost:5432/mydatabase"
engine = create_engine(postgres_url)
MySQL
mysql_url = "mysql+pymysql://user:password@localhost:3306/mydatabase"
engine = create_engine(mysql_url)
异步连接
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
async_engine = create_async_engine(
"postgresql+asyncpg://user:password@localhost/db",
echo=True
)
AsyncSessionLocal = sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False
)
定义模型
基本模型
from typing import Optional
from sqlmodel import Field, SQLModel
class Hero(SQLModel, table=True):
"""英雄模型"""
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True) # 创建索引
secret_name: str
age: Optional[int] = Field(default=None, index=True)
table=True 表示这是一个数据库表模型,而不仅仅是 Pydantic 模型。
字段类型
from datetime import datetime
from decimal import Decimal
from sqlmodel import Field, SQLModel
from typing import Optional
class Product(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=100)
description: Optional[str] = Field(default=None)
price: Decimal = Field(decimal_places=2)
quantity: int = Field(default=0)
is_active: bool = Field(default=True)
created_at: datetime = Field(default_factory=datetime.now)
# 外键
category_id: Optional[int] = Field(default=None, foreign_key="category.id")
关系定义
from typing import List, Optional
from sqlmodel import Field, Relationship, SQLModel
class Category(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
description: Optional[str] = None
# 一对多关系
products: List["Product"] = Relationship(back_populates="category")
class Product(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
price: float
category_id: Optional[int] = Field(default=None, foreign_key="category.id")
# 关系属性
category: Optional[Category] = Relationship(back_populates="products")
# 多对多关系
class HeroTeamLink(SQLModel, table=True):
team_id: Optional[int] = Field(default=None, foreign_key="team.id", primary_key=True)
hero_id: Optional[int] = Field(default=None, foreign_key="hero.id", primary_key=True)
class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
heroes: List["Hero"] = Relationship(back_populates="teams", link_model=HeroTeamLink)
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
teams: List[Team] = Relationship(back_populates="heroes", link_model=HeroTeamLink)
数据库会话管理
使用依赖注入
from typing import Annotated
from fastapi import FastAPI, Depends
from sqlmodel import Session, SQLModel, create_engine
app = FastAPI()
engine = create_engine("sqlite:///database.db")
def get_session():
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
Lifespan 方式(推荐)
使用 lifespan 参数管理应用的生命周期事件,包括数据库表的创建和资源清理。这是 FastAPI 官方推荐的方式,比 on_event 装饰器更灵活、更易于测试:
from contextlib import asynccontextmanager
from fastapi import FastAPI
from sqlmodel import SQLModel, create_engine
engine = create_engine("sqlite:///database.db")
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时:创建数据库表
SQLModel.metadata.create_all(engine)
yield
# 关闭时:清理资源(如关闭连接池等)
# engine.dispose() # 如果需要
app = FastAPI(lifespan=lifespan)
工作原理:
lifespan 是一个异步上下文管理器,yield 之前的代码在应用启动时执行,yield 之后的代码在应用关闭时执行。这种方式的优点:
- 资源管理更清晰:启动和关闭逻辑在同一函数中,便于维护
- 状态共享方便:可以在启动时初始化状态,存储在
app.state中 - 更易于测试:可以使用
TestClient的上下文管理器触发生命周期事件 - 官方推荐:FastAPI 官方明确推荐使用此方式替代
on_event
状态共享
如果需要在路由中访问启动时创建的资源,可以存储在 app.state 中:
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from sqlmodel import Session, SQLModel, create_engine
engine = create_engine("sqlite:///database.db")
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化资源
app.state.engine = engine
SQLModel.metadata.create_all(engine)
yield
# 关闭时清理
app.state.engine = None
app = FastAPI(lifespan=lifespan)
@app.get("/info")
async def info(request: Request):
# 访问应用状态
return {"engine": str(request.app.state.engine)}
关于 on_event(已不推荐)
虽然 FastAPI 仍然支持 @app.on_event("startup") 和 @app.on_event("shutdown") 装饰器,但官方明确推荐使用 lifespan 方式:
# 已不推荐的方式
@app.on_event("startup")
def on_startup():
SQLModel.metadata.create_all(engine)
@app.on_event("shutdown")
def on_shutdown():
pass # 清理逻辑
注意:如果同时提供了 lifespan 参数和 on_event 装饰器,只有 lifespan 中的代码会被执行,on_event 处理器将被忽略。
CRUD 操作
创建(Create)
from fastapi import FastAPI
from sqlmodel import Session
app = FastAPI()
@app.post("/heroes/", response_model=Hero)
def create_hero(hero: Hero, session: SessionDep):
session.add(hero)
session.commit()
session.refresh(hero) # 刷新以获取数据库生成的字段
return hero
读取(Read)
from sqlmodel import select
@app.get("/heroes/", response_model=list[Hero])
def read_heroes(
session: SessionDep,
offset: int = 0,
limit: int = Query(default=100, le=100)
):
heroes = session.exec(select(Hero).offset(offset).limit(limit)).all()
return heroes
@app.get("/heroes/{hero_id}", response_model=Hero)
def read_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
return hero
更新(Update)
from fastapi import HTTPException
@app.patch("/heroes/{hero_id}", response_model=Hero)
def update_hero(hero_id: int, hero: HeroUpdate, session: SessionDep):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.model_dump(exclude_unset=True) # 只包含设置的字段
for key, value in hero_data.items():
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
删除(Delete)
@app.delete("/heroes/{hero_id}")
def delete_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
return {"ok": True}
模型分离
实际项目中,通常将模型分为不同的用途:
from datetime import datetime
from typing import Optional
from sqlmodel import Field, SQLModel
# 基础模型:包含公共字段
class HeroBase(SQLModel):
name: str
secret_name: str
age: Optional[int] = None
# 创建模型:用于接收创建请求
class HeroCreate(HeroBase):
pass
# 更新模型:用于接收更新请求,所有字段可选
class HeroUpdate(SQLModel):
name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
# 数据库模型:映射到数据库表
class Hero(HeroBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=datetime.now)
# 读取模型:用于返回响应
class HeroRead(HeroBase):
id: int
created_at: datetime
使用示例:
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate, session: SessionDep):
db_hero = Hero.model_validate(hero) # 转换为数据库模型
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
@app.patch("/heroes/{hero_id}", response_model=HeroRead)
def update_hero(hero_id: int, hero: HeroUpdate, session: SessionDep):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.model_dump(exclude_unset=True)
for key, value in hero_data.items():
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
查询操作
基本查询
from sqlmodel import select
# 查询所有
heroes = session.exec(select(Hero)).all()
# 条件查询
young_heroes = session.exec(
select(Hero).where(Hero.age < 30)
).all()
# 排序
sorted_heroes = session.exec(
select(Hero).order_by(Hero.age.desc())
).all()
# 分页
page_heroes = session.exec(
select(Hero).offset(0).limit(10)
).all()
复杂条件
from sqlmodel import or_, and_, col
# OR 条件
heroes = session.exec(
select(Hero).where(or_(Hero.age < 20, Hero.age > 50))
).all()
# AND 条件
heroes = session.exec(
select(Hero).where(
and_(
Hero.age >= 20,
Hero.age <= 50,
Hero.name.like("%man%")
)
)
).all()
# IN 条件
heroes = session.exec(
select(Hero).where(col(Hero.name).in_(["Spider-Man", "Iron Man"]))
).all()
# 模糊查询
heroes = session.exec(
select(Hero).where(Hero.name.ilike("%spider%")) # 不区分大小写
).all()
关联查询
# 预加载关联数据
from sqlmodel import select
from sqlalchemy.orm import selectinload
heroes_with_teams = session.exec(
select(Hero).options(selectinload(Hero.teams))
).all()
聚合查询
from sqlalchemy import func
# 统计数量
count = session.exec(select(func.count(Hero.id))).one()
# 分组统计
result = session.exec(
select(Hero.age, func.count(Hero.id)).group_by(Hero.age)
).all()
# 聚合函数
stats = session.exec(
select(
func.min(Hero.age),
func.max(Hero.age),
func.avg(Hero.age)
)
).one()
完整示例
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Annotated, Optional
from fastapi import FastAPI, HTTPException, Query, Depends
from pydantic import BaseModel
from sqlalchemy import func
from sqlmodel import Field, Session, SQLModel, create_engine, select
# 数据库配置
sqlite_url = "sqlite:///heroes.db"
engine = create_engine(sqlite_url, connect_args={"check_same_thread": False})
# 模型定义
class HeroBase(SQLModel):
name: str = Field(min_length=1, max_length=100)
secret_name: str
age: Optional[int] = Field(default=None, ge=0)
class HeroCreate(HeroBase):
pass
class HeroUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=100)
secret_name: Optional[str] = None
age: Optional[int] = Field(None, ge=0)
class Hero(HeroBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class HeroRead(HeroBase):
id: int
created_at: datetime
updated_at: datetime
class HeroList(BaseModel):
items: list[HeroRead]
total: int
page: int
page_size: int
# 会话依赖
def get_session():
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
# Lifespan 事件管理(推荐方式)
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时:创建数据库表
SQLModel.metadata.create_all(engine)
yield
# 关闭时:清理资源(如需要)
# engine.dispose()
# 应用
app = FastAPI(title="Hero API", lifespan=lifespan)
# 路由
@app.post("/heroes/", response_model=HeroRead, status_code=201)
def create_hero(hero: HeroCreate, session: SessionDep):
db_hero = Hero.model_validate(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
@app.get("/heroes/", response_model=HeroList)
def read_heroes(
session: SessionDep,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: Optional[str] = None
):
offset = (page - 1) * page_size
# 构建查询
query = select(Hero)
count_query = select(func.count(Hero.id))
if name:
query = query.where(Hero.name.ilike(f"%{name}%"))
count_query = count_query.where(Hero.name.ilike(f"%{name}%"))
# 获取数据
heroes = session.exec(query.offset(offset).limit(page_size)).all()
total = session.exec(count_query).one()
return HeroList(
items=heroes,
total=total,
page=page,
page_size=page_size
)
@app.get("/heroes/{hero_id}", response_model=HeroRead)
def read_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
return hero
@app.patch("/heroes/{hero_id}", response_model=HeroRead)
def update_hero(hero_id: int, hero: HeroUpdate, session: SessionDep):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.model_dump(exclude_unset=True)
for key, value in hero_data.items():
setattr(db_hero, key, value)
db_hero.updated_at = datetime.now()
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
@app.delete("/heroes/{hero_id}")
def delete_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
return {"message": "Hero deleted successfully"}
异步数据库操作
对于高性能应用,推荐使用异步数据库操作。FastAPI 原生支持异步,配合异步数据库驱动可以获得更好的并发性能。
异步引擎配置
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlmodel import SQLModel
from typing import AsyncGenerator
# 异步引擎(注意 URL 格式)
async_engine = create_async_engine(
"postgresql+asyncpg://user:password@localhost/db",
echo=True,
pool_size=10,
max_overflow=20
)
# 异步会话工厂
AsyncSessionLocal = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False
)
# 异步会话依赖
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as session:
yield session
AsyncSessionDep = Annotated[AsyncSession, Depends(get_async_session)]
异步 CRUD 示例
from contextlib import asynccontextmanager
from typing import Annotated, Optional
from fastapi import FastAPI, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlmodel import Field, SQLModel
# 模型定义
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
secret_name: str
age: Optional[int] = Field(default=None, index=True)
class HeroCreate(BaseModel):
name: str
secret_name: str
age: Optional[int] = None
class HeroUpdate(BaseModel):
name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
# 数据库配置
async_engine = create_async_engine("postgresql+asyncpg://user:pass@localhost/db")
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
async def get_session():
async with AsyncSessionLocal() as session:
yield session
SessionDep = Annotated[AsyncSession, Depends(get_session)]
# 应用
app = FastAPI()
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时创建表
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield
app = FastAPI(lifespan=lifespan)
# 异步 CRUD 操作
@app.post("/heroes/", response_model=Hero)
async def create_hero(hero: HeroCreate, session: SessionDep):
"""创建英雄(异步)"""
db_hero = Hero(**hero.model_dump())
session.add(db_hero)
await session.commit()
await session.refresh(db_hero)
return db_hero
@app.get("/heroes/", response_model=list[Hero])
async def read_heroes(
session: SessionDep,
offset: int = 0,
limit: int = Query(default=100, le=100)
):
"""获取英雄列表(异步)"""
result = await session.execute(select(Hero).offset(offset).limit(limit))
return result.scalars().all()
@app.get("/heroes/{hero_id}", response_model=Hero)
async def read_hero(hero_id: int, session: SessionDep):
"""获取单个英雄(异步)"""
hero = await session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
return hero
@app.patch("/heroes/{hero_id}", response_model=Hero)
async def update_hero(hero_id: int, hero: HeroUpdate, session: SessionDep):
"""更新英雄(异步)"""
db_hero = await session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.model_dump(exclude_unset=True)
for key, value in hero_data.items():
setattr(db_hero, key, value)
session.add(db_hero)
await session.commit()
await session.refresh(db_hero)
return db_hero
@app.delete("/heroes/{hero_id}")
async def delete_hero(hero_id: int, session: SessionDep):
"""删除英雄(异步)"""
hero = await session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
await session.delete(hero)
await session.commit()
return {"ok": True}
异步查询技巧
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
# 条件查询
async def get_heroes_by_age(session: AsyncSession, min_age: int):
result = await session.execute(
select(Hero).where(Hero.age >= min_age).order_by(Hero.age)
)
return result.scalars().all()
# 聚合查询
async def count_heroes(session: AsyncSession):
result = await session.execute(select(func.count(Hero.id)))
return result.scalar()
# 关联查询(预加载)
async def get_heroes_with_teams(session: AsyncSession):
result = await session.execute(
select(Hero).options(selectinload(Hero.teams))
)
return result.scalars().all()
# 分页查询
async def get_heroes_paginated(session: AsyncSession, page: int, page_size: int):
offset = (page - 1) * page_size
# 获取数据
result = await session.execute(
select(Hero).offset(offset).limit(page_size)
)
items = result.scalars().all()
# 获取总数
count_result = await session.execute(select(func.count(Hero.id)))
total = count_result.scalar()
return {"items": items, "total": total, "page": page, "page_size": page_size}
异步数据库驱动
不同数据库的异步驱动:
| 数据库 | 异步驱动 | 连接字符串示例 |
|---|---|---|
| PostgreSQL | asyncpg | postgresql+asyncpg://user:pass@host/db |
| MySQL | aiomysql | mysql+aiomysql://user:pass@host/db |
| SQLite | aiosqlite | sqlite+aiosqlite:///./database.db |
安装异步驱动:
pip install asyncpg # PostgreSQL
pip install aiomysql # MySQL
pip install aiosqlite # SQLite
同步 vs 异步选择
使用异步的场景:
- 高并发 API 服务
- 需要同时处理多个数据库请求
- 与其他异步 I/O 操作配合(如 HTTP 请求、文件操作)
使用同步的场景:
- 简单的 CRUD 应用
- 数据库操作较少
- 不需要高并发
实际项目中,异步和同步可以混合使用。FastAPI 会智能地将同步数据库操作放到线程池中执行。
数据库迁移
使用 Alembic
# 安装 Alembic
pip install alembic
# 初始化
alembic init alembic
# 配置 alembic.ini 和 env.py
在 alembic/env.py 中配置:
from sqlmodel import SQLModel
from models import * # 导入所有模型
target_metadata = SQLModel.metadata
常用命令:
# 自动生成迁移文件
alembic revision --autogenerate -m "Add hero table"
# 执行迁移
alembic upgrade head
# 回退迁移
alembic downgrade -1
小结
本章我们学习了:
- 数据库连接:SQLite、PostgreSQL、MySQL 的连接方式
- 模型定义:使用 SQLModel 定义数据库表
- 关系映射:一对一、一对多、多对多关系
- CRUD 操作:增删改查的完整实现
- 模型分离:创建、更新、读取模型的分离
- 查询操作:条件查询、排序、分页、聚合
- 数据库迁移:使用 Alembic 管理数据库版本
SQLModel 的优势:
- 与 Pydantic 无缝集成
- 类型安全的查询
- 自动生成 OpenAPI 文档
- 支持 SQLAlchemy 的所有功能
练习
- 创建一个用户管理 API,包含注册、登录、个人信息修改功能
- 实现商品和分类的一对多关系,支持按分类查询商品
- 实现分页查询,包含总数统计
- 使用 Alembic 添加数据库迁移支持