数据库集成
FastAPI 不强制使用特定的数据库,可以与任何数据库或 ORM 配合使用。本章介绍如何使用 SQLModel(基于 SQLAlchemy 和 Pydantic)进行数据库操作。
SQLModel 简介
SQLModel 是由 FastAPI 作者创建的 ORM 库,它结合了 SQLAlchemy 和 Pydantic 的优势:
- 基于 Pydantic:模型定义即是数据验证
- 基于 SQLAlchemy:强大的 SQL 支持
- 与 FastAPI 完美集成
- 支持多种数据库:PostgreSQL、MySQL、SQLite 等
安装
pip install sqlmodel
# 或安装异步支持
pip install "sqlmodel[asyncio]"
数据库连接
SQLite
from sqlmodel import SQLModel, create_engine
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = create_engine(sqlite_url, echo=True) # echo=True 打印 SQL 语句
对于 SQLite,需要添加 check_same_thread 参数:
engine = create_engine(
sqlite_url,
connect_args={"check_same_thread": False}
)
PostgreSQL
from sqlmodel import create_engine
postgres_url = "postgresql://user:password@localhost:5432/mydatabase"
engine = create_engine(postgres_url)
MySQL
mysql_url = "mysql+pymysql://user:password@localhost:3306/mydatabase"
engine = create_engine(mysql_url)
异步连接
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
async_engine = create_async_engine(
"postgresql+asyncpg://user:password@localhost/db",
echo=True
)
AsyncSessionLocal = sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False
)
定义模型
基本模型
from typing import Optional
from sqlmodel import Field, SQLModel
class Hero(SQLModel, table=True):
"""英雄模型"""
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True) # 创建索引
secret_name: str
age: Optional[int] = Field(default=None, index=True)
table=True 表示这是一个数据库表模型,而不仅仅是 Pydantic 模型。
字段类型
from datetime import datetime
from decimal import Decimal
from sqlmodel import Field, SQLModel
from typing import Optional
class Product(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=100)
description: Optional[str] = Field(default=None)
price: Decimal = Field(decimal_places=2)
quantity: int = Field(default=0)
is_active: bool = Field(default=True)
created_at: datetime = Field(default_factory=datetime.now)
# 外键
category_id: Optional[int] = Field(default=None, foreign_key="category.id")
关系定义
from typing import List, Optional
from sqlmodel import Field, Relationship, SQLModel
class Category(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
description: Optional[str] = None
# 一对多关系
products: List["Product"] = Relationship(back_populates="category")
class Product(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
price: float
category_id: Optional[int] = Field(default=None, foreign_key="category.id")
# 关系属性
category: Optional[Category] = Relationship(back_populates="products")
# 多对多关系
class HeroTeamLink(SQLModel, table=True):
team_id: Optional[int] = Field(default=None, foreign_key="team.id", primary_key=True)
hero_id: Optional[int] = Field(default=None, foreign_key="hero.id", primary_key=True)
class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
heroes: List["Hero"] = Relationship(back_populates="teams", link_model=HeroTeamLink)
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
teams: List[Team] = Relationship(back_populates="heroes", link_model=HeroTeamLink)
数据库会话管理
使用依赖注入
from typing import Annotated
from fastapi import FastAPI, Depends
from sqlmodel import Session, SQLModel, create_engine
app = FastAPI()
engine = create_engine("sqlite:///database.db")
def get_session():
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
创建表
@app.on_event("startup")
def on_startup():
SQLModel.metadata.create_all(engine)
Lifespan 方式(推荐)
from contextlib import asynccontextmanager
from fastapi import FastAPI
from sqlmodel import SQLModel, create_engine
engine = create_engine("sqlite:///database.db")
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时创建表
SQLModel.metadata.create_all(engine)
yield
# 关闭时清理资源
app = FastAPI(lifespan=lifespan)
CRUD 操作
创建(Create)
from fastapi import FastAPI
from sqlmodel import Session
app = FastAPI()
@app.post("/heroes/", response_model=Hero)
def create_hero(hero: Hero, session: SessionDep):
session.add(hero)
session.commit()
session.refresh(hero) # 刷新以获取数据库生成的字段
return hero
读取(Read)
from sqlmodel import select
@app.get("/heroes/", response_model=list[Hero])
def read_heroes(
session: SessionDep,
offset: int = 0,
limit: int = Query(default=100, le=100)
):
heroes = session.exec(select(Hero).offset(offset).limit(limit)).all()
return heroes
@app.get("/heroes/{hero_id}", response_model=Hero)
def read_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
return hero
更新(Update)
from fastapi import HTTPException
@app.patch("/heroes/{hero_id}", response_model=Hero)
def update_hero(hero_id: int, hero: HeroUpdate, session: SessionDep):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.model_dump(exclude_unset=True) # 只包含设置的字段
for key, value in hero_data.items():
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
删除(Delete)
@app.delete("/heroes/{hero_id}")
def delete_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
return {"ok": True}
模型分离
实际项目中,通常将模型分为不同的用途:
from datetime import datetime
from typing import Optional
from sqlmodel import Field, SQLModel
# 基础模型:包含公共字段
class HeroBase(SQLModel):
name: str
secret_name: str
age: Optional[int] = None
# 创建模型:用于接收创建请求
class HeroCreate(HeroBase):
pass
# 更新模型:用于接收更新请求,所有字段可选
class HeroUpdate(SQLModel):
name: Optional[str] = None
secret_name: Optional[str] = None
age: Optional[int] = None
# 数据库模型:映射到数据库表
class Hero(HeroBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=datetime.now)
# 读取模型:用于返回响应
class HeroRead(HeroBase):
id: int
created_at: datetime
使用示例:
@app.post("/heroes/", response_model=HeroRead)
def create_hero(hero: HeroCreate, session: SessionDep):
db_hero = Hero.model_validate(hero) # 转换为数据库模型
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
@app.patch("/heroes/{hero_id}", response_model=HeroRead)
def update_hero(hero_id: int, hero: HeroUpdate, session: SessionDep):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.model_dump(exclude_unset=True)
for key, value in hero_data.items():
setattr(db_hero, key, value)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
查询操作
基本查询
from sqlmodel import select
# 查询所有
heroes = session.exec(select(Hero)).all()
# 条件查询
young_heroes = session.exec(
select(Hero).where(Hero.age < 30)
).all()
# 排序
sorted_heroes = session.exec(
select(Hero).order_by(Hero.age.desc())
).all()
# 分页
page_heroes = session.exec(
select(Hero).offset(0).limit(10)
).all()
复杂条件
from sqlmodel import or_, and_, col
# OR 条件
heroes = session.exec(
select(Hero).where(or_(Hero.age < 20, Hero.age > 50))
).all()
# AND 条件
heroes = session.exec(
select(Hero).where(
and_(
Hero.age >= 20,
Hero.age <= 50,
Hero.name.like("%man%")
)
)
).all()
# IN 条件
heroes = session.exec(
select(Hero).where(col(Hero.name).in_(["Spider-Man", "Iron Man"]))
).all()
# 模糊查询
heroes = session.exec(
select(Hero).where(Hero.name.ilike("%spider%")) # 不区分大小写
).all()
关联查询
# 预加载关联数据
from sqlmodel import select
from sqlalchemy.orm import selectinload
heroes_with_teams = session.exec(
select(Hero).options(selectinload(Hero.teams))
).all()
聚合查询
from sqlalchemy import func
# 统计数量
count = session.exec(select(func.count(Hero.id))).one()
# 分组统计
result = session.exec(
select(Hero.age, func.count(Hero.id)).group_by(Hero.age)
).all()
# 聚合函数
stats = session.exec(
select(
func.min(Hero.age),
func.max(Hero.age),
func.avg(Hero.age)
)
).one()
完整示例
from datetime import datetime
from typing import Annotated, Optional
from fastapi import FastAPI, HTTPException, Query, Depends
from pydantic import BaseModel
from sqlmodel import Field, Session, SQLModel, create_engine, select
# 数据库配置
sqlite_url = "sqlite:///heroes.db"
engine = create_engine(sqlite_url, connect_args={"check_same_thread": False})
# 模型定义
class HeroBase(SQLModel):
name: str = Field(min_length=1, max_length=100)
secret_name: str
age: Optional[int] = Field(default=None, ge=0)
class HeroCreate(HeroBase):
pass
class HeroUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=100)
secret_name: Optional[str] = None
age: Optional[int] = Field(None, ge=0)
class Hero(HeroBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
class HeroRead(HeroBase):
id: int
created_at: datetime
updated_at: datetime
class HeroList(BaseModel):
items: list[HeroRead]
total: int
page: int
page_size: int
# 会话依赖
def get_session():
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
# 应用
app = FastAPI(title="Hero API")
@app.on_event("startup")
def on_startup():
SQLModel.metadata.create_all(engine)
# 路由
@app.post("/heroes/", response_model=HeroRead, status_code=201)
def create_hero(hero: HeroCreate, session: SessionDep):
db_hero = Hero.model_validate(hero)
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
@app.get("/heroes/", response_model=HeroList)
def read_heroes(
session: SessionDep,
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: Optional[str] = None
):
offset = (page - 1) * page_size
# 构建查询
query = select(Hero)
count_query = select(func.count(Hero.id))
if name:
query = query.where(Hero.name.ilike(f"%{name}%"))
count_query = count_query.where(Hero.name.ilike(f"%{name}%"))
# 获取数据
heroes = session.exec(query.offset(offset).limit(page_size)).all()
total = session.exec(count_query).one()
return HeroList(
items=heroes,
total=total,
page=page,
page_size=page_size
)
@app.get("/heroes/{hero_id}", response_model=HeroRead)
def read_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
return hero
@app.patch("/heroes/{hero_id}", response_model=HeroRead)
def update_hero(hero_id: int, hero: HeroUpdate, session: SessionDep):
db_hero = session.get(Hero, hero_id)
if not db_hero:
raise HTTPException(status_code=404, detail="Hero not found")
hero_data = hero.model_dump(exclude_unset=True)
for key, value in hero_data.items():
setattr(db_hero, key, value)
db_hero.updated_at = datetime.now()
session.add(db_hero)
session.commit()
session.refresh(db_hero)
return db_hero
@app.delete("/heroes/{hero_id}")
def delete_hero(hero_id: int, session: SessionDep):
hero = session.get(Hero, hero_id)
if not hero:
raise HTTPException(status_code=404, detail="Hero not found")
session.delete(hero)
session.commit()
return {"message": "Hero deleted successfully"}
数据库迁移
使用 Alembic
# 安装 Alembic
pip install alembic
# 初始化
alembic init alembic
# 配置 alembic.ini 和 env.py
在 alembic/env.py 中配置:
from sqlmodel import SQLModel
from models import * # 导入所有模型
target_metadata = SQLModel.metadata
常用命令:
# 自动生成迁移文件
alembic revision --autogenerate -m "Add hero table"
# 执行迁移
alembic upgrade head
# 回退迁移
alembic downgrade -1
小结
本章我们学习了:
- 数据库连接:SQLite、PostgreSQL、MySQL 的连接方式
- 模型定义:使用 SQLModel 定义数据库表
- 关系映射:一对一、一对多、多对多关系
- CRUD 操作:增删改查的完整实现
- 模型分离:创建、更新、读取模型的分离
- 查询操作:条件查询、排序、分页、聚合
- 数据库迁移:使用 Alembic 管理数据库版本
SQLModel 的优势:
- 与 Pydantic 无缝集成
- 类型安全的查询
- 自动生成 OpenAPI 文档
- 支持 SQLAlchemy 的所有功能
练习
- 创建一个用户管理 API,包含注册、登录、个人信息修改功能
- 实现商品和分类的一对多关系,支持按分类查询商品
- 实现分页查询,包含总数统计
- 使用 Alembic 添加数据库迁移支持