跳到主要内容

Embedding 嵌入模型

Embedding(嵌入)是将文本转换为数值向量的关键技术,是语义搜索、RAG、推荐系统等 AI 应用的基础。本章介绍 Spring AI 中的 Embedding 模型使用。

什么是 Embedding?

Embedding 将文本映射到高维向量空间,使语义相似的文本在向量空间中距离更近。

┌─────────────────────────────────────────────────────────────┐
│ Embedding 工作原理 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 文本输入 向量输出 │
│ ─────────────────────────────────────── │
│ "猫是一种宠物" ───> [0.12, 0.45, 0.78, ...] │
│ "狗是人类的朋友" ───> [0.15, 0.42, 0.75, ...] │
│ "汽车是交通工具" ───> [0.89, 0.12, 0.34, ...] │
│ │
│ 语义相似度: │
│ 猫/狗: 0.92 (高相似度) │
│ 猫/汽车: 0.15 (低相似度) │
│ │
└─────────────────────────────────────────────────────────────┘

核心概念

概念说明
向量一组浮点数,如 [0.1, 0.2, 0.3, ...]
维度向量的长度,常见有 768、1536、3072
相似度两个向量之间的距离,常用余弦相似度
语义空间高维向量空间,相似概念距离近

EmbeddingModel 接口

Spring AI 提供统一的 EmbeddingModel 接口:

public interface EmbeddingModel extends Model<Prompt, EmbeddingResponse> {

// 获取单个文本的嵌入向量
float[] embed(String text);

// 获取多个文本的嵌入向量
EmbeddingResponse embedForResponse(List<String> texts);

// 批量嵌入
List<List<Float>> embed(List<String> texts);

// 获取维度
int dimensions();
}

支持的 Embedding 模型

OpenAI Embedding

spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
embedding:
options:
model: text-embedding-3-small # 或 text-embedding-3-large
dimensions: 1536 # 可选,指定输出维度

可用模型:

模型维度特点
text-embedding-3-small512/1536性价比高,适合大多数场景
text-embedding-3-large256/1024/3072最高质量,适合高精度需求
text-embedding-ada-0021536旧版本,不推荐新项目使用

Ollama 本地模型

spring:
ai:
ollama:
base-url: http://localhost:11434
embedding:
model: nomic-embed-text
options:
dimension: 768

推荐本地模型:

模型维度大小说明
nomic-embed-text768274MB高质量,支持长文本
all-minilm38445MB轻量级,速度快
mxbai-embed-large1024670MB高质量多语言

Azure OpenAI

spring:
ai:
azure:
openai:
api-key: ${AZURE_API_KEY}
endpoint: ${AZURE_ENDPOINT}
embedding:
options:
deployment-name: text-embedding-ada-002

其他提供商

<!-- 阿里云通义千问 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-alibaba-starter</artifactId>
</dependency>

<!-- Vertex AI -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-vertex-ai-embedding-spring-boot-starter</artifactId>
</dependency>

基本使用

获取单个嵌入向量

@Service
public class EmbeddingService {

@Autowired
private EmbeddingModel embeddingModel;

public float[] getEmbedding(String text) {
return embeddingModel.embed(text);
}

public void printEmbedding(String text) {
float[] embedding = embeddingModel.embed(text);

System.out.println("文本: " + text);
System.out.println("维度: " + embedding.length);
System.out.println("向量前10位: " +
Arrays.toString(Arrays.copyOf(embedding, 10)));
}
}

批量获取嵌入向量

@GetMapping("/batch-embed")
public List<EmbeddingResult> batchEmbed(@RequestBody List<String> texts) {
EmbeddingResponse response = embeddingModel.embedForResponse(texts);

return response.getResults().stream()
.map(result -> new EmbeddingResult(
texts.get(response.getResults().indexOf(result)),
result.getOutput(),
result.getIndex()
))
.toList();
}

record EmbeddingResult(String text, float[] embedding, int index) {}

获取嵌入维度

@GetMapping("/dimensions")
public int getDimensions() {
return embeddingModel.dimensions();
}

相似度计算

余弦相似度

余弦相似度是衡量两个向量相似程度的常用方法:

@Service
public class SimilarityService {

@Autowired
private EmbeddingModel embeddingModel;

/**
* 计算两个文本的余弦相似度
* 返回值范围:-1 到 1,值越大表示越相似
*/
public double cosineSimilarity(String text1, String text2) {
float[] embedding1 = embeddingModel.embed(text1);
float[] embedding2 = embeddingModel.embed(text2);

return cosineSimilarity(embedding1, embedding2);
}

private double cosineSimilarity(float[] vector1, float[] vector2) {
double dotProduct = 0.0;
double norm1 = 0.0;
double norm2 = 0.0;

for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
norm1 += vector1[i] * vector1[i];
norm2 += vector2[i] * vector2[i];
}

return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
}
}

相似度搜索服务

@Service
public class SemanticSearchService {

@Autowired
private EmbeddingModel embeddingModel;

private final Map<String, float[]> documentEmbeddings = new ConcurrentHashMap<>();

/**
* 添加文档到搜索库
*/
public void addDocument(String id, String content) {
float[] embedding = embeddingModel.embed(content);
documentEmbeddings.put(id, embedding);
}

/**
* 批量添加文档
*/
public void addDocuments(Map<String, String> documents) {
List<String> contents = new ArrayList<>(documents.values());
EmbeddingResponse response = embeddingModel.embedForResponse(contents);

List<String> ids = new ArrayList<>(documents.keySet());
for (int i = 0; i < ids.size(); i++) {
documentEmbeddings.put(ids.get(i), response.getResult(i).getOutput());
}
}

/**
* 语义搜索
*/
public List<SearchResult> search(String query, int topK) {
float[] queryEmbedding = embeddingModel.embed(query);

return documentEmbeddings.entrySet().stream()
.map(entry -> new SearchResult(
entry.getKey(),
cosineSimilarity(queryEmbedding, entry.getValue())
))
.sorted(Comparator.comparingDouble(SearchResult::similarity).reversed())
.limit(topK)
.toList();
}

record SearchResult(String documentId, double similarity) {}
}

实际应用示例

1. 文本相似度比较

@RestController
@RequestMapping("/api/similarity")
public class SimilarityController {

@Autowired
private SimilarityService similarityService;

@PostMapping("/compare")
public SimilarityResponse compare(@RequestBody CompareRequest request) {
double similarity = similarityService.cosineSimilarity(
request.text1(),
request.text2()
);

String interpretation = interpretSimilarity(similarity);

return new SimilarityResponse(similarity, interpretation);
}

private String interpretSimilarity(double similarity) {
if (similarity > 0.9) return "非常相似,语义几乎相同";
if (similarity > 0.7) return "高度相似,语义密切相关";
if (similarity > 0.5) return "中等相似,有一定关联";
if (similarity > 0.3) return "低相似度,关联较弱";
return "几乎不相关";
}
}

record CompareRequest(String text1, String text2) {}
record SimilarityResponse(double similarity, String interpretation) {}

2. 文档聚类

@Service
public class DocumentClusteringService {

@Autowired
private EmbeddingModel embeddingModel;

/**
* 将文档按语义相似度分组
*/
public Map<String, List<String>> clusterDocuments(
List<String> documents,
double threshold) {

// 获取所有嵌入向量
EmbeddingResponse response = embeddingModel.embedForResponse(documents);
List<float[]> embeddings = response.getResults().stream()
.map(Embedding::getOutput)
.toList();

// 简单的聚类算法
Map<Integer, List<Integer>> clusters = new HashMap<>();
boolean[] assigned = new boolean[documents.size()];

for (int i = 0; i < documents.size(); i++) {
if (assigned[i]) continue;

List<Integer> cluster = new ArrayList<>();
cluster.add(i);
assigned[i] = true;

for (int j = i + 1; j < documents.size(); j++) {
if (!assigned[j] &&
cosineSimilarity(embeddings.get(i), embeddings.get(j)) >= threshold) {
cluster.add(j);
assigned[j] = true;
}
}

clusters.put(i, cluster);
}

// 转换结果
Map<String, List<String>> result = new HashMap<>();
clusters.forEach((center, indices) -> {
result.put(
"Cluster-" + center,
indices.stream().map(documents::get).toList()
);
});

return result;
}
}

3. 推荐系统

@Service
public class ContentRecommendationService {

@Autowired
private EmbeddingModel embeddingModel;

private final Map<Long, float[]> itemEmbeddings = new ConcurrentHashMap<>();
private final Map<Long, String> itemContents = new ConcurrentHashMap<>();

/**
* 添加内容项
*/
public void addItem(Long id, String content) {
itemEmbeddings.put(id, embeddingModel.embed(content));
itemContents.put(id, content);
}

/**
* 根据内容推荐相似项目
*/
public List<Recommendation> recommend(String query, int topK) {
float[] queryEmbedding = embeddingModel.embed(query);

return itemEmbeddings.entrySet().stream()
.map(entry -> new Recommendation(
entry.getKey(),
itemContents.get(entry.getKey()),
cosineSimilarity(queryEmbedding, entry.getValue())
))
.sorted(Comparator.comparingDouble(Recommendation::score).reversed())
.limit(topK)
.toList();
}

/**
* 根据项目ID推荐相似项目
*/
public List<Recommendation> recommendByItem(Long itemId, int topK) {
float[] targetEmbedding = itemEmbeddings.get(itemId);
if (targetEmbedding == null) {
return List.of();
}

return itemEmbeddings.entrySet().stream()
.filter(entry -> !entry.getKey().equals(itemId))
.map(entry -> new Recommendation(
entry.getKey(),
itemContents.get(entry.getKey()),
cosineSimilarity(targetEmbedding, entry.getValue())
))
.sorted(Comparator.comparingDouble(Recommendation::score).reversed())
.limit(topK)
.toList();
}

record Recommendation(Long id, String content, double score) {}
}

4. 重复内容检测

@Service
public class DuplicateDetectionService {

@Autowired
private EmbeddingModel embeddingModel;

private static final double DUPLICATE_THRESHOLD = 0.95;

/**
* 检测文本是否与已有内容重复
*/
public DuplicateCheckResult checkDuplicate(String newText, List<String> existingTexts) {
float[] newEmbedding = embeddingModel.embed(newText);

for (int i = 0; i < existingTexts.size(); i++) {
float[] existingEmbedding = embeddingModel.embed(existingTexts.get(i));
double similarity = cosineSimilarity(newEmbedding, existingEmbedding);

if (similarity >= DUPLICATE_THRESHOLD) {
return new DuplicateCheckResult(
true,
i,
similarity,
existingTexts.get(i)
);
}
}

return new DuplicateCheckResult(false, -1, 0, null);
}

/**
* 批量检测重复内容
*/
public List<DuplicatePair> findDuplicates(List<String> texts, double threshold) {
List<DuplicatePair> duplicates = new ArrayList<>();
EmbeddingResponse response = embeddingModel.embedForResponse(texts);
List<float[]> embeddings = response.getResults().stream()
.map(Embedding::getOutput)
.toList();

for (int i = 0; i < texts.size(); i++) {
for (int j = i + 1; j < texts.size(); j++) {
double similarity = cosineSimilarity(embeddings.get(i), embeddings.get(j));
if (similarity >= threshold) {
duplicates.add(new DuplicatePair(i, j, similarity));
}
}
}

return duplicates;
}

record DuplicateCheckResult(
boolean isDuplicate,
int duplicateIndex,
double similarity,
String duplicateContent
) {}

record DuplicatePair(int index1, int index2, double similarity) {}
}

与 VectorStore 结合

Embedding 模型与向量存储结合实现高效的语义搜索:

@Service
public class VectorSearchService {

@Autowired
private VectorStore vectorStore;

/**
* 添加文档到向量存储
*/
public void addDocument(String id, String content, Map<String, Object> metadata) {
Document document = new Document(id, content, metadata);
vectorStore.add(List.of(document));
}

/**
* 语义搜索
*/
public List<SearchResult> search(String query, int topK, double threshold) {
SearchRequest request = SearchRequest.query(query)
.withTopK(topK)
.withSimilarityThreshold(threshold);

return vectorStore.similaritySearch(request).stream()
.map(doc -> new SearchResult(
doc.getId(),
doc.getContent(),
doc.getMetadata(),
doc.getScore()
))
.toList();
}

/**
* 带过滤条件的搜索
*/
public List<SearchResult> searchWithFilter(
String query,
String category,
int topK) {

SearchRequest request = SearchRequest.query(query)
.withTopK(topK)
.withFilterExpression("category == '" + category + "'");

return vectorStore.similaritySearch(request).stream()
.map(doc -> new SearchResult(
doc.getId(),
doc.getContent(),
doc.getMetadata(),
doc.getScore()
))
.toList();
}

record SearchResult(String id, String content, Map<String, Object> metadata, Double score) {}
}

性能优化

批量处理

@Service
public class BatchEmbeddingService {

@Autowired
private EmbeddingModel embeddingModel;

private static final int BATCH_SIZE = 100;

/**
* 批量处理大量文本
*/
public Map<String, float[]> batchEmbed(Map<String, String> documents) {
Map<String, float[]> results = new ConcurrentHashMap<>();
List<String> ids = new ArrayList<>(documents.keySet());

// 分批处理
Lists.partition(ids, BATCH_SIZE).forEach(batch -> {
List<String> batchContents = batch.stream()
.map(documents::get)
.toList();

EmbeddingResponse response = embeddingModel.embedForResponse(batchContents);

for (int i = 0; i < batch.size(); i++) {
results.put(batch.get(i), response.getResult(i).getOutput());
}
});

return results;
}
}

缓存嵌入向量

@Service
public class CachedEmbeddingService {

@Autowired
private EmbeddingModel embeddingModel;

private final Cache<String, float[]> embeddingCache =
Caffeine.newBuilder()
.maximumSize(10000)
.expireAfterAccess(Duration.ofHours(1))
.build();

/**
* 获取嵌入向量(带缓存)
*/
public float[] getEmbedding(String text) {
return embeddingCache.get(text, key -> embeddingModel.embed(text));
}

/**
* 预热缓存
*/
public void warmup(List<String> commonTexts) {
commonTexts.forEach(this::getEmbedding);
}
}

最佳实践

1. 选择合适的维度

// 高维度 = 更精确,但存储和计算成本高
// 低维度 = 更快,但可能损失精度

// 文档搜索:1536 维度(平衡)
// 大规模语义匹配:768 维度(效率优先)
// 精确匹配任务:3072 维度(精度优先)

2. 处理长文本

@Service
public class LongTextEmbeddingService {

@Autowired
private EmbeddingModel embeddingModel;

private static final int MAX_TOKENS = 8000;

/**
* 处理超长文本:分割后取平均
*/
public float[] embedLongText(String longText) {
// 分割文本
List<String> chunks = splitText(longText, MAX_TOKENS);

// 获取每个块的嵌入
EmbeddingResponse response = embeddingModel.embedForResponse(chunks);

// 平均嵌入向量
float[] avgEmbedding = new float[embeddingModel.dimensions()];
for (Embedding embedding : response.getResults()) {
float[] vec = embedding.getOutput();
for (int i = 0; i < vec.length; i++) {
avgEmbedding[i] += vec[i];
}
}

for (int i = 0; i < avgEmbedding.length; i++) {
avgEmbedding[i] /= chunks.size();
}

return avgEmbedding;
}

private List<String> splitText(String text, int maxTokens) {
// 实现文本分割逻辑
// 可以按段落、句子或固定长度分割
return List.of(text); // 简化实现
}
}

3. 处理多语言

// 使用支持多语言的模型
spring:
ai:
openai:
embedding:
options:
model: text-embedding-3-large # 支持多语言

// 或使用专门的多语言模型
spring:
ai:
ollama:
embedding:
model: nomic-embed-text # 支持多语言

小结

本章我们学习了:

  1. Embedding 概念:将文本转换为数值向量
  2. 支持的模型:OpenAI、Ollama、Azure 等多种提供商
  3. 基本使用:获取嵌入向量、批量处理
  4. 相似度计算:余弦相似度及其应用
  5. 实际应用:语义搜索、推荐系统、重复检测
  6. 性能优化:批量处理、缓存策略

练习

  1. 实现一个简单的语义搜索引擎
  2. 创建一个内容推荐系统
  3. 实现文档去重功能

参考资源