文本生成策略
generate() 方法是 Transformers 中最核心的文本生成接口,支持多种解码策略。理解这些策略对于控制生成文本的质量、多样性和连贯性至关重要。本章将深入讲解各种生成策略的原理和使用方法。
生成策略概述
文本生成的本质是:给定已生成的文本序列,预测下一个 token,循环往复直到生成结束。不同的解码策略决定了如何从模型输出的概率分布中选择下一个 token。
┌─────────────────────────────────────────────────────────────────┐
│ 文本生成流程 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 输入文本 "今天天气" │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ 模型前向 │ │
│ │ 推理 │ │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────┐ │
│ │ 词汇表概率分布 │ │
│ │ 很好: 0.15 │ │
│ │ 不错: 0.12 │ │
│ │ 真好: 0.10 │ │
│ │ ... │ │
│ └──────┬──────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────┐ │
│ │ 解码策略 │ │
│ │ ┌────────┐ ┌────────┐ ┌────────┐ │ │
│ │ │贪婪解码│ │束搜索 │ │采样 │ │ │
│ │ └────────┘ └────────┘ └────────┘ │ │
│ └──────┬──────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 选中的 token "很好" │
│ │ │
│ ▼ │
│ 重复上述过程直到生成结束 │
│ │
└─────────────────────────────────────────────────────────────────┘
基本生成方法
贪婪解码 (Greedy Decoding)
贪婪解码是最简单的策略:每一步都选择概率最高的 token。这是 generate() 的默认行为。
原理:直接选取概率最大的词,不考虑全局最优。
优点:
- 计算速度快
- 结果确定性(相同输入必然产生相同输出)
- 适合短文本生成
缺点:
- 容易产生重复内容
- 缺乏创造性
- 可能陷入局部最优
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
# 贪婪解码(默认行为)
output = model.generate(
**inputs,
max_new_tokens=50,
do_sample=False, # 不使用采样
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
# 输出示例:The future of artificial intelligence is bright. The technology is...
贪婪解码适合需要确定性输出的场景,如代码补全、格式化文本生成等。但在开放式生成任务中,往往会产生重复且缺乏创造性的文本。
束搜索 (Beam Search)
束搜索是一种启发式搜索算法,它在每一步保留多个候选序列(称为"束"),最终选择概率最高的完整序列。
原理:维护 num_beams 个候选序列,每一步扩展所有候选,保留总概率最高的 num_beams 个。
┌─────────────────────────────────────────────────────────────────┐
│ 束搜索示意图 (num_beams=2) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Step 0: "今天" │
│ │ │
│ ├── 天气很好 (0.15) ──┐ │
│ └── 天气不错 (0.12) ──┼── 保留 top-2 │
│ │ │
│ Step 1: 扩展两个候选 │
│ │ │
│ ├── "今天天气很好" ──┬── 适合外出 (0.20) │
│ │ └── 适合散步 (0.18) │
│ │ │
│ └── "今天天气不错" ──┬── 适合运动 (0.15) │
│ └── 适合游玩 (0.14) │
│ │
│ Step 2: 再次保留 top-2 组合 │
│ │ │
│ ├── "今天天气很好适合外出" (累计概率: 0.15×0.20=0.03) │
│ └── "今天天气很好适合散步" (累计概率: 0.15×0.18=0.027) │
│ │
└─────────────────────────────────────────────────────────────────┘
优点:
- 能够找到概率更高的序列
- 比贪婪解码更有可能找到全局较优解
- 适合需要高概率输出的任务(如翻译、摘要)
缺点:
- 计算量随束宽增加而线性增长
- 可能仍然产生重复
- 不适合开放式生成
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
prompt = "The capital of France is"
inputs = tokenizer(prompt, return_tensors="pt")
# 束搜索
output = model.generate(
**inputs,
max_new_tokens=20,
num_beams=5, # 束宽为 5
early_stopping=True, # 当所有束都生成结束符时提前停止
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
# 输出示例:The capital of France is Paris, which is also the largest city in France.
返回多个候选序列:
# 返回多个候选
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=5,
num_return_sequences=3, # 返回 3 个不同的候选
early_stopping=True,
pad_token_id=tokenizer.eos_token_id
)
for i, output in enumerate(outputs):
print(f"候选 {i+1}: {tokenizer.decode(output, skip_special_tokens=True)}")
长度惩罚:
束搜索倾向于生成较短的序列,因为概率的连乘会使长序列的总概率降低。通过 length_penalty 参数可以缓解这个问题:
output = model.generate(
**inputs,
max_new_tokens=50,
num_beams=5,
length_penalty=1.0, # > 0 鼓励长序列,< 0 鼓励短序列,= 0 无惩罚
early_stopping=True,
pad_token_id=tokenizer.eos_token_id
)
采样 (Sampling)
采样策略从概率分布中随机选择下一个 token,而不是总是选择最高概率的词。这引入了随机性,使生成结果更多样化。
原理:根据模型输出的概率分布随机采样一个 token。
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
prompt = "Once upon a time"
inputs = tokenizer(prompt, return_tensors="pt")
# 基本采样
torch.manual_seed(42) # 设置随机种子以复现结果
output = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True, # 启用采样
num_beams=1, # 不使用束搜索
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
纯采样的问题在于:概率分布可能非常平坦,导致选择了很多概率很低的词,生成的文本可能不连贯。因此,通常需要配合其他技术使用。
高级采样技术
温度 (Temperature)
温度参数用于调整概率分布的"陡峭程度":
其中 是模型输出的 logits, 是温度参数。
- :原始概率分布,不做调整
- :分布更陡峭,倾向于选择高概率词,输出更确定
- :分布更平坦,选择更随机,输出更多样化
- :接近贪婪解码
┌─────────────────────────────────────────────────────────────────┐
│ 温度参数效果对比 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 原始概率分布 (T=1): │
│ ┌────────────────────────────────────────┐ │
│ │ ████ 很好 (0.30) │ │
│ │ ███ 不错 (0.25) │ │
│ │ ██ 还行 (0.20) │ │
│ │ █ 一般 (0.15) │ │
│ │ 较差 (0.10) │ │
│ └────────────────────────────────────────┘ │
│ │
│ 低温度 T=0.5 (更确定): │
│ ┌────────────────────────────────────────┐ │
│ │ ████████ 很好 (0.45) │ │
│ │ ██████ 不错 (0.30) │ │
│ │ ███ 还行 (0.15) │ │
│ │ █ 一般 (0.07) │ │
│ │ 较差 (0.03) │ │
│ └────────────────────────────────────────┘ │
│ │
│ 高温度 T=1.5 (更随机): │
│ ┌────────────────────────────────────────┐ │
│ │ ████ 很好 (0.24) │ │
│ │ ████ 不错 (0.23) │ │
│ │ ███ 还行 (0.20) │ │
│ │ ███ 一般 (0.18) │ │
│ │ ██ 较差 (0.15) │ │
│ └────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
# 低温度:更确定的输出
output_low_temp = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.7, # 温度参数
pad_token_id=tokenizer.eos_token_id
)
# 高温度:更随机的输出
output_high_temp = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=1.5,
pad_token_id=tokenizer.eos_token_id
)
print("低温度输出:")
print(tokenizer.decode(output_low_temp[0], skip_special_tokens=True))
print("\n高温度输出:")
print(tokenizer.decode(output_high_temp[0], skip_special_tokens=True))
温度选择建议:
- 事实性任务(问答、翻译):
- 平衡任务(对话、写作):
- 创意任务(故事创作):
Top-K 采样
Top-K 采样只在概率最高的 K 个词中进行采样,将其他词的概率设为 0,然后重新归一化。
原理:保留概率最高的 K 个候选,截断分布的尾部。
┌─────────────────────────────────────────────────────────────────┐
│ Top-K 采样 (K=3) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 原始分布: │
│ ┌────────────────────────────────────────┐ │
│ │ ████████████ 很好 (0.35) ✓ │ │
│ │ ██████████ 不错 (0.30) ✓ │ │
│ │ ████████ 还行 (0.22) ✓ │ │
│ │ ████ 一般 (0.08) ✗ │ │
│ │ ███ 较差 (0.05) ✗ │ │
│ └────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 重新归一化后: │
│ ┌────────────────────────────────────────┐ │
│ │ ██████████████ 很好 (0.40) │ │
│ │ ████████████ 不错 (0.35) │ │
│ │ ██████████ 还行 (0.25) │ │
│ └────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
output = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
top_k=50, # 只在概率最高的 50 个词中采样
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Top-K 的问题:固定的 K 值可能不适合所有情况。有时分布很集中(确定的情况),只需要几个候选;有时分布很平坦(不确定的情况),可能需要更多候选。
Top-P (Nucleus) 采样
Top-P 采样(也叫核采样)动态选择最小的一组词,使它们的累积概率达到 P。这解决了 Top-K 固定候选数量的问题。
原理:按概率从高到低排序,保留累积概率达到 P 的最小词集。
┌─────────────────────────────────────────────────────────────────┐
│ Top-P 采样 (P=0.9) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 情况1:分布集中(确定性高) │
│ ┌────────────────────────────────────────┐ │
│ │ ████████████████████ 很好 (0.60) ✓ │ │
│ │ ████████████ 不错 (0.25) ✓ │ ← 累积: 0.85 │
│ │ ████ 还行 (0.08) ✓ │ ← 累积: 0.93 > P │
│ │ ██ 一般 (0.04) ✗ │ │
│ │ █ 较差 (0.03) ✗ │ │
│ └────────────────────────────────────────┘ │
│ 只需保留 3 个词即可达到 P=0.9 │
│ │
│ 情况2:分布平坦(不确定性高) │
│ ┌────────────────────────────────────────┐ │
│ │ ████████ 很好 (0.20) ✓ │ │
│ │ ███████ 不错 (0.18) ✓ │ ← 累积: 0.38 │
│ │ ██████ 还行 (0.15) ✓ │ ← 累积: 0.53 │
│ │ █████ 一般 (0.14) ✓ │ ← 累积: 0.67 │
│ │ ████ 较差 (0.12) ✓ │ ← 累积: 0.79 │
│ │ ███ 还可以 (0.10) ✓ │ ← 累积: 0.89 │
│ │ ██ 马马虎虎 (0.08) ✓ │ ← 累积: 0.97 > P │
│ │ █ 糟糕 (0.03) ✗ │ │
│ └────────────────────────────────────────┘ │
│ 需要保留 7 个词才能达到 P=0.9 │
│ │
└─────────────────────────────────────────────────────────────────┘
output = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
top_p=0.9, # 累积概率阈值
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
组合使用 Top-K 和 Top-P
实践中,常常同时使用 Top-K 和 Top-P,先通过 Top-K 截断极端情况,再用 Top-P 进行精细控制:
output = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True,
top_k=50, # 先限制在 top 50
top_p=0.95, # 再在累积概率 0.95 内采样
temperature=0.8, # 同时使用温度
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
推荐配置:
- 通用对话:
temperature=0.7, top_p=0.9 - 创意写作:
temperature=1.0, top_p=0.95 - 代码生成:
temperature=0.2, top_p=0.9
重复控制
重复惩罚 (Repetition Penalty)
重复惩罚通过降低已出现词的概率来减少重复:
其中 是惩罚因子, 时降低已出现词的概率。
output = model.generate(
**inputs,
max_new_tokens=100,
repetition_penalty=1.2, # 惩罚因子,> 1 时减少重复
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
惩罚因子选择:
1.0:不惩罚1.1 ~ 1.2:轻微惩罚,适合大多数情况1.5 ~ 2.0:强惩罚,可能导致输出不连贯
N-gram 重复禁止
no_repeat_ngram_size 参数可以完全禁止指定大小的 n-gram 重复:
output = model.generate(
**inputs,
max_new_tokens=100,
no_repeat_ngram_size=3, # 禁止 3-gram 重复
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
这意味着任何 3 个连续的词都不能出现第二次。这在翻译和摘要任务中特别有用。
长度控制
最大/最小长度
output = model.generate(
**inputs,
max_new_tokens=100, # 最大生成 token 数
min_new_tokens=20, # 最小生成 token 数
pad_token_id=tokenizer.eos_token_id
)
早停策略
配合束搜索使用,当所有候选序列都已生成结束符时提前停止:
output = model.generate(
**inputs,
max_new_tokens=100,
num_beams=5,
early_stopping=True, # 早停
pad_token_id=tokenizer.eos_token_id
)
高级功能
流式输出 (Streaming)
对于交互式应用,流式输出可以让用户看到生成过程,而不必等待完整输出:
from transformers import TextStreamer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
prompt = "Write a short story about a robot learning to love:"
inputs = tokenizer(prompt, return_tensors="pt")
# 使用 TextStreamer 实现流式输出
streamer = TextStreamer(tokenizer)
output = model.generate(
**inputs,
max_new_tokens=100,
streamer=streamer, # 流式输出
pad_token_id=tokenizer.eos_token_id
)
自定义流式处理器:
from transformers import TextIteratorStreamer
from threading import Thread
# 创建迭代器流式处理器
streamer = TextIteratorStreamer(tokenizer)
# 在单独的线程中运行生成
generation_kwargs = dict(
**inputs,
max_new_tokens=100,
streamer=streamer,
pad_token_id=tokenizer.eos_token_id
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 迭代获取生成的文本
for text in streamer:
print(text, end="", flush=True)
约束生成
约束生成可以强制模型输出包含特定内容:
from transformers import Constraint, PhrasalConstraint
# 强制输出包含特定短语
force_phrase = PhrasalConstraint(tokenizer.encode("artificial intelligence", add_special_tokens=False))
output = model.generate(
**inputs,
max_new_tokens=50,
constraints=[force_phrase],
pad_token_id=tokenizer.eos_token_id
)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Bad Words 过滤
禁止生成特定的词或短语:
# 禁止生成的词 ID 列表
bad_words_ids = [
tokenizer.encode("badword", add_special_tokens=False),
tokenizer.encode("anotherbad", add_special_tokens=False)
]
output = model.generate(
**inputs,
max_new_tokens=50,
bad_words_ids=bad_words_ids,
pad_token_id=tokenizer.eos_token_id
)
自定义 Logits 处理器
对于更复杂的需求,可以自定义 logits 处理器:
from transformers import LogitsProcessor
class CustomLogitsProcessor(LogitsProcessor):
def __init__(self, boost_token_ids, boost_factor=2.0):
self.boost_token_ids = boost_token_ids
self.boost_factor = boost_factor
def __call__(self, input_ids, scores):
# 提升特定 token 的概率
for token_id in self.boost_token_ids:
scores[:, token_id] *= self.boost_factor
return scores
# 使用自定义处理器
boost_processor = CustomLogitsProcessor(
boost_token_ids=[tokenizer.encode("good", add_special_tokens=False)[0]],
boost_factor=1.5
)
output = model.generate(
**inputs,
max_new_tokens=50,
logits_processor=[boost_processor],
pad_token_id=tokenizer.eos_token_id
)
GenerationConfig 配置
GenerationConfig 类封装了所有生成参数,可以保存和加载配置:
from transformers import GenerationConfig
# 创建自定义配置
generation_config = GenerationConfig(
max_new_tokens=100,
do_sample=True,
temperature=0.8,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
# 保存配置
generation_config.save_pretrained("./my_generation_config")
# 加载配置
generation_config = GenerationConfig.from_pretrained("./my_generation_config")
# 使用配置生成
output = model.generate(
**inputs,
generation_config=generation_config
)
也可以将配置与模型一起保存:
# 保存模型和生成配置
model.generation_config = generation_config
model.save_pretrained("./my_model")
# 加载时会自动加载生成配置
model = AutoModelForCausalLM.from_pretrained("./my_model")
完整示例
对话生成
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 对话生成
def chat(user_input, chat_history_ids=None):
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
# 拼接历史
if chat_history_ids is not None:
bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
else:
bot_input_ids = new_input_ids
# 生成回复
chat_history_ids = model.generate(
bot_input_ids,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
# 只返回新生成的部分
response = tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True
)
return response, chat_history_ids
# 对话循环
history = None
while True:
user_input = input("User: ")
if user_input.lower() in ["bye", "exit", "quit"]:
break
response, history = chat(user_input, history)
print(f"Bot: {response}")
批量生成
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# 批量 prompts
prompts = [
"The quick brown fox",
"In the beginning",
"Once upon a time"
]
# 批量处理
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
# 批量生成
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.8,
top_p=0.9,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
# 解码结果
for i, output in enumerate(outputs):
print(f"Prompt {i+1}: {prompts[i]}")
print(f"Generated: {tokenizer.decode(output, skip_special_tokens=True)}\n")
策略选择指南
| 任务类型 | 推荐策略 | 参数建议 |
|---|---|---|
| 代码补全 | 贪婪解码 | do_sample=False |
| 翻译 | 束搜索 | num_beams=4-6, length_penalty=1.0 |
| 摘要 | 束搜索 | num_beams=4, length_penalty=2.0 |
| 对话 | 采样 | temperature=0.7, top_p=0.9 |
| 创意写作 | 采样 | temperature=0.9-1.2, top_p=0.95 |
| 问答 | 贪婪/低温度采样 | temperature=0.3-0.5, top_p=0.9 |
常见问题
1. 输出重复怎么办?
# 方案1:增加重复惩罚
repetition_penalty=1.2
# 方案2:禁止 n-gram 重复
no_repeat_ngram_size=3
# 方案3:降低温度
temperature=0.7
2. 输出不连贯怎么办?
# 方案1:提高温度
temperature=0.9
# 方案2:使用 Top-P 采样
top_p=0.9
# 方案3:降低重复惩罚
repetition_penalty=1.0
3. 生成太短或太长?
# 控制长度
max_new_tokens=100
min_new_tokens=20
# 或使用长度惩罚(束搜索)
length_penalty=1.5 # 鼓励长输出