跳到主要内容

RAG 实战实现

本章将提供一个完整的、生产级别的 RAG 系统实现,涵盖文档处理、索引构建、检索优化、API 服务等核心模块。

项目结构

rag-project/
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI 入口
│ ├── config.py # 配置管理
│ ├── document_processor.py # 文档处理
│ ├── embedder.py # 嵌入模型
│ ├── retriever.py # 检索器
│ ├── reranker.py # 重排序
│ ├── generator.py # 生成器
│ └── rag_pipeline.py # RAG 流程
├── data/
│ ├── raw/ # 原始文档
│ └── processed/ # 处理后的数据
├── tests/
├── requirements.txt
└── docker-compose.yml

配置管理

# app/config.py
from pydantic_settings import BaseSettings
from functools import lru_cache

class Settings(BaseSettings):
# OpenAI
openai_api_key: str
embedding_model: str = "text-embedding-3-small"
llm_model: str = "gpt-4o-mini"

# 向量数据库
vector_db_type: str = "chroma" # chroma, pinecone, milvus
chroma_persist_dir: str = "./data/chroma"

# 检索参数
chunk_size: int = 800
chunk_overlap: int = 200
retrieval_top_k: int = 20
rerank_top_k: int = 5

# 重排序
reranker_model: str = "BAAI/bge-reranker-large"
use_reranking: bool = True

class Config:
env_file = ".env"

@lru_cache()
def get_settings():
return Settings()

文档处理器

# app/document_processor.py
from typing import List
from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
UnstructuredMarkdownLoader,
DirectoryLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
import hashlib

class DocumentProcessor:
def __init__(self, chunk_size: int = 800, chunk_overlap: int = 200):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", "。", "!", "?", ";", " ", ""]
)

def load_pdf(self, file_path: str) -> List[Document]:
"""加载 PDF 文件"""
loader = PyPDFLoader(file_path)
return loader.load()

def load_directory(self, dir_path: str, glob_pattern: str = "**/*.pdf") -> List[Document]:
"""批量加载目录下的文档"""
loader = DirectoryLoader(
dir_path,
glob=glob_pattern,
loader_cls=PyPDFLoader
)
return loader.load()

def split_documents(self, documents: List[Document]) -> List[Document]:
"""分块文档"""
chunks = self.text_splitter.split_documents(documents)

# 添加元数据
for i, chunk in enumerate(chunks):
chunk.metadata["chunk_id"] = self._generate_chunk_id(chunk)
chunk.metadata["chunk_index"] = i

return chunks

def _generate_chunk_id(self, chunk: Document) -> str:
"""生成唯一的 chunk ID"""
content = chunk.page_content
source = chunk.metadata.get("source", "")
page = chunk.metadata.get("page", 0)

unique_string = f"{source}:{page}:{content[:100]}"
return hashlib.md5(unique_string.encode()).hexdigest()

def process_files(self, file_paths: List[str]) -> List[Document]:
"""处理多个文件"""
all_documents = []

for file_path in file_paths:
if file_path.endswith(".pdf"):
docs = self.load_pdf(file_path)
elif file_path.endswith(".txt"):
docs = TextLoader(file_path).load()
elif file_path.endswith(".md"):
docs = UnstructuredMarkdownLoader(file_path).load()
else:
continue

# 添加来源元数据
for doc in docs:
doc.metadata["source"] = file_path

all_documents.extend(docs)

# 分块
return self.split_documents(all_documents)

嵌入模型封装

# app/embedder.py
from typing import List
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
import numpy as np

class Embedder:
def __init__(self, model_name: str = "text-embedding-3-small", use_local: bool = False):
self.model_name = model_name
self.use_local = use_local

if use_local:
# 本地模型(如 BGE)
self.model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cuda'},
encode_kwargs={'normalize_embeddings': True}
)
else:
# OpenAI API
self.model = OpenAIEmbeddings(model=model_name)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""批量嵌入文档"""
return self.model.embed_documents(texts)

def embed_query(self, query: str) -> List[float]:
"""嵌入查询"""
return self.model.embed_query(query)

@staticmethod
def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算余弦相似度"""
a = np.array(vec1)
b = np.array(vec2)
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

向量存储管理

# app/vector_store.py
from typing import List, Optional, Dict, Any
from langchain_community.vectorstores import Chroma
from langchain.schema import Document
import chromadb

class VectorStoreManager:
def __init__(
self,
embedding_function,
persist_directory: str = "./data/chroma",
collection_name: str = "documents"
):
self.embedding_function = embedding_function
self.persist_directory = persist_directory
self.collection_name = collection_name

# 初始化 Chroma
self.client = chromadb.PersistentClient(path=persist_directory)
self.vectorstore = Chroma(
client=self.client,
collection_name=collection_name,
embedding_function=embedding_function
)

def add_documents(self, documents: List[Document]) -> None:
"""添加文档到向量库"""
# 提取内容
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
ids = [doc.metadata.get("chunk_id", str(i)) for i, doc in enumerate(documents)]

# 添加
self.vectorstore.add_texts(
texts=texts,
metadatas=metadatas,
ids=ids
)

def similarity_search(
self,
query: str,
k: int = 5,
filter: Optional[Dict[str, Any]] = None
) -> List[Document]:
"""相似度搜索"""
return self.vectorstore.similarity_search(
query=query,
k=k,
filter=filter
)

def similarity_search_with_score(
self,
query: str,
k: int = 5
) -> List[tuple]:
"""带分数的相似度搜索"""
return self.vectorstore.similarity_search_with_score(query=query, k=k)

def delete_by_ids(self, ids: List[str]) -> None:
"""删除指定文档"""
self.vectorstore._collection.delete(ids=ids)

def get_retriever(self, search_kwargs: Optional[Dict] = None):
"""获取检索器"""
return self.vectorstore.as_retriever(
search_kwargs=search_kwargs or {"k": 5}
)

def get_collection_count(self) -> int:
"""获取文档数量"""
return self.vectorstore._collection.count()

重排序器

# app/reranker.py
from typing import List, Tuple
from langchain.schema import Document
from sentence_transformers import CrossEncoder
import numpy as np

class Reranker:
def __init__(self, model_name: str = "BAAI/bge-reranker-large", device: str = "cpu"):
self.model = CrossEncoder(model_name, device=device)

def rerank(
self,
query: str,
documents: List[Document],
top_k: int = 5,
min_score: float = 0.0
) -> List[Tuple[Document, float]]:
"""重排序文档"""
# 构建输入对
pairs = [(query, doc.page_content) for doc in documents]

# 计算分数
scores = self.model.predict(pairs)

# 归一化到 [0, 1]
scores = self._normalize_scores(scores)

# 合并并排序
results = list(zip(documents, scores))
results.sort(key=lambda x: x[1], reverse=True)

# 过滤低分结果
results = [(doc, score) for doc, score in results if score >= min_score]

return results[:top_k]

def _normalize_scores(self, scores: np.ndarray) -> np.ndarray:
"""归一化分数"""
min_score = scores.min()
max_score = scores.max()
if max_score == min_score:
return np.ones_like(scores)
return (scores - min_score) / (max_score - min_score)

RAG 流程

# app/rag_pipeline.py
from typing import List, Dict, Any, Optional
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from .config import get_settings
from .document_processor import DocumentProcessor
from .embedder import Embedder
from .vector_store import VectorStoreManager
from .reranker import Reranker

class RAGPipeline:
def __init__(self):
self.settings = get_settings()

# 初始化组件
self.embedder = Embedder(
model_name=self.settings.embedding_model
)

self.vector_store = VectorStoreManager(
embedding_function=self.embedder.model,
persist_directory=self.settings.chroma_persist_dir
)

self.doc_processor = DocumentProcessor(
chunk_size=self.settings.chunk_size,
chunk_overlap=self.settings.chunk_overlap
)

self.reranker = Reranker(
model_name=self.settings.reranker_model
) if self.settings.use_reranking else None

self.llm = ChatOpenAI(
model=self.settings.llm_model,
temperature=0
)

self._setup_chain()

def _setup_chain(self):
"""设置 LLM 链"""
prompt = ChatPromptTemplate.from_messages([
("system", """你是一个智能助手。请根据以下参考文档回答问题。
如果参考文档中没有相关信息,请明确告知用户,不要编造答案。

参考文档:
{context}"""),
("human", "{question}")
])

self.chain = (
{"context": RunnablePassthrough(), "question": RunnablePassthrough()}
| prompt
| self.llm
| StrOutputParser()
)

def index_documents(self, file_paths: List[str]) -> Dict[str, Any]:
"""索引文档"""
# 处理文档
chunks = self.doc_processor.process_files(file_paths)

# 存储到向量库
self.vector_store.add_documents(chunks)

return {
"indexed_chunks": len(chunks),
"total_documents": len(file_paths)
}

def retrieve(
self,
query: str,
k: int = None,
rerank: bool = True
) -> List[Document]:
"""检索相关文档"""
k = k or self.settings.retrieval_top_k

# 初步检索
documents = self.vector_store.similarity_search(query, k=k)

# 重排序
if rerank and self.reranker:
reranked = self.reranker.rerank(
query=query,
documents=documents,
top_k=self.settings.rerank_top_k
)
return [doc for doc, score in reranked]

return documents[:self.settings.rerank_top_k]

def query(
self,
question: str,
return_sources: bool = False
) -> Dict[str, Any]:
"""查询并生成回答"""
# 检索
documents = self.retrieve(question)

# 组装上下文
context = "\n\n---\n\n".join([doc.page_content for doc in documents])

# 生成回答
answer = self.chain.invoke({
"context": context,
"question": question
})

result = {"answer": answer}

if return_sources:
result["sources"] = [
{
"content": doc.page_content[:200] + "...",
"metadata": doc.metadata
}
for doc in documents
]

return result

def query_stream(self, question: str):
"""流式查询"""
documents = self.retrieve(question)
context = "\n\n---\n\n".join([doc.page_content for doc in documents])

for chunk in self.chain.stream({
"context": context,
"question": question
}):
yield chunk

API 服务

# app/main.py
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import tempfile
import os

from .rag_pipeline import RAGPipeline
from .config import get_settings

app = FastAPI(title="RAG API", version="1.0.0")

# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)

# 初始化 RAG 流程
rag = RAGPipeline()

class QueryRequest(BaseModel):
question: str
return_sources: bool = False

class QueryResponse(BaseModel):
answer: str
sources: Optional[List[dict]] = None

class IndexResponse(BaseModel):
indexed_chunks: int
total_documents: int

@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
"""查询接口"""
try:
result = rag.query(
question=request.question,
return_sources=request.return_sources
)
return QueryResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query/stream")
async def query_stream(request: QueryRequest):
"""流式查询接口"""
from fastapi.responses import StreamingResponse

def generate():
for chunk in rag.query_stream(request.question):
yield chunk

return StreamingResponse(generate(), media_type="text/event-stream")

@app.post("/index", response_model=IndexResponse)
async def index_documents(files: List[UploadFile] = File(...)):
"""索引文档接口"""
try:
# 保存上传的文件
file_paths = []
for file in files:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
tmp.write(await file.read())
file_paths.append(tmp.name)

# 索引
result = rag.index_documents(file_paths)

# 清理临时文件
for path in file_paths:
os.unlink(path)

return IndexResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health():
"""健康检查"""
return {"status": "healthy", "documents": rag.vector_store.get_collection_count()}

Docker 部署

# docker-compose.yml
version: '3.8'

services:
rag-api:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- EMBEDDING_MODEL=text-embedding-3-small
- LLM_MODEL=gpt-4o-mini
volumes:
- ./data:/app/data
restart: unless-stopped

# Dockerfile
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

使用示例

# 示例:使用 RAG 系统
from app.rag_pipeline import RAGPipeline

# 初始化
rag = RAGPipeline()

# 索引文档
result = rag.index_documents(["./docs/manual.pdf", "./docs/policy.pdf"])
print(f"索引了 {result['indexed_chunks']} 个文档块")

# 查询
response = rag.query(
question="公司的年假制度是怎样的?",
return_sources=True
)

print("回答:", response["answer"])
print("\n来源:")
for source in response["sources"]:
print(f"- {source['metadata']['source']}")

小结

本章实现了一个完整的生产级 RAG 系统:

  1. 文档处理:支持多种格式,智能分块
  2. 嵌入模型:支持 OpenAI 和本地模型
  3. 向量存储:使用 Chroma,易于部署
  4. 重排序:提升检索精度
  5. API 服务:RESTful 接口,支持流式输出
  6. Docker 部署:一键部署

下一步

参考资料