跳到主要内容

WebSocket

WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议。FastAPI 原生支持 WebSocket,可以轻松构建实时通信应用,如聊天室、实时通知、在线协作等。

什么是 WebSocket

传统的 HTTP 请求是单向的:客户端发送请求,服务器返回响应。每次通信都需要建立新的连接。

WebSocket 则不同:

  • 全双工通信:客户端和服务器可以同时发送消息
  • 持久连接:一次握手后,连接保持打开状态
  • 低延迟:无需重复建立连接,消息实时传递

基本使用

安装依赖

pip install websockets

创建 WebSocket 端点

from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse

app = FastAPI()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
# 接受 WebSocket 连接
await websocket.accept()

# 持续接收和发送消息
while True:
# 接收文本消息
data = await websocket.receive_text()
# 发送文本消息
await websocket.send_text(f"收到消息: {data}")

完整示例

下面是一个简单的聊天应用示例:

from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse

app = FastAPI()

# 简单的 HTML 客户端(仅用于演示)
html = """
<!DOCTYPE html>
<html>
<head>
<title>WebSocket 聊天</title>
</head>
<body>
<h1>WebSocket 聊天室</h1>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="messageText" autocomplete="off"/>
<button>发送</button>
</form>
<ul id='messages'></ul>
<script>
var ws = new WebSocket("ws://localhost:8000/ws");

ws.onmessage = function(event) {
var messages = document.getElementById('messages');
var message = document.createElement('li');
var content = document.createTextNode(event.data);
message.appendChild(content);
messages.appendChild(message);
};

function sendMessage(event) {
var input = document.getElementById("messageText");
ws.send(input.value);
input.value = '';
event.preventDefault();
}
</script>
</body>
</html>
"""

@app.get("/")
async def get():
"""返回聊天页面"""
return HTMLResponse(html)

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket 端点"""
await websocket.accept()

while True:
# 接收客户端消息
data = await websocket.receive_text()
# 回送消息
await websocket.send_text(f"你说: {data}")

运行应用后,访问 http://localhost:8000,你将看到一个简单的聊天界面。发送的每条消息都会被服务器回显。

消息类型

WebSocket 支持多种消息类型:

文本消息

# 接收文本
data = await websocket.receive_text()

# 发送文本
await websocket.send_text("Hello, World!")

二进制消息

# 接收二进制数据
data = await websocket.receive_bytes()

# 发送二进制数据
await websocket.send_bytes(b"Binary data")

JSON 消息

import json

# 接收 JSON
data = await websocket.receive_json()

# 发送 JSON
await websocket.send_json({"message": "Hello", "count": 42})

接收任意类型

# 接收消息,自动判断类型
message = await websocket.receive()

# message 是一个字典,包含:
# - {"type": "websocket.receive", "text": "..."} 文本消息
# - {"type": "websocket.receive", "bytes": b"..."} 二进制消息
# - {"type": "websocket.disconnect"} 断开连接

if "text" in message:
text = message["text"]
elif "bytes" in message:
data = message["bytes"]

连接管理

处理断开连接

当客户端断开连接时,receive_text() 会抛出 WebSocketDisconnect 异常:

from fastapi import FastAPI, WebSocket, WebSocketDisconnect

app = FastAPI()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"收到: {data}")
except WebSocketDisconnect:
print("客户端断开连接")

多客户端管理

对于聊天室等应用,需要管理多个 WebSocket 连接:

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from typing import List

app = FastAPI()

class ConnectionManager:
"""WebSocket 连接管理器"""

def __init__(self):
# 存储所有活跃的连接
self.active_connections: List[WebSocket] = []

async def connect(self, websocket: WebSocket):
"""接受新连接并添加到列表"""
await websocket.accept()
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket):
"""从列表中移除连接"""
self.active_connections.remove(websocket)

async def send_personal_message(self, message: str, websocket: WebSocket):
"""发送私人消息给特定客户端"""
await websocket.send_text(message)

async def broadcast(self, message: str):
"""广播消息给所有客户端"""
for connection in self.active_connections:
await connection.send_text(message)

# 创建管理器实例
manager = ConnectionManager()

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
# 接受连接
await manager.connect(websocket)

try:
while True:
data = await websocket.receive_text()

# 发送私人消息
await manager.send_personal_message(f"你说: {data}", websocket)

# 广播给其他客户端
await manager.broadcast(f"用户 {client_id} 说: {data}")

except WebSocketDisconnect:
# 客户端断开时清理连接
manager.disconnect(websocket)
await manager.broadcast(f"用户 {client_id} 离开了聊天室")

连接状态

from fastapi import WebSocketState

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

# 检查连接状态
if websocket.client_state == WebSocketState.CONNECTED:
print("客户端已连接")

if websocket.application_state == WebSocketState.CONNECTED:
print("应用已连接")

与依赖注入结合

WebSocket 端点可以使用 FastAPI 的依赖注入系统:

from typing import Annotated
from fastapi import (
Cookie,
Depends,
FastAPI,
Query,
WebSocket,
WebSocketException,
status,
)

app = FastAPI()

async def get_token(
websocket: WebSocket,
token: Annotated[str | None, Query()] = None,
):
"""验证 token 的依赖"""
if token is None or token != "secret-token":
# WebSocket 中不能使用 HTTPException
# 使用 WebSocketException 代替
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
return token

@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: Annotated[str, Depends(get_token)],
):
await websocket.accept()

while True:
data = await websocket.receive_text()
await websocket.send_text(f"Token: {token}, 消息: {data}")

重要提示:在 WebSocket 中不要使用 HTTPException,应该使用 WebSocketException

支持的依赖项

WebSocket 端点支持以下依赖:

  • Depends - 依赖注入
  • Security - 安全验证
  • Cookie - Cookie 参数
  • Header - 请求头
  • Path - 路径参数
  • Query - 查询参数
@app.websocket("/ws/{room_id}")
async def websocket_endpoint(
websocket: WebSocket,
room_id: str = Path(...), # 路径参数
token: str = Query(...), # 查询参数
session: str | None = Cookie(None), # Cookie
):
await websocket.accept()
# ...

关闭连接

服务器主动关闭

from fastapi import WebSocket

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

# 接收几条消息后关闭
for _ in range(5):
data = await websocket.receive_text()
await websocket.send_text(f"收到: {data}")

# 关闭连接,可以指定关闭码
await websocket.close(code=1000, reason="完成")

关闭码

代码名称含义
1000NORMAL正常关闭
1001GOING_AWAY端点离开
1008POLICY_VIOLATION策略违规
1011INTERNAL_ERROR服务器错误

使用 status 模块:

from fastapi import status

await websocket.close(code=status.WS_1000_NORMAL_CLOSURE)

实际应用示例

实时聊天室

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from typing import Dict, List
from datetime import datetime
import json

app = FastAPI()

class ChatRoom:
"""聊天室"""

def __init__(self, name: str):
self.name = name
self.connections: List[WebSocket] = []
self.messages: List[dict] = []

async def join(self, websocket: WebSocket, username: str):
await websocket.accept()
self.connections.append(websocket)

# 发送历史消息
for msg in self.messages[-50:]: # 最近 50 条
await websocket.send_json(msg)

# 广播加入消息
await self.broadcast({
"type": "system",
"content": f"{username} 加入了聊天室",
"timestamp": datetime.now().isoformat()
})

async def leave(self, websocket: WebSocket, username: str):
self.connections.remove(websocket)
await self.broadcast({
"type": "system",
"content": f"{username} 离开了聊天室",
"timestamp": datetime.now().isoformat()
})

async def broadcast(self, message: dict):
self.messages.append(message)
for connection in self.connections:
try:
await connection.send_json(message)
except:
pass # 忽略发送失败的连接

class ChatManager:
"""聊天管理器"""

def __init__(self):
self.rooms: Dict[str, ChatRoom] = {}

def get_or_create_room(self, name: str) -> ChatRoom:
if name not in self.rooms:
self.rooms[name] = ChatRoom(name)
return self.rooms[name]

manager = ChatManager()

@app.websocket("/ws/{room_name}/{username}")
async def websocket_endpoint(
websocket: WebSocket,
room_name: str,
username: str
):
room = manager.get_or_create_room(room_name)

try:
await room.join(websocket, username)

while True:
data = await websocket.receive_text()

message = {
"type": "message",
"username": username,
"content": data,
"timestamp": datetime.now().isoformat()
}

await room.broadcast(message)

except WebSocketDisconnect:
await room.leave(websocket, username)

实时数据推送

from fastapi import FastAPI, WebSocket
import asyncio
import random

app = FastAPI()

# 存储所有订阅者
subscribers: list[WebSocket] = []

@app.websocket("/ws/prices")
async def price_stream(websocket: WebSocket):
"""实时价格推送"""
await websocket.accept()
subscribers.append(websocket)

try:
while True:
# 模拟实时价格数据
price = {
"symbol": "AAPL",
"price": round(random.uniform(150, 200), 2),
"change": round(random.uniform(-5, 5), 2),
"timestamp": datetime.now().isoformat()
}

await websocket.send_json(price)
await asyncio.sleep(1) # 每秒推送一次

except WebSocketDisconnect:
subscribers.remove(websocket)

进度通知

from fastapi import FastAPI, WebSocket, BackgroundTasks
import asyncio

app = FastAPI()

# 存储任务进度连接
progress_connections: dict[str, WebSocket] = {}

async def long_running_task(task_id: str):
"""模拟长时间运行的任务"""
for progress in range(0, 101, 10):
await asyncio.sleep(1)

# 推送进度
if task_id in progress_connections:
ws = progress_connections[task_id]
await ws.send_json({
"task_id": task_id,
"progress": progress,
"status": "processing"
})

# 完成通知
if task_id in progress_connections:
ws = progress_connections[task_id]
await ws.send_json({
"task_id": task_id,
"progress": 100,
"status": "completed"
})

@app.websocket("/ws/progress/{task_id}")
async def progress_websocket(websocket: WebSocket, task_id: str):
"""任务进度 WebSocket"""
await websocket.accept()
progress_connections[task_id] = websocket

try:
# 等待客户端消息(保持连接)
while True:
await websocket.receive_text()
except:
pass
finally:
progress_connections.pop(task_id, None)

@app.post("/tasks/start")
async def start_task(background_tasks: BackgroundTasks):
"""启动后台任务"""
task_id = str(uuid.uuid4())
background_tasks.add_task(long_running_task, task_id)
return {"task_id": task_id}

生产环境注意事项

连接存储

上面的示例使用内存存储连接。在生产环境中:

  • 单进程:内存存储可以工作
  • 多进程/多服务器:需要使用 Redis、RabbitMQ 等进行跨进程通信

推荐使用 broadcaster 库:

from broadcaster import Broadcast

broadcast = Broadcast("redis://localhost:6379")

@app.on_event("startup")
async def startup():
await broadcast.connect()

@app.on_event("shutdown")
async def shutdown():
await broadcast.disconnect()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

async with broadcast.subscribe(channel="chat") as subscriber:
async for event in subscriber:
await websocket.send_text(event.message)

心跳机制

保持连接活跃,检测断开的客户端:

import asyncio

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

# 心跳任务
async def heartbeat():
while True:
try:
await websocket.send_json({"type": "ping"})
await asyncio.sleep(30)
except:
break

asyncio.create_task(heartbeat())

try:
while True:
data = await websocket.receive_json()
if data.get("type") == "pong":
continue # 忽略心跳响应
# 处理其他消息
except WebSocketDisconnect:
pass

错误处理

from fastapi import WebSocket, WebSocketDisconnect, WebSocketException

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
try:
await websocket.accept()

while True:
try:
data = await websocket.receive_text()
# 处理消息
except WebSocketDisconnect:
print("客户端断开连接")
break
except Exception as e:
print(f"处理消息出错: {e}")
continue

except WebSocketException as e:
print(f"WebSocket 错误: {e}")
except Exception as e:
print(f"未知错误: {e}")
finally:
# 清理资源
pass

小结

本章我们学习了:

  1. WebSocket 基础:理解 WebSocket 与 HTTP 的区别
  2. 基本使用:创建、接收和发送消息
  3. 消息类型:文本、二进制、JSON 消息
  4. 连接管理:多客户端管理、断开处理
  5. 依赖注入:在 WebSocket 中使用依赖
  6. 实际应用:聊天室、实时推送、进度通知
  7. 生产部署:连接存储、心跳机制、错误处理

WebSocket 适用场景:

  • 实时聊天应用
  • 实时数据推送(股票、体育比分等)
  • 协作编辑
  • 在线游戏
  • 实时通知

练习

  1. 实现一个简单的聊天室,支持多用户同时聊天
  2. 创建一个实时数据仪表板,每秒推送随机数据
  3. 实现心跳机制,检测客户端是否存活
  4. 使用 Redis 广播实现跨进程的 WebSocket 通信