Embedding 嵌入模型
Embedding(嵌入)是将文本转换为数值向量的关键技术,是语义搜索、RAG、推荐系统等 AI 应用的基础。本章介绍 Spring AI 中的 Embedding 模型使用。
什么是 Embedding?
Embedding 将文本映射到高维向量空间,使语义相似的文本在向量空间中距离更近。
┌─────────────────────────────────────────────────────────────┐
│ Embedding 工作原理 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 文本输入 向量输出 │
│ ─────────────────────────────────────── │
│ "猫是一种宠物" ───> [0.12, 0.45, 0.78, ...] │
│ "狗是人类的朋友" ───> [0.15, 0.42, 0.75, ...] │
│ "汽车是交通工具" ───> [0.89, 0.12, 0.34, ...] │
│ │
│ 语义相似度: │
│ 猫/狗: 0.92 (高相似度) │
│ 猫/汽车: 0.15 (低相似度) │
│ │
└─────────────────────────────────────────────────────────────┘
核心概念
| 概念 | 说明 |
|---|---|
| 向量 | 一组浮点数,如 [0.1, 0.2, 0.3, ...] |
| 维度 | 向量的长度,常见有 768、1536、3072 |
| 相似度 | 两个向量之间的距离,常用余弦相似度 |
| 语义空间 | 高维向量空间,相似概念距离近 |
EmbeddingModel 接口
Spring AI 提供统一的 EmbeddingModel 接口:
public interface EmbeddingModel extends Model<Prompt, EmbeddingResponse> {
// 获取单个文本的嵌入向量
float[] embed(String text);
// 获取多个文本的嵌入向量
EmbeddingResponse embedForResponse(List<String> texts);
// 批量嵌入
List<List<Float>> embed(List<String> texts);
// 获取维度
int dimensions();
}
支持的 Embedding 模型
OpenAI Embedding
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
embedding:
options:
model: text-embedding-3-small # 或 text-embedding-3-large
dimensions: 1536 # 可选,指定输出维度
可用模型:
| 模型 | 维度 | 特点 |
|---|---|---|
text-embedding-3-small | 512/1536 | 性价比高,适合大多数场景 |
text-embedding-3-large | 256/1024/3072 | 最高质量,适合高精度需求 |
text-embedding-ada-002 | 1536 | 旧版本,不推荐新项目使用 |
Ollama 本地模型
spring:
ai:
ollama:
base-url: http://localhost:11434
embedding:
model: nomic-embed-text
options:
dimension: 768
推荐本地模型:
| 模型 | 维度 | 大小 | 说明 |
|---|---|---|---|
nomic-embed-text | 768 | 274MB | 高质量,支持长文本 |
all-minilm | 384 | 45MB | 轻量级,速度快 |
mxbai-embed-large | 1024 | 670MB | 高质量多语言 |
Azure OpenAI
spring:
ai:
azure:
openai:
api-key: ${AZURE_API_KEY}
endpoint: ${AZURE_ENDPOINT}
embedding:
options:
deployment-name: text-embedding-ada-002
其他提供商
<!-- 阿里云通义千问 -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-alibaba-starter</artifactId>
</dependency>
<!-- Vertex AI -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-vertex-ai-embedding-spring-boot-starter</artifactId>
</dependency>
基本使用
获取单个嵌入向量
@Service
public class EmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
public float[] getEmbedding(String text) {
return embeddingModel.embed(text);
}
public void printEmbedding(String text) {
float[] embedding = embeddingModel.embed(text);
System.out.println("文本: " + text);
System.out.println("维度: " + embedding.length);
System.out.println("向量前10位: " +
Arrays.toString(Arrays.copyOf(embedding, 10)));
}
}
批量获取嵌入向量
@GetMapping("/batch-embed")
public List<EmbeddingResult> batchEmbed(@RequestBody List<String> texts) {
EmbeddingResponse response = embeddingModel.embedForResponse(texts);
return response.getResults().stream()
.map(result -> new EmbeddingResult(
texts.get(response.getResults().indexOf(result)),
result.getOutput(),
result.getIndex()
))
.toList();
}
record EmbeddingResult(String text, float[] embedding, int index) {}
获取嵌入维度
@GetMapping("/dimensions")
public int getDimensions() {
return embeddingModel.dimensions();
}
相似度计算
余弦相似度
余弦相似度是衡量两个向量相似程度的常用方法:
@Service
public class SimilarityService {
@Autowired
private EmbeddingModel embeddingModel;
/**
* 计算两个文本的余弦相似度
* 返回值范围:-1 到 1,值越大表示越相似
*/
public double cosineSimilarity(String text1, String text2) {
float[] embedding1 = embeddingModel.embed(text1);
float[] embedding2 = embeddingModel.embed(text2);
return cosineSimilarity(embedding1, embedding2);
}
private double cosineSimilarity(float[] vector1, float[] vector2) {
double dotProduct = 0.0;
double norm1 = 0.0;
double norm2 = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
norm1 += vector1[i] * vector1[i];
norm2 += vector2[i] * vector2[i];
}
return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
}
}
相似度搜索服务
@Service
public class SemanticSearchService {
@Autowired
private EmbeddingModel embeddingModel;
private final Map<String, float[]> documentEmbeddings = new ConcurrentHashMap<>();
/**
* 添加文档到搜索库
*/
public void addDocument(String id, String content) {
float[] embedding = embeddingModel.embed(content);
documentEmbeddings.put(id, embedding);
}
/**
* 批量添加文档
*/
public void addDocuments(Map<String, String> documents) {
List<String> contents = new ArrayList<>(documents.values());
EmbeddingResponse response = embeddingModel.embedForResponse(contents);
List<String> ids = new ArrayList<>(documents.keySet());
for (int i = 0; i < ids.size(); i++) {
documentEmbeddings.put(ids.get(i), response.getResult(i).getOutput());
}
}
/**
* 语义搜索
*/
public List<SearchResult> search(String query, int topK) {
float[] queryEmbedding = embeddingModel.embed(query);
return documentEmbeddings.entrySet().stream()
.map(entry -> new SearchResult(
entry.getKey(),
cosineSimilarity(queryEmbedding, entry.getValue())
))
.sorted(Comparator.comparingDouble(SearchResult::similarity).reversed())
.limit(topK)
.toList();
}
record SearchResult(String documentId, double similarity) {}
}
实际应用示例
1. 文本相似度比较
@RestController
@RequestMapping("/api/similarity")
public class SimilarityController {
@Autowired
private SimilarityService similarityService;
@PostMapping("/compare")
public SimilarityResponse compare(@RequestBody CompareRequest request) {
double similarity = similarityService.cosineSimilarity(
request.text1(),
request.text2()
);
String interpretation = interpretSimilarity(similarity);
return new SimilarityResponse(similarity, interpretation);
}
private String interpretSimilarity(double similarity) {
if (similarity > 0.9) return "非常相似,语义几乎相同";
if (similarity > 0.7) return "高度相似,语义密切相关";
if (similarity > 0.5) return "中等相似,有一定关联";
if (similarity > 0.3) return "低相似度,关联较弱";
return "几乎不相关";
}
}
record CompareRequest(String text1, String text2) {}
record SimilarityResponse(double similarity, String interpretation) {}
2. 文档聚类
@Service
public class DocumentClusteringService {
@Autowired
private EmbeddingModel embeddingModel;
/**
* 将文档按语义相似度分组
*/
public Map<String, List<String>> clusterDocuments(
List<String> documents,
double threshold) {
// 获取所有嵌入向量
EmbeddingResponse response = embeddingModel.embedForResponse(documents);
List<float[]> embeddings = response.getResults().stream()
.map(Embedding::getOutput)
.toList();
// 简单的聚类算法
Map<Integer, List<Integer>> clusters = new HashMap<>();
boolean[] assigned = new boolean[documents.size()];
for (int i = 0; i < documents.size(); i++) {
if (assigned[i]) continue;
List<Integer> cluster = new ArrayList<>();
cluster.add(i);
assigned[i] = true;
for (int j = i + 1; j < documents.size(); j++) {
if (!assigned[j] &&
cosineSimilarity(embeddings.get(i), embeddings.get(j)) >= threshold) {
cluster.add(j);
assigned[j] = true;
}
}
clusters.put(i, cluster);
}
// 转换结果
Map<String, List<String>> result = new HashMap<>();
clusters.forEach((center, indices) -> {
result.put(
"Cluster-" + center,
indices.stream().map(documents::get).toList()
);
});
return result;
}
}
3. 推荐系统
@Service
public class ContentRecommendationService {
@Autowired
private EmbeddingModel embeddingModel;
private final Map<Long, float[]> itemEmbeddings = new ConcurrentHashMap<>();
private final Map<Long, String> itemContents = new ConcurrentHashMap<>();
/**
* 添加内容项
*/
public void addItem(Long id, String content) {
itemEmbeddings.put(id, embeddingModel.embed(content));
itemContents.put(id, content);
}
/**
* 根据内容推荐相似项目
*/
public List<Recommendation> recommend(String query, int topK) {
float[] queryEmbedding = embeddingModel.embed(query);
return itemEmbeddings.entrySet().stream()
.map(entry -> new Recommendation(
entry.getKey(),
itemContents.get(entry.getKey()),
cosineSimilarity(queryEmbedding, entry.getValue())
))
.sorted(Comparator.comparingDouble(Recommendation::score).reversed())
.limit(topK)
.toList();
}
/**
* 根据项目ID推荐相似项目
*/
public List<Recommendation> recommendByItem(Long itemId, int topK) {
float[] targetEmbedding = itemEmbeddings.get(itemId);
if (targetEmbedding == null) {
return List.of();
}
return itemEmbeddings.entrySet().stream()
.filter(entry -> !entry.getKey().equals(itemId))
.map(entry -> new Recommendation(
entry.getKey(),
itemContents.get(entry.getKey()),
cosineSimilarity(targetEmbedding, entry.getValue())
))
.sorted(Comparator.comparingDouble(Recommendation::score).reversed())
.limit(topK)
.toList();
}
record Recommendation(Long id, String content, double score) {}
}
4. 重复内容检测
@Service
public class DuplicateDetectionService {
@Autowired
private EmbeddingModel embeddingModel;
private static final double DUPLICATE_THRESHOLD = 0.95;
/**
* 检测文本是否与已有内容重复
*/
public DuplicateCheckResult checkDuplicate(String newText, List<String> existingTexts) {
float[] newEmbedding = embeddingModel.embed(newText);
for (int i = 0; i < existingTexts.size(); i++) {
float[] existingEmbedding = embeddingModel.embed(existingTexts.get(i));
double similarity = cosineSimilarity(newEmbedding, existingEmbedding);
if (similarity >= DUPLICATE_THRESHOLD) {
return new DuplicateCheckResult(
true,
i,
similarity,
existingTexts.get(i)
);
}
}
return new DuplicateCheckResult(false, -1, 0, null);
}
/**
* 批量检测重复内容
*/
public List<DuplicatePair> findDuplicates(List<String> texts, double threshold) {
List<DuplicatePair> duplicates = new ArrayList<>();
EmbeddingResponse response = embeddingModel.embedForResponse(texts);
List<float[]> embeddings = response.getResults().stream()
.map(Embedding::getOutput)
.toList();
for (int i = 0; i < texts.size(); i++) {
for (int j = i + 1; j < texts.size(); j++) {
double similarity = cosineSimilarity(embeddings.get(i), embeddings.get(j));
if (similarity >= threshold) {
duplicates.add(new DuplicatePair(i, j, similarity));
}
}
}
return duplicates;
}
record DuplicateCheckResult(
boolean isDuplicate,
int duplicateIndex,
double similarity,
String duplicateContent
) {}
record DuplicatePair(int index1, int index2, double similarity) {}
}
与 VectorStore 结合
Embedding 模型与向量存储结合实现高效的语义搜索:
@Service
public class VectorSearchService {
@Autowired
private VectorStore vectorStore;
/**
* 添加文档到向量存储
*/
public void addDocument(String id, String content, Map<String, Object> metadata) {
Document document = new Document(id, content, metadata);
vectorStore.add(List.of(document));
}
/**
* 语义搜索
*/
public List<SearchResult> search(String query, int topK, double threshold) {
SearchRequest request = SearchRequest.query(query)
.withTopK(topK)
.withSimilarityThreshold(threshold);
return vectorStore.similaritySearch(request).stream()
.map(doc -> new SearchResult(
doc.getId(),
doc.getContent(),
doc.getMetadata(),
doc.getScore()
))
.toList();
}
/**
* 带过滤条件的搜索
*/
public List<SearchResult> searchWithFilter(
String query,
String category,
int topK) {
SearchRequest request = SearchRequest.query(query)
.withTopK(topK)
.withFilterExpression("category == '" + category + "'");
return vectorStore.similaritySearch(request).stream()
.map(doc -> new SearchResult(
doc.getId(),
doc.getContent(),
doc.getMetadata(),
doc.getScore()
))
.toList();
}
record SearchResult(String id, String content, Map<String, Object> metadata, Double score) {}
}
性能优化
批量处理
@Service
public class BatchEmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
private static final int BATCH_SIZE = 100;
/**
* 批量处理大量文本
*/
public Map<String, float[]> batchEmbed(Map<String, String> documents) {
Map<String, float[]> results = new ConcurrentHashMap<>();
List<String> ids = new ArrayList<>(documents.keySet());
// 分批处理
Lists.partition(ids, BATCH_SIZE).forEach(batch -> {
List<String> batchContents = batch.stream()
.map(documents::get)
.toList();
EmbeddingResponse response = embeddingModel.embedForResponse(batchContents);
for (int i = 0; i < batch.size(); i++) {
results.put(batch.get(i), response.getResult(i).getOutput());
}
});
return results;
}
}
缓存嵌入向量
@Service
public class CachedEmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
private final Cache<String, float[]> embeddingCache =
Caffeine.newBuilder()
.maximumSize(10000)
.expireAfterAccess(Duration.ofHours(1))
.build();
/**
* 获取嵌入向量(带缓存)
*/
public float[] getEmbedding(String text) {
return embeddingCache.get(text, key -> embeddingModel.embed(text));
}
/**
* 预热缓存
*/
public void warmup(List<String> commonTexts) {
commonTexts.forEach(this::getEmbedding);
}
}
最佳实践
1. 选择合适的维度
// 高维度 = 更精确,但存储和计算成本高
// 低维度 = 更快,但可能损失精度
// 文档搜索:1536 维度(平衡)
// 大规模语义匹配:768 维度(效率优先)
// 精确匹配任务:3072 维度(精度优先)
2. 处理长文本
@Service
public class LongTextEmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
private static final int MAX_TOKENS = 8000;
/**
* 处理超长文本:分割后取平均
*/
public float[] embedLongText(String longText) {
// 分割文本
List<String> chunks = splitText(longText, MAX_TOKENS);
// 获取每个块的嵌入
EmbeddingResponse response = embeddingModel.embedForResponse(chunks);
// 平均嵌入向量
float[] avgEmbedding = new float[embeddingModel.dimensions()];
for (Embedding embedding : response.getResults()) {
float[] vec = embedding.getOutput();
for (int i = 0; i < vec.length; i++) {
avgEmbedding[i] += vec[i];
}
}
for (int i = 0; i < avgEmbedding.length; i++) {
avgEmbedding[i] /= chunks.size();
}
return avgEmbedding;
}
private List<String> splitText(String text, int maxTokens) {
// 实现文本分割逻辑
// 可以按段落、句子或固定长度分割
return List.of(text); // 简化实现
}
}
3. 处理多语言
// 使用支持多语言的模型
spring:
ai:
openai:
embedding:
options:
model: text-embedding-3-large # 支持多语言
// 或使用专门的多语言模型
spring:
ai:
ollama:
embedding:
model: nomic-embed-text # 支持多语言
小结
本章我们学习了:
- Embedding 概念:将文本转换为数值向量
- 支持的模型:OpenAI、Ollama、Azure 等多种提供商
- 基本使用:获取嵌入向量、批量处理
- 相似度计算:余弦相似度及其应用
- 实际应用:语义搜索、推荐系统、重复检测
- 性能优化:批量处理、缓存策略
练习
- 实现一个简单的语义搜索引擎
- 创建一个内容推荐系统
- 实现文档去重功能