跳到主要内容

RAG 检索增强生成

RAG(Retrieval Augmented Generation,检索增强生成)是将私有数据与 AI 模型结合的关键技术。本章介绍 Spring AI 中的 RAG 实现。

什么是 RAG?

RAG 解决了一个核心问题:如何让 AI 模型回答它训练数据中没有的问题。

┌─────────────────────────────────────────────────────────────┐
│ RAG 工作流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 用户提问 │───>│ 向量搜索 │───>│ 检索相关 │ │
│ └─────────────┘ └─────────────┘ │ 文档片段 │ │
│ │ └──────┬──────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Embedding │ │ 构建增强 │ │
│ │ 模型 │ │ 提示词 │ │
│ └─────────────┘ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ AI 模型 │ │
│ │ 生成答案 │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘

核心组件

1. 文档加载(ETL)

将非结构化文档转换为可检索的数据:

@Service
public class DocumentLoader {

@Autowired
private VectorStore vectorStore;

@Autowired
private TikaDocumentReader tikaDocumentReader;

@Autowired
private TokenTextSplitter textSplitter;

public void loadDocument(Resource resource) {
// 1. 读取文档
TikaDocumentReader reader = new TikaDocumentReader(resource);
List<Document> documents = reader.get();

// 2. 分割文档
List<Document> splitDocuments = textSplitter.apply(documents);

// 3. 存储到向量数据库
vectorStore.add(splitDocuments);
}

public void loadDirectory(String directoryPath) throws IOException {
Files.walk(Paths.get(directoryPath))
.filter(Files::isRegularFile)
.filter(path -> path.toString().endsWith(".pdf") ||
path.toString().endsWith(".txt") ||
path.toString().endsWith(".md"))
.forEach(path -> {
loadDocument(new FileSystemResource(path));
});
}
}

2. 文本分割器

@Configuration
public class RagConfig {

@Bean
public TokenTextSplitter textSplitter() {
return new TokenTextSplitter(
512, // 默认块大小
128, // 默认重叠大小
5, // 最小块大小
10000, // 最大块大小
true // 保持段落完整
);
}
}

3. 向量存储

@Configuration
public class VectorStoreConfig {

@Bean
public VectorStore vectorStore(EmbeddingModel embeddingModel) {
return SimpleVectorStore.builder(embeddingModel).build();
}
}

使用 QuestionAnswerAdvisor

Spring AI 提供了 QuestionAnswerAdvisor 简化 RAG 实现:

基本配置

@Service
public class RagService {

private final ChatClient chatClient;

public RagService(ChatClient.Builder builder, VectorStore vectorStore) {
this.chatClient = builder
.defaultAdvisors(new QuestionAnswerAdvisor(vectorStore))
.build();
}

public String ask(String question) {
return chatClient.prompt()
.user(question)
.call()
.content();
}
}

自定义检索参数

public String askWithContext(String question, int topK) {
return chatClient.prompt()
.user(question)
.advisors(advisor -> advisor
.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "category == 'technical'")
.param(QuestionAnswerAdvisor.SEARCH_TOP_K, topK))
.call()
.content();
}

完整 RAG 示例

1. 配置类

@Configuration
public class RagConfiguration {

@Bean
public TokenTextSplitter textSplitter() {
return new TokenTextSplitter(
512, // 块大小
128, // 重叠大小
5, // 最小块大小
10000, // 最大块大小
true // 保持段落完整
);
}

@Bean
public ChatClient ragChatClient(ChatClient.Builder builder, VectorStore vectorStore) {
return builder
.defaultSystem("""
你是一个专业的知识库助手。
请根据提供的上下文信息回答用户问题。
如果上下文中没有相关信息,请诚实说明不知道。
回答时要引用来源。
""")
.defaultAdvisors(new QuestionAnswerAdvisor(vectorStore))
.build();
}
}

2. 文档服务

@Service
public class DocumentService {

@Autowired
private VectorStore vectorStore;

@Autowired
private TokenTextSplitter textSplitter;

/**
* 添加文本到知识库
*/
public void addText(String content, Map<String, Object> metadata) {
Document document = new Document(content, metadata);
vectorStore.add(List.of(document));
}

/**
* 添加文档文件
*/
public void addDocument(Resource resource, Map<String, Object> metadata) {
TikaDocumentReader reader = new TikaDocumentReader(resource);
List<Document> documents = reader.get();

// 添加元数据
documents.forEach(doc -> doc.getMetadata().putAll(metadata));

// 分割文档
List<Document> splitDocuments = textSplitter.apply(documents);

// 存储
vectorStore.add(splitDocuments);
}

/**
* 搜索相似文档
*/
public List<Document> search(String query, int topK) {
return vectorStore.similaritySearch(
SearchRequest.query(query).withTopK(topK)
);
}

/**
* 删除文档
*/
public void delete(List<String> ids) {
vectorStore.delete(ids);
}
}

3. 问答服务

@Service
public class QaService {

private final ChatClient chatClient;
private final VectorStore vectorStore;

public QaService(ChatClient.Builder builder, VectorStore vectorStore) {
this.vectorStore = vectorStore;
this.chatClient = builder
.defaultSystem("""
你是一个专业的知识库助手。
请根据提供的上下文回答用户问题。
如果上下文中没有相关信息,请说明。
""")
.defaultAdvisors(new QuestionAnswerAdvisor(vectorStore))
.build();
}

public Answer ask(String question) {
ChatResponse response = chatClient.prompt()
.user(question)
.call()
.chatResponse();

return new Answer(
response.getResult().getOutput().getContent(),
response.getMetadata().getUsage().getTotalTokens()
);
}

public Flux<String> askStream(String question) {
return chatClient.prompt()
.user(question)
.stream()
.content();
}
}

record Answer(String content, int tokens) {}

4. 控制器

@RestController
@RequestMapping("/api/rag")
public class RagController {

@Autowired
private DocumentService documentService;

@Autowired
private QaService qaService;

/**
* 上传文档
*/
@PostMapping("/documents")
public ResponseEntity<?> uploadDocument(
@RequestParam("file") MultipartFile file,
@RequestParam(required = false) String category) throws IOException {

Map<String, Object> metadata = new HashMap<>();
metadata.put("filename", file.getOriginalFilename());
metadata.put("category", category);

documentService.addDocument(
new ByteArrayResource(file.getBytes()),
metadata
);

return ResponseEntity.ok(Map.of("message", "文档上传成功"));
}

/**
* 问答接口
*/
@PostMapping("/ask")
public Answer ask(@RequestBody Question question) {
return qaService.ask(question.text());
}

/**
* 流式问答接口
*/
@GetMapping(value = "/ask/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> askStream(@RequestParam String question) {
return qaService.askStream(question);
}

/**
* 搜索文档
*/
@GetMapping("/search")
public List<DocumentResult> search(@RequestParam String query,
@RequestParam(defaultValue = "5") int topK) {
List<Document> documents = documentService.search(query, topK);
return documents.stream()
.map(doc -> new DocumentResult(
doc.getContent(),
doc.getMetadata()
))
.toList();
}
}

record Question(String text) {}
record DocumentResult(String content, Map<String, Object> metadata) {}

元数据过滤

添加过滤条件

@GetMapping("/ask/category")
public String askByCategory(
@RequestParam String question,
@RequestParam String category) {

return chatClient.prompt()
.user(question)
.advisors(advisor -> advisor
.param(QuestionAnswerAdvisor.FILTER_EXPRESSION,
"category == '" + category + "'"))
.call()
.content();
}

复杂过滤表达式

// 多条件过滤
String filterExpression = "category == 'technical' && year >= 2023";

// 或条件
String filterExpression = "category == 'java' || category == 'spring'";

// 模糊匹配
String filterExpression = "title like '%Spring%'";

// 使用过滤
return chatClient.prompt()
.user(question)
.advisors(advisor -> advisor
.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, filterExpression))
.call()
.content();

向量数据库配置

PGVector

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
</dependency>
spring:
datasource:
url: jdbc:postgresql://localhost:5432/ai_db
username: postgres
password: postgres
ai:
vectorstore:
pgvector:
index-type: HNSW
distance-type: COSINE_DISTANCE
dimensions: 1536

Redis

spring:
ai:
vectorstore:
redis:
uri: redis://localhost:6379
index: knowledge-base
prefix: "doc:"

Milvus

spring:
ai:
vectorstore:
milvus:
client:
host: localhost
port: 19530
databaseName: default
collectionName: documents
dimension: 1536

最佳实践

1. 合理分割文档

// 根据文档类型调整分割策略
@Bean
public TokenTextSplitter textSplitter() {
return new TokenTextSplitter(
800, // 较大的块保留更多上下文
200, // 足够的重叠保持连贯性
5, // 最小块大小
10000, // 最大块大小
true // 保持段落完整
);
}

2. 添加丰富的元数据

Map<String, Object> metadata = new HashMap<>();
metadata.put("source", "company-wiki");
metadata.put("category", "technical");
metadata.put("author", "team-a");
metadata.put("created_at", LocalDateTime.now().toString());
metadata.put("tags", List.of("java", "spring", "backend"));

Document document = new Document(content, metadata);

3. 控制检索数量

// 根据问题复杂度调整检索数量
int topK = question.length() > 100 ? 10 : 5;

return chatClient.prompt()
.user(question)
.advisors(advisor -> advisor
.param(QuestionAnswerAdvisor.SEARCH_TOP_K, topK))
.call()
.content();

小结

本章我们学习了:

  1. RAG 概念:检索增强生成的工作原理
  2. 核心组件:文档加载、分割、向量存储
  3. QuestionAnswerAdvisor:简化 RAG 实现
  4. 元数据过滤:精准检索相关文档
  5. 向量数据库:配置不同的存储后端

练习

  1. 构建一个企业知识库问答系统
  2. 实现基于文档分类的检索过滤
  3. 添加文档上传和管理功能

参考资源