X-transformers项目解析
- AI
- 1天前
- 12热度
- 0评论
X-transformers项目架构概览
架构关系图
训练脚本层 (train_*.py)
├── train_copy.py
├── train_with_muon.py
├── train_entropy_tokenizer.py
├── train_parity.py
├── train_enwik8.py
├── train_length_extrapolate.py
├── train_gpt_vae.py
├── train_belief_state.py
└── train_free.py
↓
x_transformers 核心库层
├── x_transformers.py (核心Transformer实现)
├── attend.py (注意力机制)
├── continuous.py (连续表示)
├── neo_mlp.py (MLP组件)
├── multi_input.py (多输入处理)
├── dpo.py (直接偏好优化)
├── xval.py (交叉验证)
├── entropy_based_tokenizer.py (熵分词器)
├── gpt_vae.py (GPT-VAE模型)
└── free_transformer.py (自由Transformer)
↓
包装器层 (Wrapper)
├── autoregressive_wrapper.py (自回归包装器)
├── xl_autoregressive_wrapper.py (XL自回归包装器)
├── nonautoregressive_wrapper.py (非自回归包装器)
├── belief_state_wrapper.py (信念状态包装器)
├── up_wrapper.py (上采样包装器)
核心依赖关系分析
1. 分层架构
- 训练脚本层:各种实验和任务的训练入口点
- 核心模型层:实现不同的Transformer变体和组件
- 包装器层:为模型添加特定功能(自回归、非自回归等)
2. 核心数据流
输入数据 → 预处理 → Transformer核心 → 包装器 → 输出
↓ ↓ ↓ ↓ ↓
训练脚本 → 分词器 → 注意力机制 → 特定任务 → 损失计算
3. 关键模块依赖
x_transformers.py 是核心模块,依赖:
attend.py:注意力机制实现continuous.py:连续表示处理neo_mlp.py:前馈网络
包装器依赖关系:
autoregressive_wrapper.py:基础自回归训练xl_autoregressive_wrapper.py:扩展的自回归(可能支持更长序列)belief_state_wrapper.py:用于信念状态跟踪任务nonautoregressive_wrapper.py:非自回归生成
4. 训练脚本与核心库的对应关系
| 训练脚本 | 主要依赖的x_transformers模块 | 任务类型 |
|---|---|---|
| train_gpt_vae.py | gpt_vae.py + autoregressive_wrapper | 变分自编码器 |
| train_belief_state.py | belief_state_wrapper | 信念状态建模 |
| train_entropy_tokenizer.py | entropy_based_tokenizer.py | 分词器训练 |
| train_free.py | free_transformer.py | 自由Transformer实验 |
| train_enwik8.py | xl_autoregressive_wrapper | 语言建模 |
| train_length_extrapolate.py | xl_autoregressive_wrapper | 长度外推 |
5. 架构特点
- 模块化设计:每个文件功能单一,职责明确
- 可插拔包装器:通过包装器模式扩展模型功能
- 实验友好:每个训练脚本对应一个特定实验
- 研究导向:包含多种前沿研究方向(熵分词、VAE、信念状态等)
6. 核心数据流示例
以train_gpt_vae.py为例:
原始文本 → entropy_based_tokenizer → 词元序列
↓
gpt_vae.py (编码器-解码器) → 潜在表示
↓
autoregressive_wrapper → 自回归生成
↓
重构损失 + KL散度 → 反向传播
这个架构支持快速实验不同的Transformer变体和训练策略,同时保持核心组件的复用性。
x-transformers-main 源码全解析
📄 文件: train_copy.py
代码分析报告
1. 文件功能摘要
这是一个使用XTransformer模型进行序列复制任务的训练脚本,模型学习将输入序列复制到输出序列。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| XTransformer | 一种Transformer架构的变体,包含编码器和解码器,用于序列到序列的任务 |
| LayerNorm | 层归一化,对神经网络层的输出进行标准化,提高训练稳定性 |
| Einsum | Einstein求和约定,用于简洁表达张量运算(如矩阵乘法、转置等) |
| Residual | 残差连接,将输入直接加到输出上,缓解梯度消失问题 |
| Attention | 注意力机制,让模型关注输入序列中不同部分的重要性 |
| CogSigned | 可能指代某种特定的注意力机制变体(具体实现需查看x_transformers库) |
| Token Embedding | 词嵌入,将离散的token转换为连续的向量表示 |
| Mask | 掩码,用于在训练时屏蔽某些位置(如填充位置或未来位置) |
| Backward | 反向传播,计算损失函数对模型参数的梯度 |
| Optimizer | 优化器,根据梯度更新模型参数以最小化损失 |
3. 代码逐行/逐块注释
import tqdm
import torch
import torch.optim as optim
from x_transformers import XTransformer
# 常量定义
NUM_BATCHES = int(1e5) # 总训练批次:100,000
BATCH_SIZE = 32 # 每批样本数:32
LEARNING_RATE = 3e-4 # 学习率:0.0003
GENERATE_EVERY = 100 # 每100批次生成一次预测结果
NUM_TOKENS = 16 + 2 # token数量:18(16个数据token + 2个特殊token)
ENC_SEQ_LEN = 32 # 编码器输入序列长度:32
DEC_SEQ_LEN = 64 + 1 # 解码器输出序列长度:65(比输入长一倍+1)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # 自动选择设备
# 数据生成辅助函数
def cycle():
"""
无限循环生成训练数据
生成格式:源序列 + 目标序列(前缀1 + 源序列 + 源序列)
"""
while True:
# 创建前缀:形状为(BATCH_SIZE, 1),所有值为1
prefix = torch.ones((BATCH_SIZE, 1)).long().to(DEVICE)
# 生成源序列:从2到NUM_TOKENS-1的随机整数
# 形状:(BATCH_SIZE, ENC_SEQ_LEN) = (32, 32)
src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().to(DEVICE)
# 创建目标序列:拼接[前缀, 源序列, 源序列]
# 形状:(BATCH_SIZE, 1+32+32) = (32, 65)
tgt = torch.cat((prefix, src, src), 1)
# 创建源序列掩码:全为True,表示所有位置都参与计算
src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().to(DEVICE)
yield (src, tgt, src_mask) # 返回(源序列, 目标序列, 掩码)
# 实例化模型
model = XTransformer(
dim = 128, # 模型维度:128
tie_token_emb = True, # 编码器和解码器共享词嵌入权重
return_tgt_loss = True, # 返回目标序列的损失
enc_num_tokens=NUM_TOKENS, # 编码器token数:18
enc_depth = 3, # 编码器层数:3
enc_heads = 8, # 编码器注意力头数:8
enc_max_seq_len = ENC_SEQ_LEN, # 编码器最大序列长度:32
enc_attn_cog_signed = True, # 编码器使用cog_signed注意力
dec_num_tokens = NUM_TOKENS, # 解码器token数:18
dec_depth = 3, # 解码器层数:3
dec_heads = 8, # 解码器注意力头数:8
dec_max_seq_len = DEC_SEQ_LEN, # 解码器最大序列长度:65
dec_attn_cog_signed = True # 解码器使用cog_signed注意力
).to(DEVICE) # 将模型移动到指定设备
# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train() # 设置为训练模式
# 获取一批训练数据
src, tgt, src_mask = next(cycle())
# 前向传播:计算损失
# src形状: (32, 32), tgt形状: (32, 65), src_mask形状: (32, 32)
loss = model(src, tgt, mask=src_mask)
# 反向传播:计算梯度
loss.backward()
print(f'{i}: {loss.item()}') # 打印当前损失
# 优化步骤:更新参数
optim.step()
optim.zero_grad() # 清空梯度
# 每100批次进行一次预测生成
if i != 0 and i % GENERATE_EVERY == 0:
model.eval() # 设置为评估模式
# 获取新的测试数据(只取第一个样本)
src, _, src_mask = next(cycle())
src, src_mask = src[:1], src_mask[:1] # 只保留第一个样本
# 创建起始token:值为1,形状(1, 1)
start_tokens = (torch.ones((1, 1)) * 1).long().to(DEVICE)
# 生成预测序列
# 输入:源序列(1, 32),起始token(1, 1),生成长度32
sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask=src_mask)
# 计算错误数:比较源序列和生成序列
incorrects = (src != sample).long().abs().sum()
# 打印结果
print(f"input: ", src)
print(f"predicted output: ", sample)
print(f"incorrects: {incorrects}")
关键张量维度变化说明:
-
数据生成阶段:
src: (32, 32) → 32个样本,每个样本32个tokentgt: (32, 65) → 前缀(1) + 源序列(32) + 源序列(32)src_mask: (32, 32) → 与src形状相同的全1掩码
-
模型输入/输出:
- 编码器输入:
src→ (32, 32) - 解码器输入:
tgt→ (32, 65) - 模型输出损失: 标量值
- 编码器输入:
-
生成阶段:
- 测试输入:
src[:1]→ (1, 32) - 起始token:
start_tokens→ (1, 1) - 生成输出:
sample→ (1, 32)
- 测试输入:
任务特点:
这是一个序列复制任务,模型需要学习将输入序列复制两次(中间有分隔符1)。目标序列的结构是:[1] + 输入序列 + 输入序列。模型在解码时从起始token(1)开始,生成与输入序列相同的内容。
📄 文件: train_with_muon.py
代码分析报告
1. 文件功能摘要
这是一个使用MuonAdamAtan2优化器训练基于Transformer的自回归语言模型的脚本,用于在enwik8数据集上进行字符级文本生成任务。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| TransformerWrapper | x-transformers库中的包装器,用于简化Transformer模型的构建 |
| Decoder | Transformer的解码器部分,用于自回归生成任务 |
| AutoregressiveWrapper | 自回归包装器,为模型添加自回归训练和生成功能 |
| MuonAdamAtan2 | 一种特殊的优化器,使用atan2函数和muon参数进行梯度更新 |
| Rotary Position Embedding | 旋转位置编码,一种相对位置编码方法,通过旋转矩阵将位置信息融入注意力机制 |
| Gradient Accumulation | 梯度累积,通过多个小批次累积梯度再更新参数,模拟更大批次训练 |
| KV Cache | 键值缓存,在自回归生成时缓存之前计算的键值对,避免重复计算 |
3. 代码逐行/逐块注释
# /// script
# dependencies = [
# "x-transformers",
# "adam-atan2-pytorch>=0.2.4",
# ]
# ///
# 依赖声明:指定运行所需的外部库
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from adam_atan2_pytorch import MuonAdamAtan2 # 导入特殊的优化器
# 常量定义
NUM_BATCHES = int(1e5) # 总训练批次:100,000
BATCH_SIZE = 4 # 批次大小:4个序列
GRADIENT_ACCUMULATE_EVERY = 4 # 梯度累积步数:每4步更新一次参数
LEARNING_RATE = 1e-4 # 学习率:0.0001
VALIDATE_EVERY = 100 # 验证频率:每100批次验证一次
GENERATE_EVERY = 500 # 生成频率:每500批次生成文本
GENERATE_LENGTH = 1024 # 生成文本长度:1024个字符
SEQ_LEN = 1024 # 序列长度:1024个字符
# 辅助函数
def cycle(loader):
"""创建无限循环的数据加载器"""
while True:
for data in loader:
yield data
def decode_token(token):
"""将token解码为字符"""
return str(chr(max(32, token))) # 确保字符值≥32(可打印字符)
def decode_tokens(tokens):
"""将token序列解码为字符串"""
return ''.join(list(map(decode_token, tokens)))
# 实例化GPT风格的解码器模型
model = TransformerWrapper(
num_tokens = 256, # 词汇表大小:256个字符(ASCII扩展)
max_seq_len = SEQ_LEN, # 最大序列长度:1024
attn_layers = Decoder(
dim = 512, # 模型维度:512
depth = 6, # Transformer层数:6
heads = 8, # 注意力头数:8
rotary_pos_emb = True # 使用旋转位置编码
)
)
# 包装模型以支持自回归训练和生成
ar_wrapper = AutoregressiveWrapper(model)
model.cuda() # 将模型移动到GPU
# 准备enwik8数据集
with gzip.open('./data/enwik8.gz') as file:
# 读取前95MB数据,转换为numpy数组
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 分割数据集:前90MB训练,后5MB验证
train_x, valid_x = np.split(data, [int(90e6)])
# 转换为PyTorch张量
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
class TextSamplerDataset(Dataset):
"""文本采样数据集类"""
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 完整文本数据
self.seq_len = seq_len # 序列长度
def __getitem__(self, index):
# 随机选择起始位置,确保有足够的长度
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 提取seq_len+1长度的序列(+1用于预测下一个token)
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda() # 移动到GPU
def __len__(self):
# 数据集大小 = 总数据长度 // 序列长度
return self.data.size(0) // self.seq_len
# 创建训练和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建无限循环的数据加载器
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 优化器配置
optim = MuonAdamAtan2(
muon_params = model.muon_parameters(), # 获取模型的muon参数
params = model.parameters(), # 所有模型参数
remove_muon_params_from_params = True, # 从普通参数中移除muon参数
lr = LEARNING_RATE # 学习率
)
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train() # 设置为训练模式
# 梯度累积循环
for __ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取下一个批次并计算损失
loss = ar_wrapper(next(train_loader))
# 梯度归一化并反向传播
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# 更新参数
optim.step()
# 清空梯度
optim.zero_grad()
# 验证
if i % VALIDATE_EVERY == 0:
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
loss = ar_wrapper(next(val_loader))
print(f'validation loss: {loss.item()}')
# 文本生成
if i % GENERATE_EVERY == 0:
model.eval()
# 从验证集中随机选择一个序列作为提示(去掉最后一个token)
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
# 自回归生成文本
sample = ar_wrapper.generate(
prompts = inp, # 输入提示
seq_len = GENERATE_LENGTH, # 生成长度
cache_kv = True # 启用KV缓存加速生成
)
output_str = decode_tokens(sample)
print(output_str)
关键代码逻辑说明
1. 数据流维度变化
- 原始数据:
(95,000,000,)一维字节数组 - 训练数据:
(90,000,000,),验证数据:(5,000,000,) - 每个批次:
(BATCH_SIZE=4, SEQ_LEN+1=1025),其中+1用于下一个token预测
2. 模型架构
- 输入:
(4, 1024)的token序列 - 经过TransformerWrapper:嵌入层 → 位置编码 → 6层Decoder
- 输出:
(4, 1024, 256)的logits(每个位置对256个字符的预测)
3. 训练策略
- 梯度累积:每4个批次才更新一次参数,相当于有效批次大小为16
- 梯度裁剪:限制梯度范数不超过0.5,防止训练不稳定
- 混合精度训练:通过
AutoregressiveWrapper自动处理
4. 生成过程
- 使用自回归方式逐个生成字符
cache_kv=True:缓存注意力机制的键值对,避免重复计算- 从验证集随机采样提示,展示模型生成能力
📄 文件: train_entropy_tokenizer.py
代码分析报告
1. 文件功能摘要
这个文件用于训练一个基于熵的tokenizer,结合Transformer模型在enwik8数据集上进行自回归语言建模训练,并定期生成文本样本。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| TransformerWrapper | 一个包装器类,为Transformer模型添加词嵌入和位置编码等基础组件 |
| Decoder | Transformer的解码器部分,通常用于自回归生成任务 |
| AutoregressiveWrapper | 将模型包装为自回归模型,用于序列生成任务 |
| EntropyBasedTokenizer | 基于熵的tokenizer,根据信息熵动态分割文本为token |
| LayerNorm | 层归一化,对每个样本的特征进行归一化,提高训练稳定性 |
| Einsum | Einstein求和约定,用于简洁表达张量运算(如矩阵乘法、转置等) |
| Residual | 残差连接,将输入直接加到输出上,缓解梯度消失问题 |
| Rotary Position Embedding | 旋转位置编码,通过旋转矩阵将位置信息融入注意力机制 |
| Gradient Accumulation | 梯度累积,通过多次前向传播累积梯度再更新参数,模拟更大batch size |
| Gradient Clipping | 梯度裁剪,防止梯度爆炸,限制梯度范数 |
3. 代码逐行/逐块注释
# 导入必要的库
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5) # 总训练批次:100,000
BATCH_SIZE = 4 # 每批样本数:4
GRADIENT_ACCUMULATE_EVERY = 4 # 梯度累积步数:每4步更新一次参数
LEARNING_RATE = 1e-4 # 学习率:0.0001
VALIDATE_EVERY = 100 # 验证频率:每100批验证一次
GENERATE_EVERY = 100 # 生成频率:每100批生成一次文本
GENERATE_LENGTH = 1024 # 生成文本长度:1024个token
SEQ_LEN = 1024 # 序列长度:1024
# 辅助函数
def cycle(loader):
"""创建无限循环的数据加载器,确保训练不会因数据耗尽而停止"""
while True:
for data in loader:
yield data
def decode_token(token):
"""将token ID转换为可读字符(ASCII码转换)"""
return str(chr(max(32, token))) # 确保token值≥32(可打印字符)
def decode_tokens(tokens):
"""将token序列转换为字符串"""
return ''.join(list(map(decode_token, tokens)))
# 实例化GPT风格的解码器模型
model = TransformerWrapper(
num_tokens = 256, # 词汇表大小:256(对应ASCII字符)
max_seq_len = SEQ_LEN, # 最大序列长度:1024
attn_layers = Decoder(
dim = 512, # 模型维度:512
depth = 6, # Transformer层数:6
heads = 8, # 注意力头数:8
rotary_pos_emb = True # 使用旋转位置编码
)
)
# 实例化基于熵的tokenizer
tokenizer = EntropyBasedTokenizer(
model, # 使用的模型
entropy_threshold = 2.5 # 熵阈值:决定何时分割token
)
# 将模型包装为自回归模型并移到GPU
model = AutoregressiveWrapper(model) # 添加自回归训练功能
model.cuda() # 将模型移到CUDA设备(GPU)
# 准备enwik8数据集
with gzip.open('./data/enwik8.gz') as file:
# 读取前95MB数据,转换为numpy数组
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 分割数据集:前90MB训练,后5MB验证
train_x, valid_x = np.split(data, [int(90e6)])
# 转换为PyTorch张量
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
class TextSamplerDataset(Dataset):
"""自定义数据集类,用于从长文本中随机采样固定长度序列"""
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 原始数据张量
self.seq_len = seq_len # 序列长度
def __getitem__(self, index):
# 随机选择起始位置,确保有足够的长度
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 提取seq_len+1长度的序列(+1用于预测下一个token)
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda() # 返回GPU上的张量
def __len__(self):
# 数据集大小 = 总数据长度 // 序列长度
return self.data.size(0) // self.seq_len
# 创建训练和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建无限循环的数据加载器
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train() # 设置为训练模式
# 梯度累积:多次前向传播,累积梯度
for __ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取下一个batch,计算损失
loss = model(next(train_loader)) # 输入形状: (BATCH_SIZE, SEQ_LEN+1)
# 梯度归一化:除以累积步数,模拟更大batch size
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
# 梯度裁剪:限制梯度范数不超过0.5,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# 参数更新
optim.step() # 根据梯度更新参数
optim.zero_grad() # 清空梯度
# 定期验证
if i % VALIDATE_EVERY == 0:
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
# 定期生成文本样本
if i % GENERATE_EVERY == 0:
model.eval()
# 从验证集随机选择一个序列(去掉最后一个token作为输入)
inp = random.choice(val_dataset)[:-1]
# 使用tokenizer分割输入序列
# return_segmented_seq=True 返回分割后的token列表
tokens = tokenizer(inp, return_segmented_seq = True)
# 使用分隔符连接各个token的字符串表示
delimiter = " \u275A " # Unicode分隔符:❚
output_str = delimiter.join([decode_tokens(token) for token in tokens])
print(f"{output_str}\n\n")
关键张量维度变化说明
-
数据加载阶段:
- 原始数据:
(SEQ_LEN+1,)一维张量 - 批量处理:
(BATCH_SIZE, SEQ_LEN+1)二维张量
- 原始数据:
-
模型输入/输出:
- 输入形状:
(BATCH_SIZE, SEQ_LEN) - 输出形状:
(BATCH_SIZE, SEQ_LEN, num_tokens),其中num_tokens=256 - 损失计算:模型内部将输入右移一位作为目标,计算交叉熵损失
- 输入形状:
-
EntropyBasedTokenizer工作原理:
- 基于模型预测的熵值动态分割序列
- 高熵区域(不确定性高)可能被分割为更小的token
- 低熵区域(确定性高)可能被合并为更大的token
训练流程总结
- 加载enwik8数据集(维基百科压缩数据)
- 构建Transformer解码器模型
- 使用基于熵的tokenizer动态分割文本
- 通过梯度累积训练模型
- 定期验证模型性能
- 定期生成文本样本展示训练效果
📄 文件: train_parity.py
1. 文件功能摘要
这个文件用于训练一个Transformer模型,通过课程学习(curriculum learning)策略来解决二进制奇偶性(parity)问题,即预测二进制序列中每个位置之前所有位的奇偶性(模2求和)。
2. 核心术语解释
- LayerNorm:层归一化,对神经网络中某一层的所有神经元输出进行归一化处理,加速训练并提高稳定性。
- Einsum:爱因斯坦求和约定,一种简洁的张量运算表示法,用于实现复杂的张量乘法、转置、求和等操作。
- Residual:残差连接,将层的输入直接加到输出上,有助于缓解深度网络中的梯度消失问题。
- TransformerWrapper:一个包装器,用于将词嵌入层和Transformer层组合在一起,方便构建完整的Transformer模型。
- Decoder:解码器,这里指Transformer的解码器层,通常用于自回归生成任务。
- Cross Entropy Loss:交叉熵损失,用于分类任务,衡量模型预测概率分布与真实标签之间的差异。
- Curriculum Learning:课程学习,一种训练策略,从简单样本开始,逐步增加难度,帮助模型更好地学习复杂模式。
- Hybridization:混合模型,这里指将Transformer与RNN(GRU)结合,以增强模型的状态跟踪能力。
- Gradient Clipping:梯度裁剪,限制梯度的大小,防止训练过程中梯度爆炸。
3. 代码逐行/逐块注释
import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
from x_transformers import TransformerWrapper, Decoder
# 常量定义
BATCH_SIZE = 256
LEARNING_RATE = 3e-4
EVAL_EVERY = 500 # 每500步评估一次
EVAL_LENGTHS = (16, 32, 64, 128, 256, 512) # 评估时使用的序列长度
TRAIN_MAX_LENGTH = EVAL_LENGTHS[-2] # 训练最大长度 = 256
LOSS_THRES_INCREASE_LEN = 1e-3 # 损失阈值,低于此值可考虑增加序列长度
MEET_CRITERIA_THRES_INCREASE_LEN = 10 # 连续满足条件的次数阈值
HYBRIDIZE_WITH_RNN = True # 是否与RNN混合
# 模型参数
dim = 64 # 模型维度
heads = 4 # 注意力头数
dim_head = 32 # 每个注意力头的维度
decoder_kwargs = dict() # 解码器参数字典
# 如果启用RNN混合
if HYBRIDIZE_WITH_RNN:
from torch.nn import GRU
decoder_kwargs = dict(
attn_hybrid_fold_axial_dim = 4, # 每4个token进行一次循环,有助于泛化奇偶性任务
attn_hybrid_learned_mix = True, # 学习混合权重
attn_hybrid_module = GRU(dim, dim_head * heads, batch_first = True) # GRU模块,输入dim,输出dim_head*heads
)
# 实例化模型
model = TransformerWrapper(
num_tokens = 2, # 词汇表大小,二进制只有0和1
max_seq_len = 0, # 0表示不使用绝对位置编码
attn_layers = Decoder(
dim = dim,
depth = 3, # 3层解码器
heads = heads,
attn_dim_head = dim_head,
shift_tokens = 1, # 移位token,有助于奇偶性训练,但单独使用无法泛化
**decoder_kwargs # 传入混合RNN参数
)
).cuda() # 将模型移到GPU
# 优化器
from lion_pytorch.cautious_lion import Lion
optimizer = Lion(model.parameters(), lr = LEARNING_RATE, cautious_factor = 0.1)
# 数据生成器
def cycle(length):
while True:
seq = torch.randint(0, 2, (BATCH_SIZE, length)).cuda() # 生成随机二进制序列
labels = (seq.cumsum(dim = -1) % 2) # 计算累积和模2,得到每个位置的奇偶性标签
yield (seq, labels)
# 数据加载器
train_dl = cycle(TRAIN_MAX_LENGTH) # 训练数据生成器,最大长度256
eval_dls = {eval_length: cycle(eval_length) for eval_length in EVAL_LENGTHS} # 不同长度的评估数据生成器
print(f'training at max length: {TRAIN_MAX_LENGTH}')
# 训练循环
i = 0 # 训练步数计数器
meet_criteria = 0 # 连续满足条件的计数器
train_seq_len = 1 # 当前训练序列长度,从1开始
stop_length = EVAL_LENGTHS[-2] # 停止长度 = 256
with tqdm.tqdm(mininterval = 10., desc = 'training') as pbar:
while train_seq_len < stop_length:
model.train()
seq, labels = next(train_dl) # 获取一批训练数据
# 课程学习:截取当前训练长度的序列
seq = seq[:, :train_seq_len]
labels = labels[:, :train_seq_len]
logits = model(seq) # 前向传播,得到预测logits
# 计算损失
loss = F.cross_entropy(logits.transpose(-1, -2), labels, reduction = 'none')
last_loss = loss[:, -1].mean() # 只关注最后一个位置的损失
loss.mean().backward() # 反向传播
# 检查是否满足增加序列长度的条件
if last_loss.item() < LOSS_THRES_INCREASE_LEN:
meet_criteria += 1
else:
meet_criteria = 0
# 如果连续满足条件达到阈值,增加训练序列长度
if meet_criteria >= MEET_CRITERIA_THRES_INCREASE_LEN:
meet_criteria = 0
train_seq_len += 1
print(f'criteria met, incrementing to {train_seq_len}')
print(f'({train_seq_len})| {i}: {last_loss.item()}')
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 梯度裁剪,防止梯度爆炸
optimizer.step() # 更新参数
optimizer.zero_grad() # 清零梯度
last_step = train_seq_len == stop_length # 检查是否达到最终训练长度
if last_step:
print(f'made it to training length {train_seq_len}. running final eval to check for generalization')
# 定期评估或最后一步评估
if last_step or (i + 1) % EVAL_EVERY == 0:
model.eval()
print('\n')
# 在不同长度上评估模型
for eval_length, eval_dl in eval_dls.items():
incorrects = 0
seq, labels = next(eval_dl)
logits = model(seq)
pred = logits[:, -1].argmax(dim = -1) # 取最后一个位置的预测
incorrects = (pred != labels[:, -1]).abs().sum().item() # 计算错误数
frac_incorrect = incorrects * 100 / BATCH_SIZE # 计算错误率
print(f"{eval_length}\t - frac incorrect:\t {frac_incorrect:.1f}%")
print('\n')
i += 1
pbar.update(1) # 更新进度条
关键代码解释:
-
数据生成:
labels = (seq.cumsum(dim = -1) % 2)cumsum(dim=-1):沿序列维度计算累积和。% 2:取模2,得到奇偶性(0表示偶数,1表示奇数)。
-
课程学习策略:
- 训练从长度为1的序列开始。
- 当模型在最后一个位置的损失连续10次低于阈值(1e-3)时,将训练序列长度增加1。
- 逐步增加难度,直到达到目标长度256。
-
损失计算:
loss = F.cross_entropy(logits.transpose(-1, -2), labels, reduction = 'none')logits形状为(batch_size, seq_len, 2),需要转置为(batch_size, 2, seq_len)以适应交叉熵损失函数的输入要求。reduction='none':不进行降维,返回每个位置的损失。
-
混合模型:
- 当
HYBRIDIZE_WITH_RNN=True时,Transformer会与GRU结合。 attn_hybrid_fold_axial_dim=4:每4个token进行一次循环处理,有助于模型学习奇偶性模式。- 这种混合设计旨在增强模型的状态跟踪能力,解决纯Transformer在奇偶性任务上的泛化问题。
- 当
📄 文件: train_enwik8.py
代码分析报告
1. 文件功能摘要
这是一个使用Transformer模型在enwik8数据集上进行自回归语言模型训练的脚本,实现了GPT风格的文本生成训练流程。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| TransformerWrapper | x-transformers库中的包装器,用于将Transformer核心层与词嵌入、位置编码等组件组合 |
| Decoder | Transformer的解码器部分,通常用于自回归生成任务 |
| AutoregressiveWrapper | 自回归包装器,为模型添加因果掩码和生成功能 |
| rotary_pos_emb | 旋转位置编码,一种相对位置编码方法,能更好地处理长序列 |
| attn_orthog_projected_values | 注意力机制中的正交投影值,可能用于提升注意力机制的稳定性 |
| gradient_accumulate | 梯度累积,通过多次前向传播累积梯度再更新参数,模拟更大批次训练 |
| enwik8 | 一个包含1亿字节维基百科文本的数据集,常用于语言模型基准测试 |
3. 代码逐行/逐块注释
# /// script
# dependencies = [
# "tqdm",
# "x-transformers",
# "wandb"
# ]
# ///
# 这是一个特殊的注释格式,可能用于某些脚本运行器自动安装依赖
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 导入必要的库:
# - x_transformers: 提供Transformer模型实现
# - tqdm: 进度条显示
# - gzip: 处理压缩文件
# - torch: 深度学习框架
# 训练超参数配置
NUM_BATCHES = int(1e5) # 总训练批次:100,000
BATCH_SIZE = 4 # 批次大小:4个序列
GRADIENT_ACCUMULATE_EVERY = 4 # 梯度累积步数:每4步更新一次参数
LEARNING_RATE = 1e-4 # 学习率:0.0001
VALIDATE_EVERY = 100 # 每100批次验证一次
GENERATE_EVERY = 500 # 每500批次生成一次文本
GENERATE_LENGTH = 1024 # 生成文本长度:1024个token
SEQ_LEN = 1024 # 序列长度:1024个token
TRACK_EXPERIMENT_ONLINE = False # 是否在线跟踪实验(WandB)
# 辅助函数定义
def cycle(loader):
"""创建无限循环的数据加载器"""
while True:
for data in loader:
yield data
def decode_token(token):
"""将token解码为字符(ASCII范围32-255)"""
return str(chr(max(32, token))) # 确保token在可打印ASCII范围内
def decode_tokens(tokens):
"""将token序列解码为字符串"""
return ''.join(list(map(decode_token, tokens)))
# 实例化GPT风格的解码器模型
model = TransformerWrapper(
num_tokens = 256, # 词汇表大小:256(对应ASCII字符)
max_seq_len = SEQ_LEN, # 最大序列长度:1024
attn_layers = Decoder(
dim = 512, # 模型维度:512
depth = 6, # Transformer层数:6
heads = 8, # 注意力头数:8
rotary_pos_emb = True, # 使用旋转位置编码
attn_orthog_projected_values = True, # 使用正交投影值
attn_orthog_projected_values_per_head = True # 每个头单独正交投影
)
)
# 添加自回归包装器(添加因果掩码和生成功能)
model = AutoregressiveWrapper(model)
model.cuda() # 将模型移动到GPU
# 准备enwik8数据
with gzip.open('./data/enwik8.gz') as file:
# 读取前95MB数据(enwik8总大小100MB)
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 分割数据集:前90MB训练,后5MB验证
train_x, valid_x = np.split(data, [int(90e6)])
# 转换为PyTorch张量
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
class TextSamplerDataset(Dataset):
"""文本采样数据集,用于随机截取序列"""
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 完整文本数据
self.seq_len = seq_len # 序列长度
def __getitem__(self, index):
# 随机选择起始位置(确保有足够的后续token)
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 截取seq_len+1长度的序列(+1是为了包含预测目标)
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda() # 移动到GPU
def __len__(self):
# 数据集大小 = 总数据长度 // 序列长度
return self.data.size(0) // self.seq_len
# 创建训练和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建无限循环的数据加载器
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 优化器配置
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 实验跟踪(WandB)
import wandb
wandb.init(project = 'enwik8', mode = 'online' if TRACK_EXPERIMENT_ONLINE else 'disabled')
wandb.run.name = 'baseline' # 实验名称
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train() # 设置为训练模式
# 梯度累积循环
for __ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个批次数据并计算损失
loss = model(next(train_loader))
# 梯度累积:将损失除以累积步数后反向传播
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 记录训练损失
print(f'training loss: {loss.item()}')
wandb.log(dict(loss = loss.item()))
# 梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# 参数更新
optim.step()
optim.zero_grad() # 清空梯度
# 定期验证
if i % VALIDATE_EVERY == 0:
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
wandb.log(dict(valid_loss = loss.item()))
# 定期生成文本
if i % GENERATE_EVERY == 0:
model.eval()
# 从验证集中随机选择一个序列作为生成起点(去掉最后一个token)
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp) # 解码为字符串
print(f'%s \n\n %s', (prime, '*' * 100)) # 打印输入文本和分隔符
# 生成新文本
sample = model.generate(
prompts = inp, # 输入提示
seq_len = GENERATE_LENGTH, # 生成长度
cache_kv = True # 缓存键值对以加速生成
)
# 解码并输出生成的文本
output_str = decode_tokens(sample)
print(output_str)
关键流程说明
-
数据流:
- 原始字节数据 → NumPy数组 → PyTorch张量 → GPU张量
- 每个训练样本形状:[SEQ_LEN+1],其中前SEQ_LEN个token是输入,最后一个token是预测目标
-
梯度累积机制:
- 实际批次大小 = BATCH_SIZE × GRADIENT_ACCUMULATE_EVERY = 4 × 4 = 16
- 每4次前向传播累积梯度,然后进行一次参数更新
-
文本生成流程:
- 从验证集随机选取起始序列
- 使用自回归方式逐个生成token
- 将生成的token序列解码为可读文本
-
模型特点:
- 使用旋转位置编码,适合处理长序列
- 正交投影注意力值,可能提升训练稳定性
- 因果掩码确保自回归性质(只能看到左侧token)
📄 文件: train_length_extrapolate.py
代码分析报告
1. 文件功能摘要
这是一个用于训练Transformer模型进行长度外推(length extrapolation)实验的训练脚本,使用enwik8数据集训练一个类似GPT的自回归语言模型,并在不同序列长度上验证模型性能。
2. 核心术语解释
- TransformerWrapper: x_transformers库中的包装器,用于封装Transformer模型的基本配置
- Decoder: Transformer的解码器部分,通常用于自回归生成任务
- AutoregressiveWrapper: 自回归包装器,为模型添加自回归训练和生成功能
- LayerNorm: 层归一化,用于稳定神经网络训练
- Einsum: Einstein求和约定,用于高效的张量运算
- Residual: 残差连接,将输入直接加到输出上,缓解梯度消失问题
- Dynamic Position Bias: 动态位置偏置,一种相对位置编码方法
- Gradient Accumulation: 梯度累积,通过多次前向传播累积梯度再更新参数,模拟更大批次训练
- Length Extrapolation: 长度外推,指模型在训练时使用较短序列,但能在推理时处理更长序列的能力
3. 代码逐行/逐块注释
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5) # 总训练批次:100,000
BATCH_SIZE = 4 # 批次大小:4
GRADIENT_ACCUMULATE_EVERY = 4 # 梯度累积步数:每4步更新一次参数
LEARNING_RATE = 1e-4 # 学习率:0.0001
GENERATE_EVERY = 500 # 每500步生成一次文本
GENERATE_LENGTH = 256 # 生成文本长度:256
SEQ_LEN = 256 # 训练序列长度:256
VALIDATE_EVERY = 100 # 每100步验证一次
VALIDATE_SEQ_LENS = (256, 512, 1024, 2048, 4096) # 验证时使用的不同序列长度
# 辅助函数
def cycle(loader):
"""创建无限循环的数据加载器"""
while True:
for data in loader:
yield data
def decode_token(token):
"""将token解码为字符(ASCII码转换)"""
return str(chr(max(32, token))) # 确保字符可打印(ASCII >= 32)
def decode_tokens(tokens):
"""将token序列解码为字符串"""
return ''.join(list(map(decode_token, tokens)))
# 实例化GPT-like解码器模型
model = TransformerWrapper(
num_tokens = 256, # 词汇表大小:256(对应ASCII字符)
max_seq_len = SEQ_LEN, # 最大序列长度:256
use_abs_pos_emb = False, # 不使用绝对位置编码
attn_layers = Decoder(
dim = 512, # 模型维度:512
depth = 6, # Transformer层数:6
heads = 8, # 注意力头数:8
dynamic_pos_bias = True, # 使用动态位置偏置(相对位置编码)
)
)
model = AutoregressiveWrapper(model) # 包装为自回归模型
model.cuda() # 将模型移动到GPU
# 准备enwik8数据
with gzip.open('./data/enwik8.gz') as file:
# 读取前95MB数据,enwik8是维基百科的原始文本数据集
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 分割训练集和验证集:前90MB训练,后5MB验证
train_x, valid_x = np.split(data, [int(90e6)])
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
class TextSamplerDataset(Dataset):
"""文本采样数据集类"""
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 原始数据张量
self.seq_len = seq_len # 序列长度
def __getitem__(self, index):
# 随机选择起始位置,确保有足够的长度
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 提取seq_len+1长度的序列(+1是为了包含目标token)
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda() # 移动到GPU
def __len__(self):
# 数据集大小为数据长度除以序列长度
return self.data.size(0) // self.seq_len
# 创建训练数据集和加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 创建验证数据集(用于文本生成)
val_dataset_generate = TextSamplerDataset(data_val, SEQ_LEN)
# 为不同序列长度创建验证加载器
val_loaders = dict()
for valid_seq_len in VALIDATE_SEQ_LENS:
val_dataset = TextSamplerDataset(data_val, valid_seq_len)
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loaders[valid_seq_len] = val_loader # 存储不同长度的加载器
# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train() # 设置为训练模式
# 梯度累积循环
for __ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个批次数据并计算损失
# 注意:AutoregressiveWrapper会自动处理输入输出,计算交叉熵损失
loss = model(next(train_loader))
# 累积梯度,除以累积步数进行归一化
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step() # 更新参数
optim.zero_grad() # 清空梯度
# 定期验证
if i % VALIDATE_EVERY == 0:
print(f'validation losses:\n')
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
for valid_seq_len in VALIDATE_SEQ_LENS:
val_loader = val_loaders[valid_seq_len]
# 在不同序列长度上计算损失
loss = model(next(val_loader))
print(f'[{valid_seq_len}]:\t {loss.item()}')
print('\n')
# 定期生成文本
if i % GENERATE_EVERY == 0:
model.eval()
# 从验证集中随机选择一个序列(去掉最后一个token作为prompt)
inp = random.choice(val_dataset_generate)[:-1]
prime = decode_tokens(inp)
print(f'%s \n\n %s', (prime, '*' * 100))
# 生成文本
sample = model.generate(
prompts = inp, # 输入prompt
seq_len = GENERATE_LENGTH, # 生成长度
cache_kv = True # 缓存键值对,加速生成
)
output_str = decode_tokens(sample)
print(f'{output_str}\n\n')
关键代码解释
1. 模型架构
- 使用
TransformerWrapper和Decoder构建6层Transformer - 维度512,8个注意力头,使用动态位置偏置
- 词汇表大小为256,对应ASCII字符
2. 数据流维度变化
- 原始数据:
[total_tokens](一维张量) - 采样后:
[batch_size, seq_len+1](二维张量) - 模型内部:
[batch_size, seq_len, dim](三维张量)
3. 训练策略
- 梯度累积:每4个批次累积梯度后更新一次,相当于使用batch_size=16训练
- 长度外推验证:在训练过程中使用不同长度(256-4096)验证模型泛化能力
- 定期生成:每500步生成文本,监控模型学习进度
4. 实验设计特点
- 训练时使用固定长度256,但验证时测试更长的序列
- 使用动态位置偏置,理论上能更好地处理未见过的序列长度
- 整个实验旨在测试Transformer模型在长度外推任务上的表现
📄 文件: train_gpt_vae.py
代码分析报告
1. 文件功能摘要
这是一个用于训练GPT-VAE(基于变分自编码器的生成式预训练变换器)模型的训练脚本,主要用于文本数据的压缩表示学习和条件文本生成。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| GPTVAE | 结合了GPT(生成式预训练变换器)和VAE(变分自编码器)的混合模型,既能生成文本又能学习数据的潜在表示 |
| VAE (Variational Autoencoder) | 变分自编码器,一种生成模型,通过编码器学习数据的潜在分布,通过解码器从潜在变量重建数据 |
| KL散度损失 (KL Divergence Loss) | 衡量学习到的潜在分布与先验分布(通常是标准正态分布)之间的差异,用于正则化潜在空间 |
| 自回归损失 (Autoregressive Loss) | GPT模型的标准损失,基于序列中前一个token预测下一个token的交叉熵损失 |
| 梯度累积 (Gradient Accumulation) | 在内存有限时使用的技术,通过多次前向传播累积梯度,然后一次性更新参数 |
| 潜在变量 (Latents) | VAE学习到的低维连续表示,用于控制生成过程 |
| 旋转位置编码 (Rotary Position Embedding) | 一种相对位置编码方法,通过旋转矩阵将位置信息融入注意力计算 |
3. 代码逐行/逐块注释
from x_transformers.gpt_vae import GPTVAE
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch import tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5) # 总训练批次:100,000
BATCH_SIZE = 4 # 每批样本数:4
GRADIENT_ACCUMULATE_EVERY = 4 # 梯度累积步数:每4步更新一次参数
LEARNING_RATE = 1e-4 # 学习率:0.0001
VALIDATE_EVERY = 100 # 验证频率:每100批验证一次
GENERATE_EVERY = 500 # 生成频率:每500批生成一次文本
GENERATE_LENGTH = 512 # 生成文本长度:512个token
SEQ_LEN = 512 # 序列长度:512个token
# 辅助函数
def cycle(loader):
"""创建无限循环的数据加载器,确保训练过程中数据不会耗尽"""
while True:
for data in loader:
yield data
def decode_token(token):
"""将token解码为字符,确保输出可打印字符(ASCII码≥32)"""
return str(chr(max(32, token)))
def decode_tokens(tokens):
"""将token序列解码为字符串"""
return ''.join(list(map(decode_token, tokens)))
# 实例化GPT-VAE模型
model = GPTVAE(
num_tokens = 256, # 词汇表大小:256(对应ASCII字符)
max_seq_len = SEQ_LEN, # 最大序列长度:512
dim = 512, # 模型维度:512
depth = 6, # 解码器层数:6
heads = 8, # 注意力头数:8
rotary_pos_emb = True, # 使用旋转位置编码
enc_depth = 3, # 编码器层数:3
vae_kl_loss_weight = 1., # KL散度损失权重:1.0
dim_latent = 1 # 潜在空间维度:1(示例中压缩到1维)
).cuda() # 将模型移到GPU
latents = tensor([1.]).cuda() # 创建潜在变量张量,值为1.0,用于控制生成
# 准备enwik8数据集
with gzip.open('./data/enwik8.gz') as file:
# 读取95MB数据,转换为numpy数组
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 分割数据集:前90MB为训练集,后5MB为验证集
train_x, valid_x = np.split(data, [int(90e6)])
# 转换为PyTorch张量
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
class TextSamplerDataset(Dataset):
"""自定义数据集类,用于从长文本中随机采样固定长度的序列"""
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 原始文本数据
self.seq_len = seq_len # 序列长度
def __getitem__(self, index):
# 随机选择起始位置,确保有足够的长度
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 提取seq_len+1长度的序列(+1用于创建输入-目标对)
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda() # 返回GPU上的张量
def __len__(self):
# 数据集大小 = 总数据长度 // 序列长度
return self.data.size(0) // self.seq_len
# 创建数据集和数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train() # 设置为训练模式
# 梯度累积循环
for __ in range(GRADIENT_ACCUMULATE_EVERY):
# 前向传播,获取总损失和各个损失分量
# loss: 总损失 = ar_loss + vae_kl_loss * vae_kl_loss_weight
# ar_loss: 自回归损失(语言建模损失)
# vae_kl_loss: VAE的KL散度损失
loss, (ar_loss, vae_kl_loss) = model(next(train_loader), return_all_losses = True)
# 梯度累积:将损失除以累积步数后反向传播
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 打印训练损失
print(f'training loss: {ar_loss.item():.4f}\t| kl loss: {vae_kl_loss.item():.4f}')
# 梯度裁剪:防止梯度爆炸,限制梯度范数不超过0.5
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step() # 更新模型参数
optim.zero_grad() # 清空梯度
# 验证阶段
if i % VALIDATE_EVERY == 0:
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
loss, (ar_loss, _) = model(next(val_loader), return_all_losses = True)
print(f'validation loss: {ar_loss.item():.4f}')
# 文本生成演示
if i % GENERATE_EVERY == 0:
model.eval()
# 从验证集中随机选择一个序列作为提示(去掉最后一个token)
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp) # 解码为可读文本
print(f'%s \n\n %s', (prime, '*' * 100)) # 打印提示文本
# 使用潜在变量=1.0生成文本
sample = model.generate(
prompts = inp, # 输入提示
seq_len = GENERATE_LENGTH, # 生成长度
cache_kv = True, # 缓存键值对以提高效率
latents = latents # 潜在变量控制生成
)
output_str = decode_tokens(sample) # 解码生成结果
print(f'\n\nlatent {latents.tolist()} - ', output_str) # 打印生成文本
# 使用潜在变量=-1.0生成对比文本
sample_other_direction = model.generate(
prompts = inp,
seq_len = GENERATE_LENGTH,
cache_kv = True,
latents = -latents # 相反方向的潜在变量
)
output_str = decode_tokens(sample_other_direction)
print(f'\n\nlatent {(-latents).tolist()} - ', output_str) # 打印对比生成文本
关键流程说明
数据流维度变化
- 输入数据:
[BATCH_SIZE, SEQ_LEN+1]→ 批大小×序列长度 - 模型内部:
- 编码器将输入压缩为潜在变量:
[BATCH_SIZE, dim_latent](本例中dim_latent=1) - 解码器从潜在变量重建/生成序列:
[BATCH_SIZE, SEQ_LEN, num_tokens]
- 编码器将输入压缩为潜在变量:
- 损失计算:
- 自回归损失: 基于序列中前一个token预测下一个token
- KL散度损失: 衡量潜在分布与标准正态分布的差异
训练策略特点
- 梯度累积: 在内存有限的情况下模拟更大批次训练
- 双损失优化: 同时优化语言建模能力和潜在空间结构
- 条件生成: 通过潜在变量控制文本生成风格/内容
- 定期验证和生成: 监控模型性能并可视化生成效果
潜在空间探索
代码展示了如何通过改变潜在变量值(+1.0和-1.0)来探索潜在空间的不同区域,观察生成文本的变化,这是VAE模型的重要特性。
📄 文件: train_belief_state.py
代码分析报告
1. 文件功能摘要
这是一个使用双向Transformer模型(BeliefStateWrapper)在enwik8文本数据集上进行语言模型训练的脚本,支持前向和后向文本生成。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| TransformerWrapper | 一个封装类,将词嵌入、位置编码等组件与Transformer层组合在一起 |
| Decoder | Transformer的解码器层,包含自注意力机制和前馈网络 |
| BeliefStateWrapper | 一个特殊的包装器,同时管理前向和后向语言模型,支持双向文本生成 |
| AutoregressiveWrapper | 自回归包装器,用于序列生成任务 |
| LayerNorm | 层归一化,对每个样本的特征维度进行归一化,提高训练稳定性 |
| Residual | 残差连接,将输入直接加到输出上,缓解梯度消失问题 |
| Rotary Position Embedding | 旋转位置编码,一种相对位置编码方法,通过旋转矩阵将位置信息融入注意力计算 |
| Gradient Accumulation | 梯度累积,通过多个小批次累积梯度再更新参数,模拟更大批次训练 |
| KV Cache | 键值缓存,在自回归生成时缓存之前计算的键值对,避免重复计算 |
3. 代码逐行/逐块注释
from x_transformers import TransformerWrapper, Decoder, BeliefStateWrapper
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5) # 总训练批次数量
BATCH_SIZE = 2 # 每个批次的样本数
GRADIENT_ACCUMULATE_EVERY = 8 # 梯度累积步数(实际批次大小 = 2×8 = 16)
LEARNING_RATE = 1e-4 # 学习率
VALIDATE_EVERY = 100 # 每100个批次验证一次
GENERATE_EVERY = 500 # 每500个批次生成文本一次
GENERATE_LENGTH = 256 # 生成文本的长度
SEQ_LEN = 256 # 输入序列长度
FORWARD_BACKWARD_SAME_MODEL = True # 是否使用相同的模型进行前向和后向生成
# 辅助函数
def cycle(loader):
"""创建一个无限循环的数据加载器"""
while True:
for data in loader:
yield data
def decode_token(token):
"""将token解码为字符(ASCII码转字符)"""
return str(chr(max(32, token))) # 确保token值≥32(可打印字符)
def decode_tokens(tokens):
"""将token序列解码为字符串"""
return ''.join(list(map(decode_token, tokens)))
# 实例化GPT风格的前向解码器模型
forward_model = TransformerWrapper(
num_tokens = 256, # 词汇表大小(256个ASCII字符)
max_seq_len = SEQ_LEN, # 最大序列长度
attn_layers = Decoder(
dim = 512, # 模型维度
depth = 6, # Transformer层数
heads = 8, # 注意力头数
rotary_pos_emb = True # 使用旋转位置编码
)
)
backward_model = None
# 如果不使用相同模型,则创建单独的后向模型
if not FORWARD_BACKWARD_SAME_MODEL:
backward_model = TransformerWrapper(
num_tokens = 256,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(
dim = 512,
depth = 4, # 后向模型使用更少的层(4层 vs 6层)
heads = 8,
rotary_pos_emb = True
)
)
# 创建BeliefStateWrapper,包装前向和后向模型
model = BeliefStateWrapper(
forward_decoder = forward_model, # 前向语言模型
backward_decoder = backward_model # 后向语言模型(可能为None)
)
model.cuda() # 将模型移动到GPU
# 准备enwik8数据集
with gzip.open('./data/enwik8.gz') as file:
# 读取95MB数据,转换为numpy数组
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 分割训练集(90MB)和验证集(5MB)
train_x, valid_x = np.split(data, [int(90e6)])
# 转换为PyTorch张量
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
class TextSamplerDataset(Dataset):
"""文本采样数据集,用于随机抽取固定长度的文本片段"""
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 完整文本数据
self.seq_len = seq_len # 序列长度
def __getitem__(self, index):
# 随机选择起始位置
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 提取seq_len+1长度的序列(+1是为了包含标签)
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda() # 移动到GPU
def __len__(self):
# 数据集大小 = 总数据长度 // 序列长度
return self.data.size(0) // self.seq_len
# 创建训练和验证数据集及数据加载器
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
model.train() # 设置为训练模式
# 梯度累积循环
for __ in range(GRADIENT_ACCUMULATE_EVERY):
# 获取一个批次的数据并计算损失
loss = model(next(train_loader))
# 梯度累积:将损失除以累积步数后反向传播
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
print(f'training loss: {loss.item()}')
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# 更新参数并清零梯度
optim.step()
optim.zero_grad()
# 验证
if i % VALIDATE_EVERY == 0:
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
loss = model(next(val_loader))
print(f'validation loss: {loss.item()}')
# 文本生成
if i % GENERATE_EVERY == 0:
model.eval()
# 从验证集中随机选择一个序列(去掉最后一个token作为prompt)
inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp) # 解码prompt
print(f'%s \n\n %s', (prime, '*' * 100))
print('forwards:\n')
# 前向生成文本
sample = model.generate_with_suffix_cond(
prompts = inp, # 输入prompt
seq_len = GENERATE_LENGTH, # 生成长度
cache_kv = True # 启用KV缓存加速生成
)
output_str = decode_tokens(sample)
print(output_str)
print('\nbackwards:\n')
# 后向生成文本(从prompt开始向左生成)
sample = model.generate_with_suffix_cond(
prompts = inp,
seq_len = GENERATE_LENGTH,
cache_kv = True,
decode_backwards = True # 启用后向生成
)
# 翻转生成的序列(因为后向生成是从右向左的)
output_str = decode_tokens(sample.flip(0))
print(output_str)
关键代码逻辑说明
1. BeliefStateWrapper的工作原理
- 同时训练前向和后向语言模型
- 前向模型预测下一个token:P(xt | x<t)
- 后向模型预测前一个token:P(xt | x>t)
- 支持双向文本生成能力
2. 梯度累积的实现
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
- 连续处理8个小批次(每个批次2个样本)
- 每次反向传播前将损失除以8
- 累积梯度,相当于使用16个样本的大批次进行更新
3. 文本生成的特殊处理
- 前向生成:从prompt开始向右生成文本
- 后向生成:从prompt开始向左生成文本,生成后需要
flip(0)翻转序列 - KV缓存:缓存之前计算的键值对,避免重复计算,显著加速自回归生成
4. 数据流维度变化
- 原始数据:
[95,000,000]字节数组 - 训练批次:
[BATCH_SIZE, SEQ_LEN+1]=[2, 257] - 模型输入:
[2, 256](最后一个token作为标签) - 模型输出:
[2, 256, 256](每个位置对256个token的预测概率)
📄 文件: train_free.py
代码分析报告
1. 文件功能摘要
这是一个使用FreeTransformer模型在enwik8数据集上进行文本生成训练的脚本,结合了自回归语言建模和变分自编码器(VAE)技术。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| FreeTransformer | 一种结合了自回归语言模型和变分自编码器的Transformer架构,支持可控文本生成 |
| LayerNorm | 层归一化,用于稳定神经网络训练,对每个样本的特征进行归一化 |
| Einsum | Einstein求和约定,用于简洁表达张量运算(如矩阵乘法、转置等) |
| Residual | 残差连接,将输入直接加到输出上,缓解深度网络中的梯度消失问题 |
| KL散度 | Kullback-Leibler散度,衡量两个概率分布差异的指标,在VAE中用于正则化潜在空间 |
| 自回归(AR) | 模型基于先前生成的token预测下一个token的生成方式 |
| 变分自编码器(VAE) | 一种生成模型,通过编码器将输入映射到潜在空间,再通过解码器重构 |
| 梯度累积 | 将多个小批次的梯度累加后再更新参数,模拟更大批次训练 |
| enwik8 | 一个包含前1亿字节维基百科数据的文本数据集,常用于语言模型评估 |
3. 代码逐行/逐块注释
# /// script
# dependencies = [
# "tqdm",
# "x-transformers>=2.11.0",
# ]
# ///
# 脚本依赖声明(用于某些包管理器)
from x_transformers.free_transformer import FreeTransformer # 导入FreeTransformer模型
from math import log
import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch import tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
# 常量定义
NUM_BATCHES = int(1e5) # 总训练批次:100,000
BATCH_SIZE = 4 # 批次大小:4个序列
GRADIENT_ACCUMULATE_EVERY = 4 # 梯度累积步数:每4个小批次更新一次参数
LEARNING_RATE = 1e-4 # 学习率:0.0001
VALIDATE_EVERY = 100 # 验证频率:每100批次验证一次
GENERATE_EVERY = 250 # 生成频率:每250批次生成一次文本
GENERATE_LENGTH = 512 # 生成文本长度:512个token
PRIME_LENGTH = 32 # 提示文本长度:32个token
SEQ_LEN = 512 # 序列长度:512个token
LATENT_BITS = 8 # 潜在变量位数:8位(256个可能值)
NAT = log(2) # 自然单位:ln(2),用于KL散度阈值
# 辅助函数
def cycle(loader):
"""创建无限循环的数据加载器"""
while True:
for data in loader:
yield data
def decode_token(token):
"""将token解码为字符"""
return str(chr(max(32, token))) # 确保ASCII值≥32(可打印字符)
def decode_tokens(tokens):
"""将token序列解码为字符串"""
return ''.join(list(map(decode_token, tokens)))
# 实例化GPT-like解码器模型
model = FreeTransformer(
num_tokens = 256, # 词汇表大小:256(对应字节级编码)
max_seq_len = SEQ_LEN, # 最大序列长度:512
dim = 512, # 模型维度:512
heads = 8, # 注意力头数:8
dec_head_depth = 4, # 解码器头部深度:4层
dec_tail_depth = 4, # 解码器尾部深度:4层
enc_depth = 3, # 编码器深度:3层
kl_loss_weight = 1., # KL散度损失权重:1.0
per_token_latents = True, # 每个token使用独立潜在变量
kl_loss_threshold = NAT, # KL散度阈值:ln(2)
latent_bits = LATENT_BITS # 潜在变量位数:8
).cuda() # 将模型移到GPU
# 创建随机潜在变量索引(形状为空的张量,返回标量值)
one_hot_indices = torch.randint(0, 2 ** LATENT_BITS, ())
# 准备enwik8数据
with gzip.open('./data/enwik8.gz') as file:
# 读取前95MB数据,转换为numpy数组
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
# 分割数据:前90MB训练,后5MB验证
train_x, valid_x = np.split(data, [int(90e6)])
# 转换为PyTorch张量
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
class TextSamplerDataset(Dataset):
"""文本采样数据集类"""
def __init__(self, data, seq_len):
super().__init__()
self.data = data # 完整文本数据
self.seq_len = seq_len # 序列长度
def __getitem__(self, index):
# 随机选择起始位置(确保有足够的后续token)
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
# 提取seq_len+1长度的序列(+1用于创建输入-目标对)
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda() # 移到GPU
def __len__(self):
# 数据集大小 = 总数据长度 // 序列长度
return self.data.size(0) // self.seq_len
# 创建训练和验证数据集
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
# 创建无限循环的数据加载器
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
# 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 训练循环
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train() # 设置为训练模式
# 梯度累积循环
for __ in range(GRADIENT_ACCUMULATE_EVERY):
# 前向传播,获取总损失和分解损失
# loss: 总损失,ar_loss: 自回归损失,vae_kl_loss: VAE的KL散度损失
loss, (ar_loss, vae_kl_loss) = model(next(train_loader), return_all_losses = True)
# 梯度累积:将损失除以累积步数后反向传播
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
# 打印损失信息
print(f'training loss: {ar_loss.item():.4f}\t| kl loss: {vae_kl_loss.item():.4f}')
# 梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# 参数更新
optim.step()
# 清空梯度
optim.zero_grad()
# 验证
if i % VALIDATE_EVERY == 0:
model.eval() # 设置为评估模式
with torch.no_grad(): # 禁用梯度计算
loss, (ar_loss, _) = model(next(val_loader), return_all_losses = True)
print(f'validation loss: {ar_loss.item():.4f}')
# 文本生成
if i % GENERATE_EVERY == 0:
model.eval()
# 从验证集随机选择提示文本(前PRIME_LENGTH个token)
inp = random.choice(val_dataset)[:PRIME_LENGTH]
prime = decode_tokens(inp) # 解码为可读字符串
print(f'%s \n\n %s', (prime, '*' * 100)) # 打印提示文本和分隔符
# 生成文本
sample = model.generate(
prompts = inp, # 输入提示
seq_len = GENERATE_LENGTH, # 生成长度
latents = one_hot_indices # 潜在变量控制生成风格
)
output_str = decode_tokens(sample) # 解码生成结果
# 打印潜在变量索引和生成的文本
print(f'\n\nlatent {one_hot_indices.tolist()} - ', output_str)
关键张量维度变化说明
-
数据加载:
- 原始数据:
[95,000,000](字节数组) - 批次数据:
[BATCH_SIZE, SEQ_LEN+1]→[4, 513]
- 原始数据:
-
模型输入/输出:
- 输入序列:
[batch_size, seq_len]→[4, 512] - 目标序列:
[batch_size, seq_len]→[4, 512](输入序列右移一位)
- 输入序列:
-
潜在变量:
- 潜在空间大小:
2^LATENT_BITS = 256 - 每个token的潜在表示:
[batch_size, seq_len, latent_dim]
- 潜在空间大小:
-
注意力机制:
- 查询(Q)、键(K)、值(V):
[batch_size, heads, seq_len, dim_per_head] - 注意力分数:
[batch_size, heads, seq_len, seq_len]
- 查询(Q)、键(K)、值(V):
训练流程总结
- 加载enwik8数据集并预处理
- 初始化FreeTransformer模型(结合AR和VAE)
- 使用梯度累积训练模型
- 定期验证模型性能
- 定期生成文本样本以观察生成质量
- 通过潜在变量控制生成文本的风格/属性
📄 文件: x_transformers/xl_autoregressive_wrapper.py
XLAutoregressiveWrapper 代码分析
1. 文件功能摘要
这个文件实现了一个用于超长序列(XL)自回归生成的包装器类,能够处理超过模型最大序列长度的文本生成任务,通过分块处理和记忆机制来实现长文本的连续生成。
2. 核心术语解释
- 自回归(Autoregressive):一种生成模型,每次生成一个token,并将之前生成的token作为输入来生成下一个token。
- 记忆(Mems):在Transformer-XL架构中,用于存储前一个片段的隐藏状态,以便在当前片段计算时利用历史信息。
- 缓存(Cache):存储前向传播中的中间计算结果,避免重复计算,提高推理效率。
- Einsum:爱因斯坦求和约定,一种简洁的张量运算表示法,用于复杂的张量乘法操作。
- Residual:残差连接,将输入直接加到输出上,有助于缓解深度网络中的梯度消失问题。
- LayerNorm:层归一化,对每个样本的特征维度进行归一化,稳定训练过程。
- top_p/top_k:采样策略,top_k保留概率最高的k个token,top_p保留累积概率达到p的最小token集合。
3. 代码逐行/逐块注释
from math import ceil
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, pack, unpack
from x_transformers.autoregressive_wrapper import top_p, top_k, eval_decorator
# 辅助函数
def exists(val):
return val is not None # 检查值是否存在(非None)
def divisible_by(numer, denom):
return (numer % denom) == 0 # 检查numer是否能被denom整除
# XL自回归包装器类
class XLAutoregressiveWrapper(nn.Module):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0
):
super().__init__()
self.pad_value = pad_value # 填充值,用于mask掉EOS之后的token
self.ignore_index = ignore_index # 忽略的索引,用于损失计算
self.net = net # 被包装的神经网络模型
self.max_seq_len = net.max_seq_len # 模型支持的最大序列长度
@torch.no_grad() # 禁用梯度计算,节省内存
@eval_decorator # 评估装饰器,确保模型处于评估模式
def generate(
self,
start_tokens, # 起始token序列,形状:[batch_size, seq_len]
seq_len, # 要生成的序列长度
eos_token = None, # 结束token,用于提前终止生成
temperature = 1., # 温度参数,控制采样随机性
filter_logits_fn = top_k, # 日志过滤函数,默认top_k采样
filter_kwargs: dict = dict(), # 过滤函数的参数
mems = None, # 初始记忆(前一片段的隐藏状态)
**kwargs # 传递给网络的额外参数
):
# 获取设备信息和最大序列长度
device, max_seq_len = start_tokens.device, self.max_seq_len
# 使用pack处理可能的额外维度,确保形状为 [*, n]
start_tokens, ps = pack([start_tokens], '* n')
# 获取batch大小和当前序列长度
b, t = start_tokens.shape
# 将起始token按最大序列长度分割,获取所有前导片段
# 最后一个片段可能不足max_seq_len,用于后续生成
*all_leading_tokens, _ = start_tokens.split(max_seq_len, dim = -1)
# 处理前导片段,更新记忆
for leading_tokens in all_leading_tokens:
_, mems = self.net(
leading_tokens,
mems = mems, # 使用前一片段的记忆
return_mems = True, # 返回更新后的记忆
**kwargs
)
# 开始从当前片段采样
curr_pos = len(all_leading_tokens) * max_seq_len # 当前在完整序列中的位置
curr_mems = mems # 当前记忆
cache = None # 缓存,用于存储当前片段的中间结果
out = start_tokens # 输出序列,初始为起始token
# 循环生成指定长度的token
for _ in range(seq_len):
curr_segment_len = out.shape[-1] # 当前输出序列长度
# 检查是否是最后一个片段的token(长度能被max_seq_len整除)
is_last_segment_tokens = divisible_by(curr_segment_len, max_seq_len)
# 获取当前需要处理的token(从curr_pos开始)
x = out[:, curr_pos:]
# 前向传播,获取logits和缓存
logits, cache = self.net(
x,
mems = curr_mems, # 使用当前记忆
cache = cache, # 使用之前的缓存
return_mems = True,
return_intermediates = True, # 返回中间结果(缓存)
**kwargs
)
mems = cache.mems # 更新记忆
# 取最后一个位置的logits用于生成下一个token
logits = logits[:, -1]
# 应用过滤函数(如top_k或top_p)
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
# 应用softmax和温度参数得到概率分布
probs = F.softmax(filtered_logits / temperature, dim=-1)
# 从概率分布中采样一个token
sample = torch.multinomial(probs, 1)
# 如果是最后一个片段的token,更新位置和记忆
if is_last_segment_tokens:
curr_pos = curr_segment_len
curr_mems = mems
# 将采样的token添加到输出序列
out = torch.cat((out, sample), dim=-1)
# 如果设置了EOS token,检查是否应该提前终止
if exists(eos_token):
is_eos_tokens = (out == eos_token) # 标记EOS token的位置
# 如果所有序列都生成了EOS token
if is_eos_tokens.any(dim = -1).all():
# 创建mask,将EOS之后的所有位置标记为True
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
# 用pad_value填充mask为True的位置
out = out.masked_fill(mask, self.pad_value)
break # 提前终止生成
# 只返回新生成的token(去掉起始token)
out = out[:, t:]
# 恢复原始的形状(如果有额外的维度)
out, = unpack(out, ps, '* n')
return out
def forward(
self,
x, # 输入序列,形状:[batch_size, seq_len]
mems = None, # 初始记忆
**kwargs # 传递给网络的额外参数
):
ignore_index, max_seq_len = self.ignore_index, self.max_seq_len
# 将输入分为输入序列和标签序列(用于训练)
# x: 除最后一个token外的所有token
# labels: 除第一个token外的所有token(用于预测下一个token)
x, labels = x[:, :-1], x[:, 1:]
seq_len = x.shape[1] # 序列长度
# 将输入和标签按最大序列长度分块
split_x = x.split(max_seq_len, dim = -1)
split_labels = labels.split(max_seq_len, dim = -1)
# 计算每个块的损失权重(基于长度比例)
loss_weights = tuple((t.shape[-1] / seq_len) for t in split_x)
# 选择损失函数:如果网络输出已经是log概率,使用NLLLoss;否则使用CrossEntropy
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
# 遍历每个块,计算加权损失
total_loss = 0.
for chunk, chunk_labels, loss_weight in zip(split_x, split_labels, loss_weights):
# 前向传播,获取logits和更新后的记忆
logits, mems = self.net(
chunk,
mems = mems, # 使用前一个块的记忆
return_mems = True, # 返回更新后的记忆
**kwargs
)
# 计算损失
# rearrange将logits从 [batch, seq_len, vocab] 重排为 [batch, vocab, seq_len]
# 以符合PyTorch交叉熵损失的输入要求
loss = loss_fn(
rearrange(logits, 'b n c -> b c n'),
chunk_labels,
ignore_index = ignore_index # 忽略指定索引的标签
)
# 累加加权损失
total_loss = total_loss + loss * loss_weight
return total_loss
关键张量变换说明:
-
生成过程中的维度变化:
logits[:, -1]:从形状[batch, seq_len, vocab]变为[batch, vocab],取最后一个位置的预测torch.cat((out, sample), dim=-1):将形状[batch, n]和[batch, 1]拼接为[batch, n+1]
-
训练过程中的维度变化:
rearrange(logits, 'b n c -> b c n'):将logits从[batch, seq_len, vocab_size]重排为[batch, vocab_size, seq_len],以适应PyTorch交叉熵损失的输入格式
-
分块处理:
- 使用
split(max_seq_len, dim=-1)将长序列分割为多个最大长度为max_seq_len的块 - 每个块独立处理,但通过
mems传递跨块的信息
- 使用
这个包装器的核心创新在于能够处理超过模型最大序列长度的生成任务,通过分块处理和记忆机制,实现了对任意长度序列的自回归生成。
📄 文件: x_transformers/neo_mlp.py
NeoMLP 代码分析
1. 文件功能摘要
这个文件实现了一个名为 NeoMLP 的神经网络架构,它将传统的多层感知机(MLP)结构重新设计为基于注意力机制的图神经网络,通过将输入、隐藏层和输出层都视为图中的节点,并使用 Transformer 编码器在这些节点之间进行信息传递。
2. 核心术语解释
- RandomFourierEmbed:随机傅里叶嵌入,一种将连续特征映射到高维空间的技术,通过随机投影和余弦变换实现。
- Encoder:Transformer 编码器,来自 x_transformers 库,用于处理序列数据。
- Einsum:爱因斯坦求和约定,一种简洁的张量运算表示法,用于执行复杂的张量乘法、转置和收缩操作。
- Residual:残差连接,神经网络中的一种技术,将输入直接加到输出上,有助于缓解梯度消失问题。
- LayerNorm:层归一化,对每个样本的特征维度进行归一化,稳定训练过程。
- Attention:注意力机制,让模型能够关注输入的不同部分,计算不同位置之间的相关性。
- Parameter:可训练参数,在 PyTorch 中通过
nn.Parameter封装,会在训练过程中被优化。 - ReLU:修正线性单元,一种常用的激活函数。
3. 代码逐行/逐块注释
from collections import namedtuple
import torch
from torch import nn, tensor, pi, is_tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, einsum, pack, unpack
from x_transformers.x_transformers import (
Encoder
)
# helpers
def exists(v):
return v is not None # 检查变量是否存在(非None)
def default(v, d):
return v if exists(v) else d # 如果v存在则返回v,否则返回默认值d
# random fourier
class RandomFourierEmbed(Module):
"""随机傅里叶嵌入层,用于处理连续特征"""
def __init__(self, dim):
super().__init__()
self.proj = nn.Linear(1, dim) # 从1维映射到dim维的线性层
self.proj.requires_grad_(False) # 冻结投影层参数,不参与训练
def forward(
self,
times,
):
times = rearrange(times, '... -> ... 1') # 在最后添加一个维度,从(...)变为(... 1)
rand_proj = self.proj(times) # 线性投影
return torch.cos(2 * pi * rand_proj) # 应用余弦变换,生成傅里叶特征
# class
class NeoMLP(Module):
""" https://openreview.net/forum?id=A8Vuf2e8y6 """
""" https://haian-jin.github.io/projects/LVSM/ """
def __init__(
self,
*,
dim_in, # 输入维度
dim_hidden, # 隐藏层维度(节点数量)
dim_out, # 输出维度
dim_model, # 模型内部表示维度
depth, # Transformer编码器深度(层数)
encoder_kwargs: dict = dict(
attn_dim_head = 16, # 注意力头维度
heads = 4 # 注意力头数量
)
):
super().__init__()
# input and output embeddings
# 输入、隐藏层和输出的嵌入参数
self.input_embed = nn.Parameter(torch.zeros(dim_in, dim_model))
self.hidden_embed = nn.Parameter(torch.zeros(dim_hidden, dim_model))
self.output_embed = nn.Parameter(torch.zeros(dim_out, dim_model))
# 使用正态分布初始化嵌入参数
nn.init.normal_(self.input_embed, std = 0.02)
nn.init.normal_(self.hidden_embed, std = 0.02)
nn.init.normal_(self.output_embed, std = 0.02)
# 对连续特征使用随机傅里叶嵌入
self.random_fourier = nn.Sequential(
RandomFourierEmbed(dim_model), # 随机傅里叶嵌入
nn.Linear(dim_model, dim_model) # 线性变换层
)
# 将MLP的隐藏层替换为具有消息传递的节点
# 这些节点通过自注意力机制形成全连接图
self.transformer = Encoder(
dim = dim_model, # 特征维度
depth = depth, # 层数
**encoder_kwargs # 其他编码器参数
)
# 输出层参数
self.to_output_weights = nn.Parameter(torch.randn(dim_out, dim_model))
self.to_output_bias = nn.Parameter(torch.zeros(dim_out))
def forward(
self,
x, # 输入张量,形状为(batch, dim_in)或(dim_in)
return_embeds = False # 是否返回嵌入表示
):
no_batch = x.ndim == 1 # 检查是否没有批次维度
if no_batch:
x = rearrange(x, '... -> 1 ...') # 添加批次维度,从(dim_in)变为(1, dim_in)
batch = x.shape[0] # 获取批次大小
# 对输入应用随机傅里叶变换
fouriered_input = self.random_fourier(x)
# 将傅里叶变换后的输入与输入嵌入相加
# fouriered_input形状: (batch, dim_in, dim_model)
# self.input_embed形状: (dim_in, dim_model),通过广播变为(batch, dim_in, dim_model)
input_embed = fouriered_input + self.input_embed
# 重复隐藏层和输出层嵌入以匹配批次大小
# self.hidden_embed形状: (dim_hidden, dim_model) -> (batch, dim_hidden, dim_model)
# self.output_embed形状: (dim_out, dim_model) -> (batch, dim_out, dim_model)
hidden_embed, output_embed = tuple(
repeat(t, '... -> b ...', b = batch)
for t in (self.hidden_embed, self.output_embed)
)
# 将所有输入打包成一个token序列用于自注意力
# input_embed: (batch, dim_in, dim_model)
# hidden_embed: (batch, dim_hidden, dim_model)
# output_embed: (batch, dim_out, dim_model)
# 打包后: (batch, total_tokens, dim_model),其中total_tokens = dim_in + dim_hidden + dim_out
embed, packed_shape = pack([input_embed, hidden_embed, output_embed], 'b * d')
# 通过Transformer编码器进行信息传递
embed = self.transformer(embed)
# 解包回原来的三个部分
input_embed, hidden_embed, output_embed = unpack(embed, packed_shape, 'b * d')
# 投影到输出空间
# output_embed: (batch, dim_out, dim_model)
# self.to_output_weights: (dim_out, dim_model)
# 结果: (batch, dim_out)
output = einsum(output_embed, self.to_output_weights, 'b n d, n d -> b n')
output = output + self.to_output_bias # 添加偏置
if no_batch:
output = rearrange(output, '1 ... -> ...') # 移除批次维度
if not return_embeds:
return output # 只返回输出
# 返回输出和所有嵌入表示
return output, (input_embed, hidden_embed, output_embed)
关键张量维度变化说明:
-
输入处理阶段:
- 原始输入
x:(batch, dim_in)或(dim_in) - 经过随机傅里叶变换后:
(batch, dim_in, dim_model) - 与输入嵌入相加后:
(batch, dim_in, dim_model)
- 原始输入
-
嵌入重复阶段:
- 隐藏层嵌入:
(dim_hidden, dim_model)→(batch, dim_hidden, dim_model) - 输出层嵌入:
(dim_out, dim_model)→(batch, dim_out, dim_model)
- 隐藏层嵌入:
-
打包阶段:
- 三个嵌入拼接:
(batch, dim_in+dim_hidden+dim_out, dim_model)
- 三个嵌入拼接:
-
Transformer处理:
- 输入/输出形状保持不变:
(batch, total_tokens, dim_model)
- 输入/输出形状保持不变:
-
输出投影:
- 输出嵌入:
(batch, dim_out, dim_model) - 权重矩阵:
(dim_out, dim_model) - 最终输出:
(batch, dim_out)
- 输出嵌入:
这个架构的创新之处在于将传统的MLP重新构想为一个图神经网络,其中输入、隐藏和输出节点通过Transformer编码器进行全连接的消息传递,从而能够捕捉更复杂的特征交互。
📄 文件: x_transformers/init.py
1. 文件功能摘要
这个文件是 x_transformers 库的模块初始化文件,负责导入并公开库中所有主要的类和函数,使其可以通过 from x_transformers import ... 的方式直接访问。
2. 核心术语解释
- XTransformer: 一个完整的 Transformer 模型,通常包含编码器和解码器。
- Encoder: Transformer 编码器,用于将输入序列编码为特征表示。
- Decoder: Transformer 解码器,用于自回归生成序列。
- PrefixDecoder: 支持前缀缓存的解码器,用于高效生成。
- CrossAttender: 交叉注意力模块,常用于编码器-解码器架构。
- AttentionPool: 使用注意力机制进行池化操作。
- Attention: 自注意力机制,用于捕捉序列内部依赖。
- FeedForward: 前馈神经网络,通常用于注意力层后的特征变换。
- RMSNorm: 均方根归一化,一种替代 LayerNorm 的归一化方法。
- AdaptiveRMSNorm: 自适应 RMSNorm,可能根据输入动态调整参数。
- TransformerWrapper: Transformer 模型的包装器,简化模型构建。
- ViTransformerWrapper: 视觉 Transformer 包装器,适用于图像任务。
- AutoregressiveWrapper: 自回归包装器,用于语言建模等任务。
- NonAutoregressiveWrapper: 非自回归包装器,用于并行生成序列。
- BeliefStateWrapper: 信念状态包装器,可能用于多轮对话或状态跟踪。
- ContinuousTransformerWrapper: 连续数据 Transformer 包装器,适用于连续输入(如音频、时间序列)。
- ContinuousAutoregressiveWrapper: 连续数据的自回归包装器。
- MultiInputTransformerWrapper: 多输入 Transformer 包装器,支持多个输入源。
- XValTransformerWrapper: 交叉验证 Transformer 包装器,可能用于评估或集成。
- XValAutoregressiveWrapper: 交叉验证的自回归包装器。
- XLAutoregressiveWrapper: 超长序列自回归包装器,可能支持更长的上下文。
- DPO: 直接偏好优化(Direct Preference Optimization),用于强化学习对齐。
- NeoMLP: 一种新型的多层感知机结构,可能具有特殊设计。
- EntropyBasedTokenizer: 基于熵的分词器,可能用于动态分词或词汇表优化。
3. 代码逐行/逐块注释
# 从 x_transformers.x_transformers 模块导入核心 Transformer 相关类
from x_transformers.x_transformers import (
XTransformer, # 完整的 Transformer 模型
Encoder, # 编码器部分
Decoder, # 解码器部分
PrefixDecoder, # 支持前缀缓存的解码器
CrossAttender, # 交叉注意力模块
AttentionPool, # 注意力池化层
Attention, # 自注意力机制
FeedForward, # 前馈网络
RMSNorm, # 均方根归一化
AdaptiveRMSNorm, # 自适应 RMSNorm
TransformerWrapper, # 通用 Transformer 包装器
ViTransformerWrapper, # 视觉 Transformer 包装器
)
# 导入自回归和非自回归包装器,用于序列生成任务
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
from x_transformers.belief_state_wrapper import BeliefStateWrapper
# 导入连续数据相关的 Transformer 包装器,适用于非离散输入(如音频、时间序列)
from x_transformers.continuous import (
ContinuousTransformerWrapper,
ContinuousAutoregressiveWrapper
)
# 导入多输入 Transformer 包装器,支持多个输入源(如多模态数据)
from x_transformers.multi_input import MultiInputTransformerWrapper
# 导入交叉验证相关的 Transformer 包装器,可能用于模型评估或集成学习
from x_transformers.xval import (
XValTransformerWrapper,
XValAutoregressiveWrapper
)
# 导入超长序列自回归包装器,可能支持更长的上下文长度(如 XLNet 风格)
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
# 导入直接偏好优化(DPO)模块,用于强化学习对齐任务
from x_transformers.dpo import (
DPO
)
# 导入新型多层感知机(NeoMLP)模块,可能具有创新的网络结构
from x_transformers.neo_mlp import (
NeoMLP
)
# 导入基于熵的分词器,可能用于动态分词或优化词汇表
from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
📄 文件: x_transformers/belief_state_wrapper.py
Belief State Transformer 代码分析
1. 文件功能摘要
这个文件实现了一个"信念状态变换器"(Belief State Transformer),这是一种特殊的Transformer架构,能够同时学习序列的前向和后向生成,通过结合前缀和后缀信息来预测序列中的缺失部分。
2. 核心术语解释
- Belief State(信念状态):模型对序列中某个位置的前向和后向信息的联合表示,结合了前缀(前向)和后缀(后向)的上下文信息。
- Forward Decoder(前向解码器):标准的自回归Transformer,从左到右生成序列。
- Backward Decoder(后向解码器):反向的自回归Transformer,从右到左生成序列。
- Suffix Token(后缀标记):特殊的标记,用于指示后向解码的开始位置。
- Cartesian Product(笛卡尔积):生成所有可能的前向-后向位置对,用于训练信念状态。
- Detach Multiple(分离多个张量):自定义的PyTorch函数,用于同时分离多个张量以减少内存使用。
- Fill in the Middle(中间填充):一种序列生成任务,给定前缀和后缀,预测中间的缺失部分。
3. 代码逐行/逐块注释
# Belief State Transformer
# Hu et al. https://arxiv.org/abs/2410.23506
# https://www.youtube.com/watch?v=aqhbRtB2Fyg
from __future__ import annotations
from random import random
import torch
from torch.autograd import Function
from torch.nn import Module, ModuleList
from torch import nn, cat, stack, tensor, Tensor, arange, cartesian_prod
import torch.nn.functional as F
from x_transformers.autoregressive_wrapper import (
eval_decorator,
min_p,
)
from x_transformers.x_transformers import (
Decoder,
TransformerWrapper
)
import einx
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
导入部分:导入必要的库和模块,包括PyTorch基础组件、x_transformers库中的解码器和包装器,以及einops用于张量操作。
# helper functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
辅助函数:
exists():检查变量是否为Nonedefault():如果变量存在则返回变量,否则返回默认值
# a custom flip that can handle variable lengths across batch
def flip(x, dim = 1, lens = None):
if not exists(lens):
return x.flip(dim)
batch, seq_len, device = *x.shape[:2], x.device
seq = arange(seq_len, device = device)
mask = einx.less('j, i -> i j', seq, lens)
masked_seq = einx.where('i j, j,', mask, seq, -1)
flip_indices = masked_seq.argsort(dim = -1, descending = True)
if x.ndim == 3:
flip_indices = repeat(flip_indices, '... -> ... d', d = x.shape[-1])
return x.gather(dim, flip_indices)
自定义翻转函数:
- 处理变长序列的翻转操作
- 对于每个批次,根据序列长度
lens创建掩码 - 使用
gather操作按照计算出的翻转索引重新排列张量
# detach multiple tensors and backward the gradients once
class DetachMultiple(Function):
@classmethod
def forward(self, ctx, *tensors):
detached_tensors = tuple(t.detach() for t in tensors)
for detached_tensor in detached_tensors:
detached_tensor.requires_grad_()
return detached_tensors
@classmethod
def backward(self, ctx, *grads):
return grads
detach_multiple = DetachMultiple.apply
分离多个张量的自定义函数:
- 同时分离多个张量以减少内存使用
- 在反向传播时一次性处理所有梯度
class BeliefStateWrapper(Module):
"""
Figure 13. in https://arxiv.org/abs/2410.23506
"""
def __init__(
self,
forward_decoder: TransformerWrapper,
backward_decoder: TransformerWrapper | None = None,
train_frac_forward_backward_pairs: float = 1.,
text_head: Module | None = None,
backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
pred_distance = False,
pred_distance_loss_weight: float = 1.,
cond_on_distance = False,
cond_on_distance_prob = 0.5,
max_pred_distance = None
):
super().__init__()
backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
dim = forward_decoder.emb_dim
num_tokens = forward_decoder.num_tokens
max_seq_len = forward_decoder.max_seq_len
self.num_tokens = num_tokens
# the suffix token
self.suffix_token = nn.Parameter(torch.zeros(dim))
nn.init.normal_(self.suffix_token, std = 0.02)
# the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
if not exists(text_head):
text_head = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.LeakyReLU(),
nn.Linear(dim, num_tokens * 2),
)
self.text_head = text_head
# predicting terminal state (when suffix and prefix predict the same token)
self.max_pred_distance = default(max_pred_distance, max_seq_len)
self.to_distance_logits = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.LeakyReLU(),
nn.Linear(dim, self.max_pred_distance),
) if pred_distance else None
self.pred_distance_loss_weight = pred_distance_loss_weight
# conditioning on distance
assert 0. < cond_on_distance_prob < 1.
self.cond_on_distance = cond_on_distance
self.cond_on_distance_prob = cond_on_distance_prob
if cond_on_distance:
self.to_distance_cond = nn.Sequential(
Rearrange('... -> ... 1'),
nn.Linear(1, dim),
nn.LeakyReLU(),
nn.Linear(dim, dim * 2),
)
# the two decoders, one which is causal forward, the other causal backwards
self.forward_decoder = forward_decoder
self.backward_decoder = backward_decoder
# what fraction of forward backward pairs to train on
# for further memory efficiency
assert 0 < train_frac_forward_backward_pairs <= 1.
self.train_frac_fb_pairs = train_frac_forward_backward_pairs
self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
# loss weighting
self.backward_ar_loss_weight = backward_ar_loss_weight
self.needs_loss_weight = backward_ar_loss_weight != 1.
self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
# sampling
self.max_seq_len = self.forward_decoder.max_seq_len
BeliefStateWrapper初始化:
- 接收前向和后向解码器(默认共享)
- 初始化后缀标记(可学习参数)
- 设置文本预测头(预测前向和后向的下一个标记)
- 可选的距离预测模块(预测前缀和后缀之间的距离)
- 可选的距离条件模块(根据距离调整嵌入)
- 设置训练参数和损失权重
@torch.no_grad()
@eval_decorator
def generate_with_suffix_cond(
self,
prompts,
seq_len,
temperature = 1.25,
cache_kv = False,
suffix: Tensor | None = None, # the goal conditioning
filter_logits_fn = min_p,
filter_kwargs = dict(
min_p = 0.1
),
decode_backwards = False,
**kwargs
):
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
prompts, batch_ps = pack([prompts], '* d')
batch, orig_seq_len = prompts.shape
# allow for decoding backwards, to make sure it is working
main_decoder = self.forward_decoder
if decode_backwards:
prompts = prompts.flip(1)
main_decoder = self.backward_decoder
out = prompts
# kv caches
cache = None
# get the encoded suffix token once
suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
if not decode_backwards:
if exists(suffix):
if suffix.ndim == 1:
suffix = repeat(suffix, 'n -> b n', b = batch)
suffix = suffix.flip(1) # reverse autoregressive
suffix_embed = self.backward_decoder(
suffix,
prepend_embeds = suffix_sos_tokens,
return_embeddings = True
)
# pick out the last embedding for fill in the middle
suffix_embed = suffix_embed[:, -1:]
else:
# just grab a random token for now for prefix
prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device)
prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True)
# sampling up to seq_len
for _ in range(seq_len):
embeds, new_cache = main_decoder(
out,
prepend_embeds = suffix_sos_tokens if decode_backwards else None,
return_intermediates = True,
return_embeddings = True,
cache = cache,
**kwargs
)
last_embeds = embeds[:, -1:]
if not decode_backwards:
embeds = cat((last_embeds, suffix_embed), dim = -1)
else:
embeds = cat((prefix_embed, last_embeds), dim = -1)
if cache_kv and self.forward_decoder.can_cache_kv:
cache = new_cache
forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1)
logits = forward_logits if not decode_backwards else backward_logits
logits = logits[:, -1]
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim = -1)
sample = torch.multinomial(probs, 1)
# concat sample
out = torch.cat((out, sample), dim = -1)
out = out[:, orig_seq_len:]
out, = unpack(out, batch_ps, '* n')
return out
生成函数:
- 支持前向和后向生成
- 使用后缀条件进行中间填充生成
- 实现KV缓存以提高生成效率
- 支持贪婪采样和温度采样
def forward(
self,
seq,
lens: Tensor | None = None, # Int['b']
loss_weight_by_fb_indices: callable | None = None
):
batch, seq_len, device = *seq.shape, seq.device
# handle variable length sequences
seq_for_labels = seq
if exists(lens):
mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens)
seq_for_labels = torch.where(mask, seq, -1)
# forward autoregressive
forward_embeds = self.forward_decoder(seq, return_embeddings = True)
# backward autoregressive
backward_seq = flip(seq, lens = lens)
suffix_tokens = repeat(self.suffix_token, 'd -> b 1 d', b = batch)
backward_embeds = self.backward_decoder(
backward_seq,
prepend_embeds = suffix_tokens,
return_embeddings = True
)
backward_embeds = flip(backward_embeds, lens = lens)
# trick to reduce memory on backwards pass
forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
# belief state objective
seq_arange = arange(seq_len, device = device)
fb_pairs = cartesian_prod(seq_arange, seq_arange + 1) # plus one for suffix token
# filter down to valid pairs, as in figure 11
# f - forward, b - backward, i - indices
fi, bi = fb_pairs.unbind(dim = -1)
valid_mask = (bi - fi) >= 2
fb_pairs = fb_pairs[valid_mask]
# maybe subsample fb pairs
if self.needs_subsample_fb_pairs:
num_pairs = fb_pairs.shape[0]
num_subsampled = max(int(num_pairs * self.train_frac_fb_pairs), 1)
rand_subsampled_indices = torch.randperm(num_pairs, device = device)[:num_subsampled]
fb_pairs = fb_pairs[rand_subsampled_indices]
# get labels for both
fi, bi = fb_pairs.unbind(dim = -1)
labels_fi, labels_bi = (fi + 1), (bi - 1)
forward_labels, backward_labels = seq_for_labels[:, labels_fi], seq_for_labels[:, labels_bi]
labels = cat((forward_labels, backward_labels), dim = -1)
# get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
fb_embeds = cat((
forward_embeds[:, fi],
backward_embeds[:, bi]
), dim = -1)
logits = self.text_head(fb_embeds)
# cross entropy loss
loss = F.cross_entropy(
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
labels,
reduction = 'none' if self.needs_loss_weight else 'mean',
ignore_index = -1
)
# maybe condition on distance
cond_on_distance = self.cond_on_distance and (random() < self.cond_on_distance_prob)
if cond_on_distance:
distance = (bi - fi).float()
distance_cond = self.to_distance_cond(distance)
fb_embeds = fb_embeds * distance_cond
# maybe predict distance
if exists(self.to_distance_logits) and not cond_on_distance:
distance_logits = self.to_distance_logits(fb_embeds)
distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
distance_labels = repeat(distance_labels, 'n -> b n', b = batch)
pred_dist_loss = F.cross_entropy(
rearrange(distance_logits, 'b n l -> b l n'),
distance_labels
)
loss = (
loss +
pred_dist_loss * self.pred_distance_loss_weight
)
# maybe loss weighting
needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
if needs_loss_weight:
loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
if self.needs_loss_weight:
loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
# allow researcher to pass in a function that acts on the the forward backward indices Int['n fb']
# the reason this may be needed is because the earlier tokens will have more eligible pairs for training, and perhaps this could be normalized
if exists(loss_weight_by_fb_indices):
loss_weight = loss_weight_by_fb_indices(fb_pairs)
if loss_weight.ndim == 1:
loss = einx.multiply('b fb n, n', loss, loss_weight)
elif loss_weight.ndim == 2:
loss = einx.multiply('b fb n, n fb', loss, loss_weight)
else:
raise ValueError('invalid loss weight dims')
loss = loss.mean()
return loss
前向传播函数:
- 处理变长序列:创建掩码处理填充标记
- 前向自回归:获取前向嵌入
- 后向自回归:翻转序列,添加后缀标记,获取后向嵌入
- 生成有效位置对:使用笛卡尔积生成所有可能的前向-后向位置对,过滤无效对(距离至少为2)
- 下采样:可选地减少训练对的数量以节省内存
- 计算损失:
- 结合前向和后向嵌入预测下一个标记
- 可选的距离条件或距离预测
- 灵活的损失加权机制
关键张量变换:
fb_embeds: 形状为[batch, num_pairs, dim*2],结合了前向和后向嵌入logits: 形状为[batch, num_pairs, num_tokens*2],同时预测前向和后向的下一个标记
📄 文件: x_transformers/autoregressive_wrapper.py
代码分析报告
1. 文件功能摘要
这个文件实现了一个自回归包装器(AutoregressiveWrapper),用于将任意神经网络模型包装成能够进行自回归生成(如文本生成)的模型,支持多种解码策略(贪婪搜索、束搜索、对比解码等)和训练功能。
2. 核心术语解释
- 自回归(Autoregressive):一种生成模型,每次生成一个token,并将之前生成的所有token作为输入来预测下一个token。
- 束搜索(Beam Search):一种启发式搜索算法,在每一步保留多个候选序列(束),而不是只保留一个最优序列。
- 对比解码(Contrastive Decoding):使用专家模型和业余模型输出的差异来改进生成质量的技术。
- KV缓存(Key-Value Cache):在Transformer解码过程中缓存注意力机制的key和value,避免重复计算。
- Top-p采样(Nucleus Sampling):从累积概率超过阈值p的token中进行采样,实现动态词汇表大小。
- Top-k采样:只从概率最高的k个token中进行采样。
- Gumbel噪声:用于Gumbel-Softmax采样的噪声,实现可微分的离散采样。
- 注意力Z损失(Attention Z-loss):一种正则化损失,用于稳定注意力权重的训练。
- 位置编码(Positional Embedding):为token添加位置信息的嵌入,可以是绝对位置编码或相对位置编码(如旋转位置编码)。
3. 代码逐行/逐块注释
辅助函数部分
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def identity(t, *args, **kwargs):
return t
def join(arr, delimiter = ', '):
return delimiter.join(arr)
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else (t,) * length
注释:这些是通用的辅助函数,用于处理空值、默认值、元组转换等。
def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
self.eval()
out = fn(self, *args, **kwargs)
self.train(was_training)
return out
return inner
注释:装饰器函数,用于在调用方法时临时将模型切换到评估模式,调用结束后恢复原来的训练状态。
Gumbel采样相关函数
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def gumbel_noise(t):
return -log(-log(torch.rand_like(t)))
def gumbel_sample(logits, temperature = 1., eps = 1e-6):
noise = gumbel_noise(logits)
return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
注释:
log:安全的对数函数,防止对0取对数gumbel_noise:生成Gumbel分布噪声,用于Gumbel-Softmax采样gumbel_sample:使用Gumbel噪声进行采样,temperature参数控制采样随机性
缓存操作函数
def modify_cached_kv(cache, fn):
for inter in cache.attn_intermediates:
if inter.layer_type == 'a':
inter.cached_kv = [fn(t) for t in inter.cached_kv]
注释:修改缓存中的key-value对,fn是应用于每个缓存张量的函数。
序列对齐函数
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
if pad == (0, 0):
return t
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value = value)
def align_right(t, lens, pad_id = 0):
batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
assert lens.ndim == 1 and lens.shape[0] == batch
assert lens.amax() <= seq_len
pad_lens = seq_len - lens
max_pad_len = pad_lens.amax()
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
offset = max_pad_len - pad_lens
aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
return aligned
注释:
pad_at_dim:在指定维度上填充张量align_right:将变长序列右对齐,用于处理不同长度的提示(prompt)
采样策略函数
def top_p(logits, thres = 0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending = True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
sorted_indices_to_remove = cum_probs > thres
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
sorted_logits[sorted_indices_to_remove] = float('-inf')
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
注释:Top-p(核)采样,只保留累积概率超过阈值thres的token,将其余token的概率设为负无穷。
def top_k(logits, frac_num_tokens = 0.1, k = None):
num_tokens = logits.shape[-1]
k = default(k, ceil(frac_num_tokens * num_tokens))
k = min(k, num_tokens)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
注释:Top-k采样,只保留概率最高的k个token。
def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
return torch.where(probs < limit, float('-inf'), logits)
注释:Top-a采样,基于最大概率的幂次和比例来过滤token。
def min_p(logits, min_p = 0.1):
probs = logits.softmax(dim = -1)
max_probs = probs.amax(dim = -1, keepdim = True)
limit = min_p * max_probs
return torch.where(probs < limit, float('-inf'), logits)
注释:Min-p采样,保留概率至少为最大概率min_p倍的token。
对比解码函数
def contrastive_decode_fn(
expert_logits,
amateur_logits,
alpha = 0.1,
beta = 0.5
):
cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
diffs = (1 + beta) * expert_logits - beta * amateur_logits
contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
return contrastive_decode_logits
注释:对比解码算法,通过专家模型和业余模型的logits差异来改进生成质量。首先计算一个截断值,然后计算加权差异,最后将低于截断值的token设为负无穷。
AutoregressiveWrapper类
初始化方法
class AutoregressiveWrapper(Module):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0,
mask_prob = 0.,
add_attn_z_loss = False,
next_embed_loss_weight = 0.1
):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index
self.net = net
self.max_seq_len = net.max_seq_len
assert mask_prob < 1.
self.mask_prob = mask_prob
self.add_attn_z_loss = add_attn_z_loss
self.add_continuous_pred_head = net.add_continuous_pred_head
self.next_embed_loss_weight = next_embed_loss_weight
注释:初始化自回归包装器,包装一个神经网络模型,设置各种参数如填充值、忽略索引、掩码概率等。
束搜索方法
@torch.no_grad()
@eval_decorator
def beam_search(
self,
prompts,
seq_len,
beams = 4,
return_beams_and_scores = False,
eos_token = None,
temperature = 1.,
stochastic = False,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = identity,
restrict_to_max_seq_len = True,
filter_kwargs: dict = dict(),
cache_kv = True,
**kwargs
):
注释:束搜索生成方法,支持束宽、温度采样、随机束搜索等。
关键代码块分析:
# 处理变长提示
if exists(prompt_lens):
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
seq_start_pos = orig_seq_len - prompt_lens
注释:如果提供了提示长度,将提示右对齐,并计算序列起始位置。
# 主生成循环
for i in range(seq_len):
is_first = i == 0
if restrict_to_max_seq_len:
max_len_exceeded = out.shape[-1] > max_seq_len
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), '错误信息'
x = out[:, -max_seq_len:]
if exists(cache):
modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :])
注释:如果限制最大序列长度,只取最后max_seq_len个token作为输入,并相应截断缓存。
logits, new_cache = self.net(
x,
return_intermediates = True,
cache = cache,
seq_start_pos = seq_start_pos,
**kwargs
)
注释:前向传播获取下一个token的logits和新的缓存。
# 扩展束
scores = repeat(scores, 'b -> b beams', beams = beams)
scores = scores + next_scores
out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
samples = rearrange(samples, 'b beams -> (b beams) 1')
if should_cache and is_first:
modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
注释:将分数和输出序列扩展为束的维度,如果是第一步还需要扩展缓存。
# 排序并剪枝束
scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch)
curr_num_beams = scores.shape[-1]
if curr_num_beams > beams:
scores, sort_indices = scores.sort(dim = -1, descending = True)
scores = scores[:, :beams]
top_beams_indices = sort_indices[:, :beams]
top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
out = out[flattened_beam_indices]
注释:将分数重新排列,按分数排序,只保留前beams个最好的束,并选择对应的输出序列。
生成方法
@torch.no_grad()
@eval_decorator
def generate(
self,
prompts: list[Tensor] | Tensor,
seq_len,
eos_token = None,
temperature = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(
beta = 0.5,
alpha = 0.1
),
cache_kv = True,
**kwargs
):
注释:通用的生成方法,支持多种采样策略和对比解码。
关键代码块分析:
# 处理对比解码
if exists(amateur_model):
amateur_model = cast_tuple(amateur_model)
contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
assert len(amateur_model) == len(contrastive_decode_kwargs)
amateur_caches = [None] * len(amateur_model)
filter_logits_fn = identity
for i, module in enumerate(amateur_model):
if isinstance(module, AutoregressiveWrapper):
amateur_model[i] = module.net
module.eval()
注释:如果提供了业余模型,设置对比解码。将业余模型转换为元组,关闭过滤函数,并将业余模型设置为评估模式。
# 对比解码前向传播
if exists(amateur_model):
for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
amateur_logits, next_amateur_cache = amateur(
x,
return_intermediates = True,
cache = amateur_cache,
seq_start_pos = seq_start_pos,
**kwargs
)
amateur_logits = amateur_logits[:, -1]
assert amateur_logits.shape == logits.shape, 'logits维度不一致'
logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
if cache_kv and amateur.can_cache_kv:
amateur_caches[i] = next_amateur_cache
注释:对每个业余模型进行前向传播,获取logits,然后使用对比解码函数结合专家和业余模型的logits。
# 采样
if greedy:
sample = logits.argmax(dim = -1, keepdim = True)
else:
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
注释:根据是否贪婪选择采样方式:贪婪选择直接取argmax,否则使用过滤后的logits进行多项式采样。
前向传播方法
def forward(
self,
x,
return_outputs = False,
prepend_embeds = None,
**kwargs
):
注释:训练时的前向传播方法,计算交叉熵损失和其他辅助损失。
关键代码块分析:
# 掩码语言模型(MLM)训练
if self.mask_prob > 0.:
rand = torch.randn(inp.shape, device = x.device)
rand[:, 0] = -torch.finfo(rand.dtype).max # 第一个token不应该被掩码
num_mask = min(int(seq * self.mask_prob), seq - 1)
indices = rand.topk(num_mask, dim = -1).indices
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
kwargs.update(self_attn_kv_mask = mask)
注释:如果设置了掩码概率,随机选择一些位置进行掩码,用于掩码语言模型训练。
# 前向传播获取输出
out, cache = self.net(
inp,
return_intermediates = True,
return_attn_z_loss = add_attn_z_loss,
return_next_embed_pred = add_next_embed_loss,
prepend_embeds = prepend_embeds,
**kwargs
)
注释:调用包装的网络进行前向传播,返回logits和缓存。
# 计算交叉熵损失
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
loss = loss_fn(
rearrange(logits, 'b n c -> b c n'),
target,
ignore_index = ignore_index
)
注释:根据网络输出类型选择损失函数,并计算交叉熵损失。使用rearrange将logits从[batch, seq_len, vocab]转换为[batch, vocab, seq_len]以符合PyTorch交叉熵损失的输入格式。
# 添加注意力Z损失
if add_attn_z_loss:
loss = loss + cache.attn_z_loss
注释:如果启用了注意力Z损失,将其添加到总损失中。
# 添加连续嵌入预测损失
if add_next_embed_loss:
mask = target != ignore_index
embed_pred = next_embed_pred[:, :-1]
cont_targets = init_embeds[:, 1:].detach()
cont_loss = F.l1_loss(embed_pred, cont_targets, reduction = 'none')
cont_loss = cont_loss[mask].mean()
loss = loss +
---
## 📄 文件: x_transformers/multi_input.py
# MultiInputTransformerWrapper 代码分析
## 1. 文件功能摘要
这是一个支持多输入类型的Transformer包装器,能够处理多种不同类型的输入标记(如文本、图像、音频等),为每种输入类型提供独立的嵌入层和输出头,适用于多模态或多任务学习场景。
## 2. 核心术语解释
### **LayerNorm (层归一化)**
- **解释**: 对神经网络层的输出进行归一化,使其均值为0,方差为1,有助于稳定训练过程。
- **在代码中**: 用于`post_emb_norm`,对嵌入后的表示进行归一化。
### **Residual (残差连接)**
- **解释**: 将层的输入直接加到输出上,形成`输出 = F(输入) + 输入`的结构,有助于梯度流动和深层网络训练。
- **在代码中**: 虽然未直接出现,但`AttentionLayers`内部通常会使用残差连接。
### **Einsum (爱因斯坦求和约定)**
- **解释**: 一种简洁的张量运算表示法,用于指定多维数组的乘积和求和操作。
- **在代码中**: 通过`einops`库的`rearrange`函数实现类似功能,用于张量形状变换。
### **AttentionLayers (注意力层)**
- **解释**: Transformer的核心组件,包含多头自注意力机制和前馈网络。
- **在代码中**: 作为`attn_layers`参数传入,是模型的主要计算模块。
### **ScaledSinusoidalEmbedding (缩放正弦位置编码)**
- **解释**: 使用正弦函数生成的位置编码,并进行缩放以适应模型维度。
- **在代码中**: 作为位置编码的一种选项。
### **AbsolutePositionalEmbedding (绝对位置编码)**
- **解释**: 可学习的位置编码,为每个位置分配一个可学习的向量。
- **在代码中**: 默认的位置编码方式。
### **Memory Tokens (记忆令牌)**
- **解释**: 类似BERT的[CLS]令牌,用于聚合序列信息,可放置在序列开头或间隔插入。
- **在代码中**: 通过`memory_tokens`参数实现,用于增强模型表示能力。
### **KV Caching (键值缓存)**
- **解释**: 在自回归生成中缓存注意力层的键值对,避免重复计算,提高推理效率。
- **在代码中**: 通过`can_cache_kv`属性控制是否支持缓存。
## 3. 代码逐行/逐块注释
```python
from __future__ import annotations # 启用延迟类型注解,允许在类型提示中引用类自身
import torch
from torch import nn, Tensor
from torch.nn import Module, ModuleDict
import torch.nn.functional as F
from typing import Dict # 类型提示:字典类型
from einops import pack, repeat, unpack # 张量操作库,用于形状变换
from x_transformers.x_transformers import (
AttentionLayers, # 注意力层模块
ScaledSinusoidalEmbedding, # 缩放正弦位置编码
AbsolutePositionalEmbedding, # 绝对位置编码
LayerIntermediates, # 层中间结果容器
LayerNorm, # 层归一化
always, # 返回常量的函数
pad_at_dim, # 维度填充函数
is_empty, # 检查是否为空
)
# 辅助函数
def exists(val):
return val is not None # 检查值是否存在(非None)
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d # 如果val不存在,返回默认值d
class MultiInputTransformerWrapper(Module):
def __init__(
self,
*,
num_tokens: Dict[str, int] = dict(), # 字典:输入类型->词汇表大小
max_seq_len, # 最大序列长度
attn_layers: AttentionLayers, # 注意力层模块
emb_dim = None, # 嵌入维度
max_mem_len = 0, # 最大记忆长度
shift_mem_down = 0, # 记忆下移层数
emb_dropout = 0., # 嵌入dropout率
post_emb_norm = False, # 是否在嵌入后使用层归一化
num_memory_tokens = None, # 记忆令牌数量
memory_tokens_interspersed_every = None, # 记忆令牌间隔插入频率
return_only_embed = False, # 是否只返回嵌入表示
use_abs_pos_emb = True, # 是否使用绝对位置编码
scaled_sinu_pos_emb = False, # 是否使用缩放正弦位置编码
emb_frac_gradient = 1., # 嵌入梯度分数(GLM-130B和Cogview使用)
attn_z_loss_weight = 1e-4, # 注意力Z损失权重
):
super().__init__()
dim = attn_layers.dim # 注意力层的维度
emb_dim = default(emb_dim, dim) # 默认嵌入维度等于注意力层维度
self.emb_dim = emb_dim
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down
# 判断是否需要绝对位置编码
# 条件:max_seq_len为0 或 不使用绝对位置编码 或 注意力层禁用绝对位置编码
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
# 位置编码选择
if no_abs_pos_emb:
self.pos_emb = always(0) # 总是返回0,即无位置编码
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) # 缩放正弦位置编码
else:
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) # 绝对位置编码
# 为每种输入类型创建独立的嵌入层
# 例如:{"text": 50000, "image": 1000} -> text_embed和image_embed两个嵌入层
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim)
for name, one_num_tokens in num_tokens.items()})
self.emb_frac_gradient = emb_frac_gradient # 嵌入梯度分数
# 嵌入后归一化
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
self.emb_dropout = nn.Dropout(emb_dropout) # 嵌入dropout
# 如果嵌入维度与注意力层维度不同,需要线性投影
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers # 注意力层
# 输出头:为每种输入类型创建独立的线性层
if return_only_embed:
self.to_logits = None # 只返回嵌入,不计算logits
else:
self.to_logits = ModuleDict({name: nn.Linear(dim, logits_dim, bias=False)
for name, logits_dim in num_tokens.items()})
# 记忆令牌(类似BERT的[CLS])
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) # 可学习的记忆令牌
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every # 记忆令牌间隔插入频率
# 是否支持KV缓存(用于自回归生成)
self.can_cache_kv = self.num_memory_tokens == 0 # 有记忆令牌时不支持缓存
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb # 无位置编码时可在最大序列长度外缓存
def forward(
self,
x: Dict[str, Tensor], # 输入字典:输入类型->输入张量
return_embeddings = False, # 是否返回嵌入表示
return_logits_and_embeddings = False, # 是否同时返回logits和嵌入
return_intermediates = False, # 是否返回中间结果
mask = None, # 注意力掩码
return_mems = False, # 是否返回记忆
return_attn = False, # 是否返回注意力图
mems = None, # 外部记忆
mem_masks = None, # 记忆掩码
pos = None, # 外部位置编码
prepend_embeds = None, # 预追加的嵌入(如图像嵌入)
prepend_mask = None, # 预追加嵌入的掩码
sum_embeds = None, # 需要求和的嵌入(用于自条件训练)
return_attn_z_loss = False, # 是否返回注意力Z损失
attn_z_loss_weight = 1e-4, # 注意力Z损失权重
seq_start_pos = None, # 序列起始位置(用于相对位置编码)
cache: LayerIntermediates | None = None, # 缓存中间结果
**kwargs
):
assert not is_empty(x) # 确保输入非空
first_input = list(x.values())[0] # 获取第一个输入用于获取形状信息
# 解包变量:批次大小b,序列长度n,设备device等
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
# 判断是否需要返回隐藏状态
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
return_embeddings = return_embeddings | (not exists(self.to_logits)) # 如果没有输出头,则必须返回嵌入
# 1. 令牌嵌入
assert len(x) == len(self.embeds) # 确保输入类型数量与嵌入层数量匹配
token_emb = 0. # 初始化令牌嵌入
# 对每种输入类型分别进行嵌入并求和
for name, embed_id in x.items():
embed_key = f'{name}_embed'
assert embed_key in self.embeds # 确保有对应的嵌入层
embed = self.embeds[embed_key](embed_id) # 形状: (b, n, emb_dim)
token_emb = token_emb + embed # 多模态嵌入求和
# 2. 位置编码
external_pos_emb = exists(pos) and pos.dtype != torch.long # 判断是否为外部提供的浮点型位置编码
# 获取位置编码:如果提供了外部位置编码且为浮点型,则使用外部编码;否则使用内部位置编码器
pos_emb = self.pos_emb(first_input, pos=pos, seq_start_pos=seq_start_pos) if not external_pos_emb else pos
token_emb = token_emb + pos_emb # 添加位置编码
# 3. 添加外部求和嵌入(用于自条件训练)
if exists(sum_embeds):
token_emb = token_emb + sum_embeds
# 更新x为令牌嵌入
x = token_emb
# 4. 嵌入后归一化
x = self.post_emb_norm(x)
# 5. 预追加嵌入(如PaLI中的图像嵌入)
if exists(prepend_embeds):
prepend_seq, prepend_dim = prepend_embeds.shape[1:] # 获取预追加序列长度和维度
assert prepend_dim == x.shape[-1], '预追加嵌入需要与文本模型维度相同'
# 在序列维度拼接:形状从(b, n, d)变为(b, prepend_seq+n, d)
x = torch.cat((prepend_embeds, x), dim=-2)
# 处理掩码
if exists(prepend_mask) or exists(mask):
mask = default(mask, lambda: torch.ones((b, n), device=device, dtype=torch.bool))
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device=device, dtype=torch.bool))
# 拼接掩码:形状从(b, n)变为(b, prepend_seq+n)
mask = torch.cat((prepend_mask, mask), dim=-1)
# 6. 嵌入梯度分数(部分梯度回传)
if emb_frac_gradient < 1:
assert emb_frac_gradient > 0
# 公式:x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
# 只有emb_frac_gradient比例的梯度会回传到嵌入层
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
# 7. 嵌入dropout
x = self.emb_dropout(x)
# 8. 投影到注意力层维度(如果需要)
x = self.project_emb(x)
# 9. 处理记忆令牌
if has_memory_tokens:
mem_every = self.memory_tokens_interspersed_every # 记忆令牌间隔插入频率
if exists(mem_every):
assert mem_every > 0
# 填充序列长度到mem_every的倍数
next_seq_len = math.ceil(n / mem_every) * mem_every
x = pad_at_dim(x, (0, next_seq_len - n), dim=-2, value=0.)
# 形状变换:(b, next_seq_len, d) -> (b*n_groups, mem_every, d)
x = rearrange(x, 'b (n m) d -> (b n) m d', m=mem_every)
# 重复记忆令牌以匹配批次大小
mem = repeat(self.memory_tokens, 'n d -> b n d', b=x.shape[0])
# 打包记忆令牌和输入:在序列维度拼接
x, mem_packed_shape = pack((mem, x), 'b * d')
# 自动处理掩码:在掩码开头添加记忆令牌的位置
if not exists(mem_every) and exists(mask):
mask = pad_at_dim(mask, (num_mems, 0), dim=-1, value=True)
if exists(mem_every):
# 恢复原始形状
x = rearrange(x, '(b n) m d -> b (n m) d', b=b)
# 10. 记忆下移(如果启用)
if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
mems = [*mems_r, *mems_l] # 交换记忆顺序
# 11. 通过注意力层
x, intermediates = self.attn_layers(
x,
mask=mask,
mems=mems,
mem_masks=mem_masks,
cache=cache,
return_hiddens=True,
seq_start_pos=seq_start_pos,
**kwargs
)
# 12. 注意力层后的记忆令牌处理
if has_memory_tokens:
if exists(mem_every):
# 形状变换以分离记忆令牌
x = rearrange(x, 'b (n m) d -> (b n) m d', m=(mem_every + num_mems))
# 解包:分离记忆令牌和原始输入
mem, x = unpack(x, mem_packed_shape, 'b * d')
intermediates.memory_tokens = mem # 保存记忆令牌到中间结果
if exists(mem_every):
# 恢复形状
x = rearrange(x, '(b n) m d -> b (n m) d', b=b)
x = x[:, :n] # 截取原始序列长度(去掉填充部分)
# 13. 投影到logits
if not return_embeddings:
# 为每种输入类型计算logits
logits = {name: fn(x) for name, fn in self.to_logits.items()}
# 14. 根据返回标志选择输出
if return_logits_and_embeddings:
out = (logits, x) # 返回(logits, 嵌入)
elif return_embeddings:
out = x # 只返回嵌入
else:
out = logits # 只返回logits
# 15. 辅助损失:注意力Z损失
if return_attn_z_loss:
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight=attn_z_loss_weight)
return_intermediates = True
# 16. 记忆处理
if return_mems:
hiddens = intermediates.hiddens # 获取各层隐藏状态
# 将新隐藏状态与旧记忆拼接
new_mems = [torch.cat(pair, dim=-2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems] # 截取并分离梯度
if not return_intermediates:
return out, new_mems # 返回输出和新
---
## 📄 文件: x_transformers/x_transformers.py
# x_transformers.py 代码分析
## 1. 文件功能摘要
这是一个完整的Transformer架构实现库,提供了多种现代Transformer变体、注意力机制、位置编码和归一化方法,支持自注意力、交叉注意力、编码器-解码器架构以及各种先进的训练技巧。
## 2. 核心术语解释
### **LayerNorm** (层归一化)
- **解释**: 对每个样本的特征维度进行归一化,使激活值保持稳定分布
- **变体**:
- <code>RMSNorm</code>: 只计算均方根,不减去均值
- <code>ScaleNorm</code>: 基于向量长度的归一化
- <code>AdaptiveLayerNorm</code>: 根据条件动态调整的层归一化
### **Einsum** (爱因斯坦求和约定)
- **解释**: 使用简洁的符号表示张量运算,如 <code>'b h n d, b h d m -> b h n m'</code> 表示注意力计算
- **用途**: 在代码中广泛用于矩阵乘法、转置等操作
### **Residual** (残差连接)
- **解释**: 将输入直接加到输出上,缓解梯度消失问题
- **变体**:
- <code>GRUGating</code>: 使用GRU门控的残差连接
- <code>HyperConnection</code>: 超连接,支持多残差流
### **Attention** (注意力机制)
- **解释**: 计算查询(Query)、键(Key)、值(Value)之间的相关性权重
- **特性**: 支持多头注意力、旋转位置编码、稀疏注意力等
### **RoPE** (旋转位置编码)
- **解释**: 通过旋转矩阵将位置信息编码到注意力计算中
- **特点**: 相对位置编码,支持长度外推
### **Flash Attention**
- **解释**: 高效的内存感知注意力算法,减少GPU内存访问
- **优势**: 更快的训练速度和更低的内存占用
### **ALiBi** (注意力线性偏置)
- **解释**: 在注意力分数中添加线性偏置,实现相对位置编码
- **特点**: 支持长度外推,无需训练位置编码
## 3. 代码逐行/逐块注释
### 3.1 导入和配置部分
```python
from __future__ import annotations
from typing import Callable
import math
from copy import deepcopy
from random import random, randrange
from functools import partial, wraps
from itertools import chain
from collections import namedtuple
from contextlib import nullcontext
from dataclasses import dataclass
from packaging import version
import torch
from torch.amp import autocast
import torch.nn.functional as F
from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from torch.nn import Module, ModuleList, ModuleDict
from loguru import logger
from x_transformers.attend import Attend, Intermediates
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import einx
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
注释:
- 导入必要的PyTorch和第三方库
einops用于张量操作的重排和重塑Attend是核心注意力计算模块AutoregressiveWrapper用于自回归生成
3.2 爱因斯坦符号约定
# einstein notation
# b - batch
# n - sequence
# d - feature dimension
# h - attention heads
# i, j - sequence (source, target)
注释: 定义了代码中使用的张量维度符号约定
3.3 辅助函数
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def identity(t, *args, **kwargs):
return t
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else (val,) * depth
注释: 常用的工具函数,用于处理可选参数和类型转换
3.4 位置编码类
RotaryEmbedding (旋转位置编码)
class RotaryEmbedding(Module):
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
注释:
- 实现旋转位置编码(RoPE)
use_xpos: 是否使用xPos(可学习的位置缩放)base_rescale_factor: NTK-aware缩放,支持长度外推- 计算逆频率向量用于生成旋转矩阵
apply_rotary_pos_emb (应用旋转位置编码)
@autocast('cuda', enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
freqs = freqs[:, -seq_len:, :]
scale = scale[:, -seq_len:, :] if is_tensor(scale) else scale
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
out = cat((t, t_unrotated), dim = -1)
return out.type(orig_dtype)
注释:
- 将旋转位置编码应用到查询和键张量
- 支持部分旋转(partial rotary),如GPT-J
- 使用
cos和sin函数实现旋转操作 rotate_half函数交换向量的后半部分
3.5 注意力机制类
Attention (核心注意力模块)
class Attention(Module):
def __init__(
self,
dim,
dim_head = DEFAULT_DIM_HEAD,
dim_context = None,
heads = 8,
causal = False,
flash = False,
# ... 大量参数省略
):
super().__init__()
dim_kv = default(dim_context, dim)
self.scale = dim_head ** -0.5 # 缩放因子
self.heads = heads
self.causal = causal
# 查询、键、值投影
self.to_q = LinearNoBias(dim_q_input, q_dim)
self.to_k = LinearNoBias(dim_kv_input, k_dim)
self.to_v = LinearNoBias(dim_kv_input, v_dim)
# 注意力计算模块
self.attend = Attend(
heads = heads,
causal = causal,
flash = flash,
# ... 其他参数
)
注释:
- 实现多头注意力机制
dim_head: 每个注意力头的维度causal: 是否因果(掩码未来位置)flash: 是否使用Flash Attention- 支持分组查询注意力(GQA)、键值头共享等
Attention.forward (前向传播)
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
attn_mask = None,
rel_pos = None,
attn_bias = None,
rotary_pos_emb = None,
# ... 其他参数
):
# 投影查询、键、值
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
# 分割为多头
q = self.split_q_heads(q) # [b, n, h*d] -> [b, h, n, d]
k = self.split_k_heads(k)
v = self.split_v_heads(v)
# 应用旋转位置编码
if exists(rotary_pos_emb):
q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
# 计算注意力
out, intermediates = self.attend(
q, k, v,
mask = final_attn_mask,
attn_bias = attn_bias,
prev_attn = prev_attn
)
# 合并多头
out = self.merge_heads(out) # [b, h, n, d] -> [b, n, h*d]
# 输出投影
out = self.to_out(out)
注释:
- 完整的注意力计算流程
- 支持缓存键值对用于自回归生成
- 处理各种掩码(因果掩码、填充掩码等)
- 返回中间结果用于调试或辅助损失
3.6 前馈网络
FeedForward (前馈网络)
class FeedForward(Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
glu = False,
swish = False,
relu_squared = False,
dropout = 0.,
zero_init_output = False,
):
super().__init__()
inner_dim = int(dim * mult) # 扩展维度
dim_out = default(dim_out, dim)
# 激活函数选择
if relu_squared:
activation = ReluSquared()
elif swish:
activation = nn.SiLU()
else:
activation = nn.GELU()
# GLU门控线性单元
if glu:
proj_in = GLU(dim, inner_dim, activation)
else:
proj_in = nn.Sequential(
nn.Linear(dim, inner_dim),
activation
)
# 前馈网络结构
self.ff = Sequential(
proj_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out),
)
注释:
- 标准的Transformer前馈网络
mult: 隐藏层维度扩展倍数(通常为4)glu: 使用门控线性单元- 支持多种激活函数(GELU、SiLU、ReLU²等)
3.7 注意力层堆叠
AttentionLayers (注意力层堆叠)
class AttentionLayers(Module):
def __init__(
self,
dim,
depth = None,
heads = 8,
causal = False,
cross_attend = False,
pre_norm = True,
rotary_pos_emb = False,
# ... 大量参数
):
super().__init__()
self.dim = dim
self.causal = causal
self.layers = ModuleList([])
# 位置编码
self.rotary_pos_emb = RotaryEmbedding(...) if rotary_pos_emb else None
# 构建层序列
for ind, layer_type in enumerate(layer_types):
if layer_type == 'a': # 自注意力
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
elif layer_type == 'c': # 交叉注意力
layer = Attention(dim, heads = heads, **cross_attn_kwargs)
elif layer_type == 'f': # 前馈网络
layer = FeedForward(dim, **ff_kwargs)
# 残差连接
residual = Residual(dim, scale_residual = scale_residual)
# 归一化层
pre_branch_norm = norm_fn() if pre_norm else None
post_main_norm = norm_fn() if not pre_norm else None
self.layers.append(ModuleList([
norms, # [pre_norm, post_branch_norm, post_main_norm]
layer,
residual
]))
注释:
- 堆叠多个注意力层和前馈层
- 支持预归一化(pre-norm)和后归一化(post-norm)
- 灵活的层配置(自注意力、交叉注意力、前馈网络)
- 支持残差连接和层缩放
3.8 Transformer包装器
TransformerWrapper (Transformer包装器)
class TransformerWrapper(Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers: AttentionLayers,
emb_dim = None,
emb_dropout = 0.,
tie_embedding = False,
# ... 其他参数
):
super().__init__()
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
# 词嵌入
self.token_emb = TokenEmbedding(emb_dim, num_tokens)
# 位置编码
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
# 嵌入层归一化和dropout
self.post_emb_norm = LayerNorm(emb_dim)
self.emb_dropout = nn.Dropout(emb_dropout)
# 维度投影
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
# 注意力层
self.attn_layers = attn_layers
# 输出层
if tie_embedding:
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
else:
self.to_logits = nn.Linear(dim, num_tokens)
注释:
- 完整的Transformer模型包装
- 处理词嵌入、位置编码、归一化
- 支持权重绑定(tie embedding)
- 连接注意力层和输出层
3.9 编码器-解码器架构
XTransformer (编码器-解码器)
class XTransformer(Module):
def __init__(
self,
*,
dim,
tie_token_emb = False,
**kwargs
):
super().__init__()
# 分离编码器和解码器参数
enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
# 编码器
self.encoder = TransformerWrapper(
attn_layers = Encoder(dim = dim, **enc_kwargs)
)
# 解码器
self.decoder = TransformerWrapper(
attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
)
# 权重共享
if tie_token_emb:
self.decoder.token_emb = self.encoder.token_emb
# 自回归包装
self.decoder = AutoregressiveWrapper(self.decoder)
注释:
- 完整的编码器-解码器Transformer
- 支持参数前缀分离(enc*, dec*)
- 可选的词嵌入权重共享
- 使用AutoregressiveWrapper实现自回归生成
3.10 张量维度变换示例
注意力计算中的维度变换
# 查询投影和分割
q = self.to_q(x) # [batch, seq_len, dim] -> [batch, seq_len, heads * dim_head]
q = self.split_q_heads(q) # [batch, seq_len, heads * dim_head] -> [batch, heads, seq_len, dim_head]
# 注意力分数计算
attn_scores = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# q: [batch, heads, query_len, dim_head]
# k: [batch, heads, key_len, dim_head]
# 输出: [batch, heads, query_len, key_len]
# 注意力加权
out = einsum('b h i j, b h j d -> b h i d', attn_weights, v)
# attn_weights: [batch, heads, query_len, key_len]
# v: [batch, heads, key_len, dim_head]
# 输出: [batch, heads, query_len, dim_head]
# 合并多头
out = self.merge_heads(out) # [batch, heads, seq_len, dim_head] -> [batch, seq_len, heads * dim_head]
旋转位置编码维度变换
# 应用旋转位置编码
def apply_rotary_pos_emb(t, freqs, scale = 1):
# t: [batch, heads, seq_len, dim]
# freqs: [batch, seq_len, rot_dim] 或 [1, seq_len, rot_dim]
# 部分旋转(GPT-J风格)
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
# t: [batch, heads, seq_len, rot_dim]
# t_unrotated: [batch, heads, seq_len, dim-rot_dim]
# 旋转计算
t_rotated = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
# t_rotated: [batch, heads, seq_len, rot_dim]
# 合并
out = cat((t_rotated, t_unrotated
---
## 📄 文件: x_transformers/entropy_based_tokenizer.py
# 代码分析报告
## 1. **文件功能摘要**
这个文件实现了一个基于熵的变长分词器,通过计算解码器输出的熵值来动态地将输入序列分割成不同长度的token,主要用于字节-潜在变换器(byte-latent transformer)中的序列分割。
## 2. **核心术语解释**
| 术语 | 解释 |
|------|------|
| **熵 (Entropy)** | 信息论中的概念,衡量随机变量的不确定性。在NLP中,熵值高表示模型对下一个token的预测不确定性大。 |
| **Logits** | 模型输出的原始分数(未经过softmax),通常表示每个可能token的得分。 |
| **Softmax** | 将logits转换为概率分布的函数,确保所有概率之和为1。 |
| **LayerNorm** | 层归一化,对每个样本的特征进行归一化,稳定训练过程。 |
| **Einsum** | Einstein求和约定,用于简洁地表示张量运算(如矩阵乘法、转置等)。 |
| **Residual** | 残差连接,将输入直接加到输出上,帮助梯度流动,缓解深度网络中的梯度消失问题。 |
| **Decoder** | 解码器,通常用于生成任务,这里用于预测下一个token并计算熵。 |
| **Tokenization** | 分词,将输入文本分割成更小的单元(token)的过程。 |
| **Cumsum** | 累积和,计算张量沿某一维度的累积和。 |
| **Pad Sequence** | 填充序列,将不同长度的序列填充到相同长度,便于批量处理。 |
## 3. **代码逐行/逐块注释**
```python
from __future__ import annotations
from itertools import zip_longest
import torch
from torch import tensor
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.utils.rnn import pad_sequence
import einx
from einops import repeat, rearrange, pack, unpack
# helper functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def log(t, eps = 1e-20):
return t.clamp(min = eps).log()
def calc_entropy_from_logits(logits):
prob = logits.softmax(dim = -1) # 将logits转换为概率分布,沿最后一个维度
return -(prob * log(prob)).sum(dim = -1) # 计算熵:-Σ p * log(p),沿最后一个维度求和
# entropy based tokenizer applied in byte-latent transformer paper
# they use a simple entropy threshold for segmenting a string into variable sized tokens
# https://arxiv.org/abs/2412.09871
class EntropyBasedTokenizer(Module):
def __init__(
self,
decoder: Module, # 解码器模型,用于预测下一个token
entropy_threshold: float, # 熵阈值,高于此值则分割token
max_token_size: int | None = None # 最大token长度限制
):
super().__init__()
self.decoder = decoder
self.entropy_threshold = entropy_threshold
self.max_token_size = max_token_size
@torch.no_grad() # 禁用梯度计算,推理模式
def forward(
self,
seq, # 输入序列,形状为 [batch, seq_len] 或 [seq_len]
lens = None, # 每个序列的实际长度(用于变长序列)
return_segmented_seq = False, # 是否返回分割后的序列
decoder_forward_kwargs: dict = dict() # 解码器的额外参数
):
no_batch_dim = seq.ndim == 1 # 检查是否无批次维度
seq, maybe_batch_ps = pack((seq,), '* n') # 统一为批次格式,保留原始形状信息
self.decoder.eval() # 设置解码器为评估模式
is_var_length = exists(lens) # 检查是否提供了变长序列
batch, seq_len, device, max_token_size = *seq.shape, seq.device, self.max_token_size
arange = torch.arange(seq_len, device = device) # 创建位置索引 [0, 1, ..., seq_len-1]
# 通过解码器前向传播,获取logits
logits = self.decoder(seq, **decoder_forward_kwargs)
# 计算每个位置的熵值
entropies = calc_entropy_from_logits(logits)
# 创建长度掩码(用于变长序列)
mask = tensor(True, device = device)
if is_var_length:
# 生成掩码:位置索引 < 序列长度 的位置为True
mask = einx.less('n, b -> b n', arange, lens)
# 创建超过熵阈值的掩码
over_thres_mask = (entropies >= self.entropy_threshold) & mask
# 准备索引张量(每个位置加1)
arange_plus_one = arange + 1
arange_plus_one = repeat(arange_plus_one, 'n -> b n', b = batch) # 扩展为批次维度
# 复制边界掩码
boundaries = over_thres_mask.clone()
# 设置最后一个token的边界
if not is_var_length:
boundaries[..., -1] = True # 固定长度:最后一个位置总是边界
else:
# 变长序列:在每个序列的实际末尾设置边界
scatter_indices = rearrange(lens - 1, 'b -> b 1') # 调整形状为 [batch, 1]
boundaries.scatter_(-1, scatter_indices, True) # 在最后一个有效位置设置True
# 处理最大token大小限制(防止重复子序列被合并成一个大token)
if exists(max_token_size):
# 为每个token分配ID(通过累积和)
token_ids = boundaries.cumsum(dim = -1) # 边界处ID增加
token_ids = F.pad(token_ids, (1, -1), value = 0) # 左移一位,使边界位于token开始
# 计算最大token数量
max_num_tokens = boundaries.sum(dim = -1).amax().item()
token_ids_seq = torch.arange(max_num_tokens, device = device)
# 创建token掩码:标记每个位置属于哪个token
token_mask = einx.equal('j, b i -> b j i', token_ids_seq, token_ids)
# 计算每个token内的子序列位置
token_sub_seq_arange = token_mask.cumsum(dim = -1)
# 创建子序列边界:每max_token_size个位置设置一个边界
sub_seq_boundaries = (token_sub_seq_arange % max_token_size == 0)
sub_seq_boundaries = (sub_seq_boundaries & token_mask).any(dim = 1)
# 合并原始边界和子序列边界
boundaries = boundaries | sub_seq_boundaries
# 应用原始掩码
if exists(mask):
boundaries = boundaries & mask
# 计算每个序列的token数量
num_tokens = boundaries.sum(dim = -1)
# 获取边界位置的索引
indices = arange_plus_one[boundaries].split(num_tokens.tolist())
# 计算每个token的长度
token_lengths = []
for one_indices in indices:
padded_indices = F.pad(one_indices, (1, 0), value =0.) # 在开头填充0
one_token_lengths = padded_indices[1:] - padded_indices[:-1] # 计算相邻索引差作为长度
token_lengths.append(one_token_lengths)
# 将不同长度的token长度填充为相同长度
token_lengths = pad_sequence(token_lengths, batch_first = True)
# 如果不需要返回分割后的序列,直接返回token长度
if not return_segmented_seq:
token_lengths, = unpack(token_lengths, maybe_batch_ps, '* num_tokens')
return token_lengths
# 根据token长度分割原始序列
lens = default(lens, (None,))
segmented_seq = []
for one_seq, one_len, one_token_length in zip_longest(seq, lens, token_lengths):
if exists(one_len):
one_seq = one_seq[:one_len] # 截取有效长度
one_token_length = one_token_length[one_token_length > 0] # 移除填充的0
# 按token长度分割序列
splitted_seq = one_seq.split(one_token_length.tolist())
segmented_seq.append(splitted_seq)
# 如果输入无批次维度,返回单个序列
if no_batch_dim:
segmented_seq = segmented_seq[0]
return segmented_seq</code></pre>
<h3>关键张量变换说明:</h3>
<ol>
<li>
<p><strong>熵计算</strong>:</p>
<ul>
<li>`logits`形状: `[batch, seq_len, vocab_size]`</li>
<li>`prob = logits.softmax(dim=-1)`: 沿vocab_size维度归一化为概率</li>
<li>`entropies = -(prob * log(prob)).sum(dim=-1)`: 沿vocab_size求和,得到`[batch, seq_len]`</li>
</ul>
</li>
<li>
<p><strong>边界检测</strong>:</p>
<ul>
<li>`over_thres_mask`: 形状`[batch, seq_len]`,True表示该位置熵超过阈值</li>
<li>`boundaries`: 在`over_thres_mask`基础上添加序列末尾边界</li>
</ul>
</li>
<li>
<p><strong>最大token大小处理</strong>:</p>
<ul>
<li>`token_ids`: 通过`cumsum`为每个token分配唯一ID,形状`[batch, seq_len]`</li>
<li>`token_mask`: 形状`[batch, max_num_tokens, seq_len]`,标记每个位置属于哪个token</li>
<li>`token_sub_seq_arange`: 在每个token内计数位置</li>
<li>`sub_seq_boundaries`: 当位置计数是`max_token_size`的倍数时设置边界</li>
</ul>
</li>
<li>
<p><strong>token长度计算</strong>:</p>
<ul>
<li>边界索引相减得到每个token的长度</li>
<li>使用`pad_sequence`将不同数量的token填充到相同长度</li>
</ul>
</li>
</ol>
<p>这个分词器的核心思想是:当解码器对下一个token的预测不确定性(熵)超过阈值时,就在当前位置进行分割,从而将序列动态地划分为信息量相对均匀的token。</p>
<hr />
<h2>📄 文件: x_transformers/gpt_vae.py</h2>
<h1>GPT-VAE 代码分析</h1>
<h2>1. 文件功能摘要</h2>
<p>这是一个基于变分自编码器(VAE)架构的GPT模型,用于将文本序列编码为潜在空间表示,并支持通过潜在向量控制文本生成。</p>
<h2>2. 核心术语解释</h2>
<ul>
<li><strong>VAE(变分自编码器)</strong>:一种生成模型,通过编码器将输入映射到潜在空间的概率分布,然后从该分布采样并解码重建输入。</li>
<li><strong>KL散度(KL Divergence)</strong>:衡量两个概率分布差异的指标,在VAE中用于约束潜在空间接近标准正态分布。</li>
<li><strong>重参数化技巧(Reparametrization Trick)</strong>:VAE中用于使采样操作可微的技术,通过从标准正态分布采样并变换得到目标分布的样本。</li>
<li><strong>自回归(Autoregressive)</strong>:模型在生成序列时,每个位置的输出依赖于之前所有位置的输出。</li>
<li><strong>潜在空间(Latent Space)</strong>:编码器将高维输入压缩到的低维表示空间。</li>
<li><strong>Einsum</strong>:爱因斯坦求和约定,用于简洁表达张量运算。</li>
<li><strong>Residual Connection</strong>:残差连接,将输入直接加到网络层的输出上,缓解梯度消失问题。</li>
<li><strong>LayerNorm</strong>:层归一化,对每个样本的特征维度进行归一化。</li>
<li><strong>Attention</strong>:注意力机制,让模型能够关注输入序列的不同部分。</li>
</ul>
<h2>3. 代码逐行/逐块注释</h2>
<pre><code class="language-python">from __future__ import annotations
# 应用ACT(Zhou等人)中的CVAE + DETR设计到GPT
# 用于转向、多样性强化学习、EPO中的map-elites和其他可能性
import torch
from torch import nn, Tensor, is_tensor, tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from x_transformers.x_transformers import (
Encoder,
Decoder,
TransformerWrapper
)
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
from einops.layers.torch import Rearrange
from einops import rearrange, reduce, repeat
# 辅助函数
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# 类定义
class GPTVAE(Module):
def __init__(
self,
*,
num_tokens, # 词汇表大小
dim, # 模型维度
depth, # 解码器层数
enc_depth, # 编码器层数
max_seq_len, # 最大序列长度
dim_latent = None, # 潜在空间维度,默认为模型维度
attn_dim_head = 64, # 注意力头维度
heads = 8, # 注意力头数量
enc_kwargs: dict = dict(), # 编码器额外参数
dec_kwargs: dict = dict(), # 解码器额外参数
vae_kl_loss_weight = 1., # KL散度损失权重
vae_kl_div_floor = 0., # KL散度下限,来自Free Transformer和Kingma 2016
latents_dropout_prob = 0.5, # 完全丢弃潜在向量的概率
pad_id = -1, # 填充token的ID
encoder: Module | None = None, # 可自定义编码器
**kwargs
):
super().__init__()
dim_latent = default(dim_latent, dim) # 如果未指定潜在维度,使用模型维度
# 如果没有提供自定义编码器,创建默认编码器
if not exists(encoder):
encoder = TransformerWrapper(
num_tokens = num_tokens,
max_seq_len = max_seq_len + 1, # +1为潜在token预留位置
return_only_embed = True, # 只返回嵌入表示
average_pool_embed = True, # 对嵌入进行平均池化
attn_layers = Encoder(
dim = dim,
depth = enc_depth,
attn_dim_head = attn_dim_head,
heads = heads,
**kwargs,
**enc_kwargs
),
)
self.encoder = encoder
# 将编码器输出映射到潜在空间的均值和方差
# 输出形状: [batch_size, dim*2] -> 重排为 [2, batch_size, dim_latent]
self.to_latent_mean_log_variance = nn.Sequential(
nn.Linear(dim, dim_latent * 2), # 线性层输出均值和方差
Rearrange('b (two d) -> two b d', two = 2) # 重排维度
)
# 将潜在向量转换为解码器的前置token嵌入
# 形状: [batch_size, dim_latent] -> [batch_size, 1, dim]
self.from_latent_to_prepend_token = nn.Sequential(
nn.Linear(dim_latent, dim), # 线性变换
Rearrange('b d -> b 1 d') # 增加序列维度
)
# 创建解码器
self.decoder = TransformerWrapper(
num_tokens = num_tokens,
max_seq_len = max_seq_len,
attn_layers = Decoder(
dim = dim,
depth = depth,
attn_dim_head = attn_dim_head,
heads = heads,
**kwargs,
**dec_kwargs
),
)
# 用自回归包装器包装解码器
self.ar_wrapped_decoder = AutoregressiveWrapper(self.decoder, ignore_index = pad_id)
self.pad_id = pad_id
# 损失权重配置
self.vae_kl_div_floor = vae_kl_div_floor
self.vae_kl_loss_weight = vae_kl_loss_weight
# 潜在向量dropout
self.latents_dropout = nn.Dropout(latents_dropout_prob)
@property
def device(self):
# 获取模型所在设备
return next(self.parameters()).device
def encode_to_latents(
self,
seq, # 输入序列 [batch_size, seq_len]
return_mean_log_var = False # 是否返回均值和方差
):
# 创建mask,标识非填充token
mask = seq != self.pad_id
# 编码器处理序列,返回池化后的表示 [batch_size, dim]
pooled = self.encoder(seq, mask = mask)
# 获取潜在空间的均值和方差 [2, batch_size, dim_latent]
latents_mean, latents_log_var = self.to_latent_mean_log_variance(pooled)
# 计算标准差
latents_std = (0.5 * latents_log_var).exp()
# 重参数化技巧:从标准正态分布采样并变换
# latents = μ + σ * ε, 其中 ε ~ N(0, I)
latents = latents_mean + latents_std * torch.randn_like(latents_mean)
if not return_mean_log_var:
return latents # 只返回采样后的潜在向量
return latents, (latents_mean, latents_log_var)
@torch.no_grad() # 禁用梯度计算,用于推理
def generate(
self,
prompts, # 提示序列
seq_len, # 生成序列长度
latents = None, # 可选潜在向量
seq_for_latents = None, # 用于推导潜在向量的序列
**generate_kwargs # 生成参数
):
# 验证输入维度
assert prompts.ndim in {1, 2}
batch = prompts.shape[0] if prompts.ndim == 2 else 1
# 如果提供了seq_for_latents,从中推导潜在向量
if exists(seq_for_latents):
assert not exists(latents), 'latents should not be passed in if given the seq from which to derive them'
latents = self.encode_to_latents(seq_for_latents)
# 准备前置嵌入
prepend_embeds = None
if exists(latents):
if not is_tensor(latents):
latents = tensor(latents, device = self.device)
# 处理潜在向量维度
if latents.ndim == 1: # 如果是1D,重复到batch维度
latents = repeat(latents, 'd -> b d', b = batch)
# 将潜在向量转换为前置token嵌入 [batch_size, 1, dim]
prepend_embeds = self.from_latent_to_prepend_token(latents)
# 使用自回归解码器生成文本
generated = self.ar_wrapped_decoder.generate(
prompts,
seq_len,
prepend_embeds = prepend_embeds, # 前置潜在嵌入
**generate_kwargs
)
return generated
def forward(
self,
seq, # 目标序列 [batch_size, seq_len]
seq_for_latents = None, # 用于编码的序列(默认与目标序列相同)
return_all_losses = False # 是否返回所有损失分量
):
batch, device = seq.shape[0], seq.device
# 默认使用相同序列进行编码
seq_for_latents = default(seq_for_latents, seq)
# 编码到潜在空间,返回潜在向量和分布参数
latents, (latents_mean, latents_log_var) = self.encode_to_latents(seq_for_latents, return_mean_log_var = True)
# 创建dropout mask:随机丢弃部分batch的潜在向量
dropped_latents = ~self.latents_dropout(torch.ones((batch,), device = device)).bool()
# 将潜在向量转换为前置嵌入 [batch_size, 1, dim]
prepend_embeds = self.from_latent_to_prepend_token(latents)
# 计算自回归损失(语言建模损失)
# seq_start_pos: 序列起始位置,当潜在向量被dropout时为0,否则为1
ar_loss = self.ar_wrapped_decoder(
seq,
prepend_embeds = prepend_embeds,
seq_start_pos = dropped_latents.long() # 控制是否关注第一个风格潜在token
)
# 计算VAE KL散度损失
# KL(N(μ,σ) || N(0,I)) = 0.5 * (σ² + μ² - log(σ²) - 1)
vae_kl_loss = 0.5 * (
latents_log_var.exp() # σ²
+ latents_mean.square() # μ²
- latents_log_var # -log(σ²)
- 1. # -1
)
# 应用KL散度下限(防止过正则化)
vae_kl_loss = F.relu(vae_kl_loss - self.vae_kl_div_floor)
# 求和并取平均 [batch_size, dim_latent] -> 标量
vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
# 计算总损失
total_loss = (
ar_loss + # 自回归重建损失
vae_kl_loss * self.vae_kl_loss_weight # 加权KL散度损失
)
if not return_all_losses:
return total_loss
losses = (ar_loss, vae_kl_loss)
return total_loss, losses</code></pre>
<h2>关键张量变换维度说明:</h2>
<ol>
<li>
<p><strong>编码器输出变换</strong>:</p>
<ul>
<li>输入序列: `[batch_size, seq_len]`</li>
<li>编码器输出: `[batch_size, dim]`</li>
<li>线性层输出: `[batch_size, dim_latent * 2]`</li>
<li>重排后: `[2, batch_size, dim_latent]`(第一维是均值,第二维是方差)</li>
</ul>
</li>
<li>
<p><strong>潜在向量采样</strong>:</p>
<ul>
<li>均值: `[batch_size, dim_latent]`</li>
<li>方差: `[batch_size, dim_latent]`</li>
<li>采样噪声: `[batch_size, dim_latent]`</li>
<li>最终潜在向量: `[batch_size, dim_latent]`</li>
</ul>
</li>
<li>
<p><strong>潜在向量到前置token</strong>:</p>
<ul>
<li>输入: `[batch_size, dim_latent]`</li>
<li>线性层: `[batch_size, dim]`</li>
<li>重排: `[batch_size, 1, dim]`(作为序列的第一个token)</li>
</ul>
</li>
<li>
<p><strong>KL散度计算</strong>:</p>
<ul>
<li>每个样本的KL散度: `[batch_size, dim_latent]`</li>
<li>求和后: `[batch_size]`</li>
<li>平均后: 标量</li>
</ul>
</li>
</ol>
<hr />
<h2>📄 文件: x_transformers/up_wrapper.py</h2>
<h1>代码分析报告</h1>
<h2>1. 文件功能摘要</h2>
<p>这个文件实现了一个<strong>通用预训练包装器</strong>,通过合成数据生成器(模拟图灵机)为Transformer模型生成训练数据,用于无监督预训练。</p>
<h2>2. 核心术语解释</h2>
<table>
<thead>
<tr>
<th>术语</th>
<th>解释</th>
</tr>
</thead>
<tbody>
<tr>
<td><strong>TransformerWrapper</strong></td>
<td>对Transformer模型的封装类,提供标准接口</td>
</tr>
<tr>
<td><strong>AutoregressiveWrapper</strong></td>
<td>自回归包装器,用于处理序列生成任务</td>
</tr>
<tr>
<td><strong>SyntheticDataGenerator</strong></td>
<td>合成数据生成器,使用LSTM/GRU网络模拟简单图灵机</td>
</tr>
<tr>
<td><strong>UniversalPretrainWrapper</strong></td>
<td>通用预训练包装器,协调数据生成和模型训练</td>
</tr>
<tr>
<td><strong>Residual</strong></td>
<td>残差连接,将输入直接加到输出上,缓解梯度消失</td>
</tr>
<tr>
<td><strong>Buffer</strong></td>
<td>数据缓冲区,存储生成的合成数据用于训练</td>
</tr>
<tr>
<td><strong>Causal Attention</strong></td>
<td>因果注意力,确保每个位置只能看到前面的位置</td>
</tr>
<tr>
<td><strong>Temperature</strong></td>
<td>温度参数,控制softmax分布的平滑程度</td>
</tr>
</tbody>
</table>
<h2>3. 代码逐行/逐块注释</h2>
<h3>导入和工具函数</h3>
<pre><code class="language-python"># https://arxiv.org/abs/2506.20057
# Peter Bloem
from __future__ import annotations
from functools import partial
from random import randrange, uniform
import torch
from torch import nn, cat, tensor, randperm
from torch.nn import LSTM, GRU, Module
from x_transformers.x_transformers import (
TransformerWrapper,
AutoregressiveWrapper
)
# 工具函数
def exists(v):
return v is not None # 检查变量是否存在(非None)
def default(v, d):
return v if exists(v) else d # 如果v存在则返回v,否则返回默认值d
def divisible_by(num, den):
return (num % den) == 0 # 检查num是否能被den整除</code></pre>
<h3>随机序列生成函数</h3>
<pre><code class="language-python">def random_sequences(
num_tokens, # 词汇表大小
seq_len, # 序列长度
num_samples_random, # 随机样本数量
num_samples_constant, # 常数样本数量
shuffle = True, # 是否打乱
device = None # 设备
):
assert num_samples_random > 0 or num_samples_constant > 0
# 生成随机序列:形状为(num_samples_random, seq_len),值在[0, num_tokens)之间
rand_seq = torch.randint(0, num_tokens, (num_samples_random, seq_len))
# 生成常数序列:所有位置填充相同的随机token
const_seq = torch.full((num_samples_constant, seq_len), randrange(num_tokens))
# 合并两种序列
all_seq = cat((rand_seq, const_seq))
if exists(device):
all_seq = all_seq.to(device) # 移动到指定设备
if not shuffle:
return all_seq
# 使用随机排列打乱序列
rand_indices = randperm(all_seq.shape[0], device = all_seq.device)
return all_seq[rand_indices] # 返回打乱后的序列</code></pre>
<h3>合成数据生成器类</h3>
<pre><code class="language-python">class SyntheticDataGenerator(Module):
def __init__(
self,
dim, # 嵌入维度
num_tokens, # 词汇表大小
max_seq_len = 512, # 最大序列长度
hidden_size = None, # 隐藏层大小
use_gru = False, # 是否使用GRU(默认LSTM)
network_klass = None # 自定义网络类
):
super().__init__()
self.max_seq_len = max_seq_len
# 词嵌入层:将token索引映射为dim维向量
self.embed = nn.Embedding(num_tokens, dim)
# 设置隐藏层大小,默认为dim
hidden_size = default(hidden_size, dim)
# 默认使用LSTM或GRU,batch_first=True表示输入形状为(batch, seq, feature)
default_network_klass = partial(LSTM if not use_gru else GRU, batch_first = True)
network_klass = default(network_klass, default_network_klass)
# 创建循环神经网络
self.net = network_klass(dim, hidden_size)
# 线性层:将隐藏状态映射回词汇表logits
self.to_logits = nn.Linear(dim, num_tokens, bias = False)
# 初始化权重
self.apply(self.init_)</code></pre>
<h3>权重初始化和重置方法</h3>
<pre><code class="language-python"> def reset_(self):
# 重置所有可重置参数的模块
for m in self.modules():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
# 重新应用初始化
self.apply(self.init_)
@torch.no_grad()
def init_(self, m):
# 初始化线性层权重:乘以0到1.1之间的随机数
if isinstance(m, nn.Linear):
m.weight *= uniform(0., 1.1)</code></pre>
<h3>序列生成方法</h3>
<pre><code class="language-python"> @torch.inference_mode()
@torch.compile
def generate(
self,
length, # 要生成的序列长度
seed = None, # 种子序列
condition = None, # 条件序列
temperature = 1e-4 # 温度参数(接近贪婪采样)
):
assert exists(seed) or exists(condition)
# 合并种子和条件序列
prefix = [*filter(exists, (seed, condition))]
seq_len = self.max_seq_len
# 初始序列
seq = torch.cat(prefix, dim = -1)
net_input = seq
hiddens = None # 初始隐藏状态
# 自回归生成循环
for _ in range(length):
# 前向传播获取logits和新的隐藏状态
logits, hiddens = self.forward(net_input, hiddens)
# 取最后一个位置的logits
last_logit = logits[:, -1]
# 应用温度缩放并计算概率分布
prob = (last_logit / temperature).softmax(dim = -1)
# 从分布中采样下一个token
sampled = torch.multinomial(prob, 1)
# 将采样的token作为下一时间步的输入
net_input = sampled
# 将新token添加到序列中
seq = torch.cat((seq, sampled), dim = -1)
# 返回最后seq_len个token(保持序列长度一致)
return seq[:, -seq_len:]</code></pre>
<h3>前向传播方法</h3>
<pre><code class="language-python"> def forward(
self,
input, # 输入token序列,形状(batch, seq_len)
hiddens = None # 循环网络的隐藏状态
):
# 词嵌入:将token索引转换为向量,形状(batch, seq_len, dim)
tokens = self.embed(input)
# 循环网络处理:返回输出和新的隐藏状态
# embed形状: (batch, seq_len, hidden_size)
embed, hidden = self.net(tokens, hiddens)
# 将隐藏状态映射回词汇表logits
logits = self.to_logits(embed)
return logits, hidden</code></pre>
<h3>通用预训练包装器类</h3>
<pre><code class="language-python">class UniversalPretrainWrapper(Module):
def __init__(
self,
model: TransformerWrapper, # 要训练的Transformer模型
data_generator: SyntheticDataGenerator | Module | None = None,
buffer_size = None, # 数据缓冲区大小
num_reset = 20, # 每次重置的缓冲区样本数
batch_size = 32, # 批次大小
seq_len = 512, # 序列长度
seed_length = 8, # 种子序列长度
reset_turing_machine_every = 0, # 重置数据生成器的频率
keep_buffer_on_cpu = False # 是否将缓冲区保留在CPU上
):
super().__init__()
self.model = model
# 将模型包装为自回归版本
self.ar_wrapped = AutoregressiveWrapper(model)
# 确保模型使用因果注意力(只能看到前面的token)
assert model.attn_layers.causal
# 获取模型参数
num_tokens = model.num_tokens
dim = model.attn_layers.dim
# 如果没有提供数据生成器,创建默认的
if not exists(data_generator):
data_generator = SyntheticDataGenerator(
num_tokens = num_tokens,
dim = dim,
max_seq_len = seq_len
)
# 设置各种参数
self.reset_turing_machine_every = reset_turing_machine_every
self.seq_len = seq_len
self.data_generator = data_generator
self.seed_length = seed_length
self.batch_size = batch_size
# 设置缓冲区大小(默认为batch_size的20倍)
buffer_size = default(buffer_size, batch_size * 20)
assert buffer_size > batch_size, f'data buffer size must be greater than batch size'
# 确保重置数量是偶数
assert divisible_by(num_reset, 2)
self.num_reset = num_reset
self.buffer_size = buffer_size
# 创建部分应用的随机序列生成函数
self.random_sequences_fn = partial(random_sequences, num_tokens, seq_len)
# 初始化数据缓冲区:一半随机序列,一半常数序列
init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
# 根据设置决定缓冲区存储位置
if keep_buffer_on_cpu:
self.synth_data_buffer = init_data_buffer
else:
self.register_buffer('synth_data_buffer', init_data_buffer)
# 注册步数计数器
self.register_buffer('step', tensor(0))</code></pre>
<h3>设备属性和缓冲区采样方法</h3>
<pre><code class="language-python"> @property
def device(self):
return self.step.device # 返回模型所在的设备
def get_rand_sequences_from_buffer(self, size = None):
size = default(size, self.batch_size)
# 生成随机索引:从缓冲区中随机选择size个样本
rand_indices = randperm(self.buffer_size, device = self.device)[:size]
return self.synth_data_buffer[rand_indices]</code></pre>
<h3>核心训练流程</h3>
<pre><code class="language-python"> def forward(self):
# 算法1的实现
# 1. 从缓冲区获取条件序列
conditions = self.get_rand_sequences_from_buffer()
# 2. 获取种子序列:从缓冲区随机裁剪seed_length长度的片段
seeds = self.get_rand_sequences_from_buffer()
# 创建序列位置索引 [0, 1, ..., seed_length-1]
seq_arange = torch.arange(self.seed_length)
# 为每个批次生成随机起始位置
rand_offset = torch.randint(0, self.seq_len - self.seed_length, (self.batch_size,))
# 计算每个批次中种子序列的起始位置
seq_start_pos = rand_offset[:, None] + seq_arange
# 创建批次索引
batch_arange = torch.arange(self.batch_size, device = self.device)[:, None]
# 从缓冲区提取种子序列
seeds = seeds[batch_arange, seq_start_pos]
# 3. 使用数据生成器生成新序列
# 输入:条件序列 + 种子序列 → 输出:生成的序列
generated = self.data_generator.generate(
self.seq_len,
condition = conditions.to(self.device),
seed = seeds.to(self.device)
)
# 步数加1
self.step.add_(1)
# 4. 定期重置数据生成器(模拟图灵机重置)
if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
self.data_generator.reset_()
# 5. 重置部分缓冲区内容
if self.num_reset > 0:
# 获取要重置的缓冲区位置
buffer_to_reset = self.get_rand_sequences_from_buffer(self.num_reset)
with torch.no_grad():
# 生成新的随机序列(一半随机,一半常数)
reset_sequences = self.random_sequences_fn(self.num_reset // 2, self.num_reset // 2, device = self.device)
# 替换缓冲区中的内容
buffer_to_reset.copy_(reset_sequences)
# 6. 将生成的"增强"序列放回缓冲区
with torch.no_grad():
conditions.copy_(generated)
# 7. 从缓冲区采样训练数据
data = self.get_rand_sequences_from_buffer().to(self.device)
# 8. 使用自回归包装器训练Transformer模型
return self.ar_wrapped(data)</code></pre>
<h2>关键流程总结</h2>
<ol>
<li><strong>数据生成循环</strong>:使用简单的LSTM/GRU网络作为"图灵机"生成合成数据</li>
<li><strong>缓冲区管理</strong>:维护一个数据缓冲区,混合随机序列和生成序列</li>
<li><strong>训练流程</strong>:从缓冲区采样→生成新数据→更新缓冲区→训练Transformer</li>
<li><strong>自回归训练</strong>:使用`AutoregressiveWrapper`进行标准的语言模型训练</li>
</ol>
<p>这种方法的核心思想是<strong>通过简单的合成数据生成器为复杂的Transformer模型提供训练数据</strong>,实现无监督预训练。</p>
<hr />
<h2>📄 文件: x_transformers/attend.py</h2>
<h1>文件功能分析</h1>
<h2>1. 文件功能摘要</h2>
<p>这是一个实现多种注意力机制(包括标准注意力、Flash Attention、稀疏注意力等)的PyTorch模块,提供了丰富的注意力变体和配置选项。</p>
<h2>2. 核心术语解释</h2>
<ul>
<li><strong>Flash Attention</strong>: 一种高效的注意力计算算法,通过分块计算减少内存访问,显著提升计算速度</li>
<li><strong>Talking Heads</strong>: 注意力头之间的信息交互机制,通过1x1卷积在不同注意力头之间传递信息</li>
<li><strong>Causal Mask</strong>: 因果掩码,用于自回归任务,确保每个位置只能看到当前位置及之前的位置</li>
<li><strong>Multi-Query Attention</strong>: 多查询注意力,多个查询头共享相同的键值对,减少计算量</li>
<li><strong>Grouped Attention</strong>: 分组注意力,将注意力头分组,每组共享键值投影</li>
<li><strong>Sparse Top-k Attention</strong>: 稀疏注意力,只保留top-k个最大的注意力分数进行计算</li>
<li><strong>Gumbel Softmax</strong>: 使用Gumbel噪声的softmax,用于可微分的离散采样</li>
<li><strong>Selective Attention</strong>: 选择性注意力,允许token阻止自己被未来的token关注</li>
<li><strong>Cog Attention</strong>: 带符号的注意力,允许负的注意力权重增加表达能力</li>
<li><strong>L2 Distance Attention</strong>: 使用L2距离代替点积计算相似度</li>
<li><strong>Attention Sink</strong>: 注意力汇聚点,额外的可学习token用于稳定注意力分布</li>
</ul>
<h2>3. 代码逐行/逐块注释</h2>
<pre><code class="language-python">from __future__ import annotations
from functools import partial
from typing import Tuple, Callable
import torch
from torch.nn import Module, Parameter
from torch import cat, nn, einsum, Tensor
import torch.nn.functional as F
from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass
from einops import rearrange, repeat, pack, unpack
# 常量定义
@dataclass
class Intermediates:
"""存储注意力计算过程中的中间结果"""
qk_similarities: Tensor | None = None # QK相似度矩阵
pre_softmax_attn: Tensor | None = None # softmax前的注意力分数
post_softmax_attn: Tensor | None = None # softmax后的注意力权重
values: Tensor | None = None # 值向量
cached_kv: tuple[Tensor, Tensor] | None = None # 缓存的键值对
layer_type: str | None = None # 层类型
hybrid_hidden: Tensor | None = None # 混合隐藏状态
def to_tuple(self):
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
# 辅助函数
def exists(val):
"""检查值是否存在(非None)"""
return val is not None
def default(val, d):
"""如果值存在则返回该值,否则返回默认值"""
return val if exists(val) else d
def at_most_one_of(*bools):
"""检查最多只有一个布尔值为True"""
return sum([*map(int, bools)]) <= 1
def compact(arr):
"""过滤掉None值"""
return [*filter(exists, arr)]
@torch.jit.script
def softclamp(t: Tensor, value: float):
"""软截断函数,使用tanh进行平滑限制"""
return (t / value).tanh() * value
def pack_one(t, pattern):
"""使用einops打包张量"""
return pack([t], pattern)
def unpack_one(t, ps, pattern):
"""使用einops解包张量"""
return unpack(t, ps, pattern)[0]
def once(fn):
"""确保函数只执行一次的装饰器"""
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print) # 只打印一次的print函数
# Gumbel softmax相关函数
def log_prob_from_hard_attend(intermeds: Intermediates):
"""从硬注意力中计算对数概率"""
log_probs = intermeds.pre_softmax_attn.log_softmax(dim = -1)
# 获取one-hot索引
one_hot = intermeds.post_softmax_attn.argmax(dim = -1, keepdim = True)
log_prob = log_probs.gather(-1, one_hot)
return rearrange(log_prob, 'b h i 1 -> b h i')
# 选择性注意力
# https://arxiv.org/abs/2410.02703 - section 3.3
# 允许每个token阻止自己被未来的token关注
def selective_attn(
sim,
sim_head_gate = None,
no_mask_sos = True
):
"""选择性注意力实现"""
i, j, device = *sim.shape[-2:], sim.device
sim_head_gate = default(sim_head_gate, sim[:, 0]) # 默认使用第一个头的相似度
gate = F.relu(sim_head_gate) # 只保留正值
if no_mask_sos:
gate = gate.clone()
gate[..., -i] = 0. # 不对序列开始token进行掩码
eye = torch.eye(i, device = device)
if j > i:
eye = F.pad(eye, (j - i, 0), value = 1.) # 填充单位矩阵
gate = (1. - eye) * gate # 不对角线元素应用门控
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # 只允许掩码未来位置
gate = gate.cumsum(dim = -2) # 累积求和
return sim - rearrange(gate, 'b i j -> b 1 i j') # 从相似度中减去门控值
# 替代的距离函数
def qk_l2_dist_squared(q, k):
"""计算Q和K之间的L2距离平方"""
if k.ndim == 3:
k = repeat(k, 'b j d -> b h j d', h = q.shape[1]) # 扩展维度
q, packed_shape = pack_one(q, '* i d')
k, _ = pack_one(k, '* j d')
l2_dist_squared = torch.cdist(q, k) ** 2 # 计算L2距离平方
return unpack_one(l2_dist_squared, packed_shape, '* i j')
# one-hot直通softmax
def one_hot_straight_through(logits, temperature = 1.):
"""one-hot直通softmax,用于硬注意力"""
one_hot_indices = logits.argmax(dim = -1, keepdim = True)
one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.) # 创建one-hot
soft_attn = (logits / temperature).softmax(dim = -1)
# 直通技巧:one-hot + soft_attn - soft_attn.detach()
return one_hot + soft_attn - soft_attn.detach()
# 稀疏topk注意力
def sparse_topk_attn(
logits,
sparse_topk,
temperature = 1.,
straight_through = False
):
"""稀疏topk注意力实现"""
orig_logits = logits
mask_value = -torch.finfo(logits.dtype).max
top_values, _ = logits.topk(sparse_topk, dim = -1) # 获取topk值
sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
logits = logits.masked_fill(~sparse_topk_mask, mask_value) # 掩码非topk位置
topk_attn = logits.softmax(dim = -1)
if not straight_through:
return topk_attn
# 直通版本
soft_attn = (orig_logits / temperature).softmax(dim = -1)
return topk_attn.detach() + soft_attn - soft_attn.detach()
# 创建因果掩码的函数
def create_causal_mask(i, j, device):
"""创建因果掩码(上三角矩阵)"""
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
def onnx_create_causal_mask(i, j, device):
"""ONNX兼容的因果掩码创建"""
r = torch.arange(i, device = device)
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j') # 广播比较
causal_mask = F.pad(causal_mask, (j - i, 0), value = False) # 右侧填充
return causal_mask
# 主类
class Attend(Module):
"""注意力机制主类,支持多种注意力变体"""
def __init__(
self,
*,
dropout = 0., # dropout概率
causal = False, # 是否因果注意力
heads = None, # 注意力头数
pre_talking_heads = False, # softmax前的talking heads
post_talking_heads = False, # softmax后的talking heads
pre_scale_post_talking_heads = False, # 预缩放post talking heads
sparse_topk = None, # 稀疏topk参数
sparse_topk_straight_through = False, # 稀疏topk直通
scale = None, # 缩放因子
qk_norm = False, # 是否对QK进行归一化
l2_distance = False, # 使用L2距离
sigmoid = False, # 使用sigmoid激活
gumbel_softmax = False, # 使用Gumbel softmax
gumbel_softmax_temp = 1., # Gumbel温度
gumbel_softmax_hard = True, # 硬Gumbel softmax
cog_signed = False, # Cog带符号注意力
custom_attn_fn: Callable | None = None, # 自定义注意力函数
flash = False, # 使用Flash Attention
softclamp_logits = False, # 软截断logits
logit_softclamp_value = 50., # logits截断值
add_zero_kv = False, # 添加零KV
head_learned_sink = False, # 学习注意力汇聚点
selective = False, # 选择性注意力
hard = False, # 硬注意力
cope = None, # 上下文位置编码
onnxable = False, # ONNX兼容
sdp_kwargs: dict = dict( # Flash Attention参数
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
):
super().__init__()
self.scale = scale
# 因果相关
self.causal = causal
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
# 注意力类型选择
is_sparse_topk_attn = exists(sparse_topk)
# 参数兼容性检查
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
assert not (flash and hard), 'hard attention not available for flash'
assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
assert at_most_one_of(sigmoid, hard, l2_distance, gumbel_softmax, is_sparse_topk_attn)
# 设置注意力函数
if exists(custom_attn_fn):
self.attn_fn = custom_attn_fn
elif sigmoid:
self.attn_fn = F.sigmoid
elif hard:
self.attn_fn = one_hot_straight_through
elif is_sparse_topk_attn:
self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
elif gumbel_softmax:
self.attn_fn = partial(F.gumbel_softmax, dim = -1, tau = gumbel_softmax_temp, hard = gumbel_softmax_hard)
else:
softmax_fn = partial(F.softmax, dim = -1)
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
# dropout
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
# talking heads初始化
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
# 初始化talking heads权重为单位矩阵
if exists(self.pre_softmax_talking_heads):
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
if exists(self.post_softmax_talking_heads):
nn.init.dirac_(self.post_softmax_talking_heads.weight)
if exists(self.pre_scale_post_talking_heads):
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
# 选择性注意力
assert not (flash and selective), 'selective attention cannot work on flash attention'
assert not (selective and not causal), 'selective attention is designed for autoregressive'
self.selective = selective
# Cog注意力
assert not (flash and cog_signed), 'cog attention not available for flash'
self.cog_signed = cog_signed
# L2距离注意力
self.l2_distance = l2_distance
# 添加零KV
self.add_zero_kv = add_zero_kv
# 学习注意力汇聚点
assert not (head_learned_sink and flash), f'not supported for flash attention yet'
self.head_learned_sink = head_learned_sink
self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
# 软截断logits
if softclamp_logits:
assert not flash, 'flash attention not compatible with logit softclamp value yet'
assert logit_softclamp_value > 0.
self.softclamp_logits = softclamp_logits
self.logit_softclamp_value = logit_softclamp_value
# 上下文位置编码
self.cope = cope
# Flash Attention
self.flash = flash
torch_version = version.parse(torch.__version__)
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 设置Flash Attention后端
if self.flash:
if torch_version >= version.parse('2.3'):
from torch.nn.attention import SDPBackend
str_to_backend = dict(
enable_flash = SDPBackend.FLASH_ATTENTION,
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
enable_math = SDPBackend.MATH,
enable_cudnn = SDPBackend.CUDNN_ATTENTION
)
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
else:
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
def flash_attn(
self,
q, k, v,
mask = None,
attn_bias = None
):
"""Flash Attention实现"""
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# 处理多查询注意力:将KV扩展到多头
if k.ndim == 3:
k = repeat(k, 'b ... -> b h ...', h = q.shape[1])
if v.ndim == 3:
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
# 处理L2距离注意力
if self.l2_distance:
# 扩展维度以计算L2距离
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
k = F.pad(k, (0, 1), value = -1.)
k = cat((k, k_norm_sq), dim = -1)
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
q = cat((2 * q, q_norm_sq), dim = -1)
q = F.pad(q, (0, 1), value = -1.)
# 处理缩放因子
if exists(self.scale):
---
## 📄 文件: x_transformers/continuous.py
# 连续Transformer代码分析
## 1. 文件功能摘要
这个文件实现了一个连续值输出的Transformer模型,支持概率性输出(均值和方差),可用于时间序列预测、连续值生成等任务。
## 2. 核心术语解释
- **LayerNorm**: 层归一化,对每个样本的特征维度进行归一化,稳定训练过程
- **Einsum**: Einstein求和约定,一种简洁的张量运算表示法,用于复杂的张量变换
- **Residual**: 残差连接,将输入直接加到输出上,缓解梯度消失问题
- **Attention**: 注意力机制,计算输入序列中不同位置之间的相关性权重
- **Positional Embedding**: 位置编码,为序列中的每个位置添加位置信息
- **Memory Tokens**: 记忆令牌,可学习的特殊令牌,用于存储全局信息
- **Probabilistic Output**: 概率性输出,模型输出均值和方差,表示预测的不确定性
- **Autoregressive**: 自回归,当前时刻的输出作为下一时刻的输入
- **Rollout**: 展开预测,进行多步预测时逐步生成后续值
- **KV Cache**: 键值缓存,在自回归生成时缓存之前的计算结果以加速推理
## 3. 代码逐行/逐块注释
```python
from __future__ import annotations
import torch
from torch import nn, cat, stack, arange
from torch.nn import Module
import torch.nn.functional as F
from torch.distributions import Normal
import einx
from einops import rearrange, reduce, pack, repeat, unpack
from x_transformers.autoregressive_wrapper import align_right
from x_transformers.x_transformers import (
Attention,
AttentionLayers,
ScaledSinusoidalEmbedding,
AbsolutePositionalEmbedding,
LayerNorm,
masked_mean,
always,
pad_at_dim
)
# helper functions
def exists(val):
"""检查值是否存在(非None)"""
return val is not None
def default(val, d):
"""如果val存在则返回val,否则返回默认值d"""
if exists(val):
return val
return d() if not isinstance(d, Module) and callable(d) else d
def sample_from_mean_variance(
mean,
variance,
eps = 1e-5,
temperature = 1.
):
"""从均值和方差中采样,支持温度调节"""
std = variance.clamp(min = eps).sqrt() # 计算标准差,确保数值稳定性
return torch.normal(mean, std * temperature) # 从正态分布采样
def masked_mean(t, mask):
"""计算带掩码的平均值"""
t = einx.where('b n, b n d, -> b n d', mask, t, 0.) # 将掩码为False的位置置0
num = reduce(t, 'b n d -> b', 'sum') # 计算每个batch的求和
den = mask.sum(dim = -1) # 计算每个batch的有效token数
masked_average = num / den.clamp(min = 1.) # 计算平均值,防止除0
return masked_average
# probabilistic loss fn
class GaussianNLL(Module):
"""高斯负对数似然损失函数,用于概率性输出"""
def forward(self, pred, target):
mean, var = pred # pred包含均值和方差
return F.gaussian_nll_loss(mean, target, var, reduction = 'none')
# main classes
class ContinuousTransformerWrapper(Module):
"""连续值Transformer包装器,支持概率性输出"""
def __init__(
self,
*,
max_seq_len,
attn_layers: AttentionLayers,
dim_in = None,
dim_out = None,
emb_dim = None,
max_mem_len = 0,
num_memory_tokens = None,
post_emb_norm = False,
emb_dropout = 0.,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
average_pool_embed = False,
probabilistic = False,
):
super().__init__()
dim = attn_layers.dim
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
# 判断是否使用绝对位置编码
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
if no_abs_pos_emb:
self.pos_emb = always(0) # 不使用位置编码
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim) # 缩放正弦位置编码
else:
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) # 绝对位置编码
self.post_emb_norm = LayerNorm(dim) if post_emb_norm else nn.Identity() # 嵌入后归一化
self.emb_dropout = nn.Dropout(emb_dropout) # 嵌入dropout
# memory tokens
num_memory_tokens = default(num_memory_tokens, 0)
self.has_memory_tokens = num_memory_tokens > 0
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) # 可学习的记忆令牌
# attention layers
self.attn_layers = attn_layers
# average pool
self.average_pool_embed = average_pool_embed
# project in and out
self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
# 概率性输出时,输出维度乘以2(均值和log方差)
self.probabilistic = probabilistic
self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity()
# 检查是否支持KV缓存
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
def forward(
self,
x,
return_embeddings = False,
return_intermediates = False,
return_mems = False,
mask = None,
lens = None,
return_attn = False,
mems = None,
mem_masks = None,
pos = None,
sum_embeds = None,
prepend_embeds = None,
prepend_mask = None,
cache: LayerIntermediates | None = None,
input_not_include_cache = False,
seq_start_pos = None,
**kwargs
):
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
# 如果传入序列长度,则生成掩码
if exists(lens):
assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
seq_arange = arange(seq, device = device)
# 创建下三角掩码:位置j < 长度i时为True
mask = einx.less('j, i -> i j', seq_arange, lens)
# 处理缓存时的位置偏移
seq_pos_offset = 0
if exists(cache) and input_not_include_cache:
seq_pos_offset = cache.cache_length
# 输入投影 + 位置编码
x = self.project_in(x)
x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos, offset = seq_pos_offset)
if exists(sum_embeds):
x = x + sum_embeds # 添加额外的嵌入
x = self.post_emb_norm(x)
# 添加记忆令牌
if self.has_memory_tokens:
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch) # 扩展记忆令牌到batch维度
x, mem_ps = pack([m, x], 'b * d') # 打包:[记忆令牌, 输入序列]
if exists(mask):
num_mems = m.shape[-2]
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True) # 在掩码前部填充True
# 添加前置嵌入(如图像嵌入)
if exists(prepend_embeds):
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
x = cat((prepend_embeds, x), dim = -2) # 拼接:[前置嵌入, 输入]
if exists(prepend_mask) or exists(mask):
mask = default(mask, lambda: torch.ones((batch, seq), device = device, dtype = torch.bool))
prepend_mask = default(prepend_mask, lambda: torch.ones((batch, prepend_seq), device = device, dtype = torch.bool))
mask = cat((prepend_mask, mask), dim = -1) # 拼接掩码
x = self.emb_dropout(x)
# 注意力层前向传播
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, input_not_include_cache = input_not_include_cache, seq_pos_offset = seq_pos_offset, return_hiddens = True, **kwargs)
# 分离记忆令牌
if self.has_memory_tokens:
m, x = unpack(x, mem_ps, 'b * d') # 解包:分离记忆令牌和输出
intermediates.memory_tokens = m
# 平均池化
if self.average_pool_embed:
x = masked_mean(x, mask = orig_mask)
# 输出投影
out = self.project_out(x) if not return_embeddings else x
# 处理概率性输出:分离均值和方差
if not return_embeddings and self.probabilistic:
# 重排:将最后维度拆分为均值和log方差
mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2)
variance = log_var.exp() # 将log方差转换为方差
out = stack((mean, variance)) # 堆叠为元组
# 返回中间结果
if return_intermediates:
return out, intermediates
# 返回记忆
if return_mems:
hiddens = intermediates.hiddens
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens) # 取最后max_mem_len个token作为记忆
return out, new_mems
# 返回注意力图
if return_attn:
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
return out, attn_maps
return out
class ContinuousAutoregressiveWrapper(Module):
"""连续值自回归包装器,用于训练和生成"""
def __init__(
self,
net: ContinuousTransformerWrapper,
loss_fn: Module | None = None,
use_l1_loss = False,
equal_loss_weight_batch = False, # 如果为True,每个序列的损失权重相同(而不是每个token)
):
super().__init__()
self.net = net
self.max_seq_len = net.max_seq_len
probabilistic = net.probabilistic
self.probabilistic = probabilistic
# 默认损失函数
if not exists(loss_fn):
if probabilistic:
loss_fn = GaussianNLL() # 概率性输出使用高斯NLL
elif use_l1_loss:
loss_fn = nn.L1Loss(reduction = 'none') # L1损失
else:
loss_fn = nn.MSELoss(reduction = 'none') # MSE损失
self.loss_fn = loss_fn
self.equal_loss_weight_batch = equal_loss_weight_batch
@torch.no_grad()
def generate(
self,
start_tokens,
seq_len,
temperature = 1.,
cache_kv = True,
**kwargs
):
"""自回归生成序列"""
should_cache_kv = cache_kv and self.net.can_cache_kv
device = start_tokens.device
was_training = self.net.training
num_dims = start_tokens.ndim
assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
no_batch = num_dims == 2
if no_batch:
start_tokens = rearrange(start_tokens, 'n d -> 1 n d') # 添加batch维度
b, t, _, device = *start_tokens.shape, start_tokens.device
self.net.eval()
out = start_tokens
cache = None
# 自回归生成循环
for _ in range(seq_len):
x = out[:, -self.max_seq_len:] # 取最后max_seq_len个token
net_out, new_cache = self.net(x, cache = cache, return_intermediates = True, **kwargs)
last_output = net_out[..., -1:, :] # 取最后一个token的输出
# 概率性输出需要采样
if self.probabilistic:
mean, var = last_output
last_output = sample_from_mean_variance(mean, var, temperature = temperature)
out = cat((out, last_output), dim = -2) # 将新生成的token添加到序列
if should_cache_kv:
cache = new_cache # 更新缓存
out = out[:, t:] # 只返回新生成的部分
if no_batch:
out = rearrange(out, '1 n d -> n d') # 移除batch维度
self.net.train(was_training)
return out
def forward_rollout(
self,
x,
rollout_steps = 2,
**kwargs
):
"""展开预测:进行多步预测"""
assert rollout_steps > 1
steps = rollout_steps
device = x.device
# 验证输入
assert 'prepend_embeds' not in kwargs
# 处理序列长度
lens = kwargs.pop('lens', None)
if exists(lens):
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
seq_len, device = inp.shape[1], inp.device
seq_arange = arange(seq_len, device = device)
mask = einx.less('j, i -> i j', seq_arange, lens) # 创建掩码
kwargs['mask'] = mask
if not exists(lens):
batch, seq_len = x.shape[:2]
lens = torch.full((batch,), seq_len, device = device)
# 手动处理掩码
mask = kwargs.pop('mask', None)
# 为每个batch样本选择随机范围,并对齐序列用于展开损失
valid_tokens_for_rollout = (lens - steps).clamp(min = 0)
valid_sample = valid_tokens_for_rollout > 0
x = x[valid_sample] # 移除无效序列(长度小于展开步数)
if exists(mask):
mask = mask[valid_sample]
batch = x.shape[0]
# 随机选择起始位置
seq_start_pos = (torch.rand((batch,), device = device) * valid_tokens_for_rollout).floor().long()
batch_arange = torch.arange(batch, device = device)
batch_arange = rearrange(batch_arange, 'b -> b 1')
# 裁剪序列
seq_end_pos = seq_start_pos + steps
max_end_pos = seq_end_pos.amax().item()
x = x[:, :max_end_pos]
x = align_right(x, seq_end_pos) # 右对齐序列
# 分割输入和目标
inp, targets = x[:, :-steps], x[:, -steps:]
# 展开预测循环
cache = None
preds = []
for _ in range(steps):
out, cache = self.net(
inp,
seq_start_pos = seq_start_pos,
return_intermediates = True,
**kwargs
)
last_pred = out[..., -1:, :] # 取最后一个预测
# 概率性输出需要采样
if self.probabilistic:
mean, var = last_pred
inp = sample_from_mean_variance(mean, var)
else:
inp = last_pred
preds.append(last_pred)
# 堆叠预测
preds = cat(preds, dim = 1)
# 计算损失
loss = self.loss_fn(preds, targets)
return loss.mean()
def forward(
self,
x,
rollout_steps = 1, # 在成功的世界模型论文中使用了2步展开
**kwargs
):
"""前向传播,支持单步和多步展开"""
if rollout_steps > 1:
return self.forward_rollout(x, rollout_steps = rollout_steps, **kwargs)
# 标准自回归训练:输入为前n-1个token,目标为后n-1个token
inp, target = x[:, :-1], x[:, 1:]
assert 'prepend_embeds' not in kwargs
# 处理序列长度
lens = kwargs.pop('lens', None)
if exists(lens):
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
seq_len, device = inp.shape[1], inp.device
seq_arange = torch.arange(seq_len, device = device)
mask = einx.less('j
---
## 📄 文件: x_transformers/nonautoregressive_wrapper.py
# 非自回归包装器代码分析
## 1. 文件功能摘要
这个文件实现了一个非自回归Transformer包装器,用于训练和生成基于掩码预测的序列生成模型,支持自条件生成和令牌评判器机制。
## 2. 核心术语解释
### **非自回归 (Non-Autoregressive)**
- **解释**: 与自回归模型(逐个生成令牌)不同,非自回归模型可以同时生成所有位置的令牌,显著提高生成速度。
### **掩码语言建模 (Masked Language Modeling, MLM)**
- **解释**: BERT等模型使用的预训练技术,随机掩码输入序列中的部分令牌,让模型预测被掩码的令牌。
### **自条件生成 (Self-Conditioning)**
- **解释**: 使用模型自身的先前预测作为条件输入来生成后续预测,类似于扩散模型中的条件机制。
### **令牌评判器 (Token Critic)**
- **解释**: 一个辅助网络,用于评估生成的令牌质量,帮助模型学习更好的生成策略。
### **Gumbel采样 (Gumbel Sampling)**
- **解释**: 一种从分类分布中采样的方法,通过添加Gumbel噪声实现可微分的采样,常用于强化学习和离散变量生成。
### **调度函数 (Schedule Function)**
- **解释**: 控制生成过程中掩码比例随时间变化的函数,如线性调度或余弦调度。
### **残差连接 (Residual)**
- **解释**: 将输入直接加到网络层的输出上,有助于梯度流动和训练深度网络。
## 3. 代码逐行/逐块注释
```python
from __future__ import annotations
import math
from random import random
from contextlib import nullcontext
from collections import namedtuple
import torch
from torch import nn, pi
from torch.nn import Module
from torch.func import grad_and_value, vmap
import torch.nn.functional as F
import einx
from einops import rearrange, repeat, pack, unpack
from x_transformers.x_transformers import TransformerWrapper
# 常量定义
Losses = namedtuple('Losses', ['loss', 'generator_loss', 'critic_loss'])
# 定义损失元组,包含总损失、生成器损失和评判器损失
# 辅助函数
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# 存在性检查和默认值设置
# 采样辅助函数
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = logits.topk(k, dim = -1)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(2, ind, val)
return probs
# 保留top-k个最高logits,其余设为负无穷,用于过滤低概率令牌
def log(t, eps = 1e-10):
return torch.log(t + eps)
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 生成Gumbel噪声,用于Gumbel-Softmax采样
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# Gumbel采样:添加噪声后取argmax,实现可微分采样
# 概率辅助函数
def sample_prob(prob):
return random() < prob
def coin_flip():
return sample_prob(0.5)
# 概率采样和抛硬币函数
# 张量辅助函数
def get_mask_subset_prob(mask, prob, min_mask = 0):
batch, seq, device = *mask.shape, mask.device
num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
logits = torch.rand((batch, seq), device = device)
logits = logits.masked_fill(~mask, -1)
# 生成随机数,非掩码位置设为-1
randperm = logits.argsort(dim = -1).argsort(dim = -1).float()
# 获取随机排列顺序
num_padding = (~mask).sum(dim = -1, keepdim = True)
randperm -= num_padding
# 调整排列顺序,考虑填充位置
subset_mask = randperm < num_to_mask
subset_mask.masked_fill_(~mask, False)
return subset_mask
# 从掩码中按概率选择子集
# 调度函数
def linear_schedule(t):
return 1 - t
# 线性调度:掩码比例随时间线性减少
def cosine_schedule(t):
""" https://arxiv.org/abs/2202.04200 """
return torch.cos(t * pi / 2)
# 余弦调度:掩码比例随时间按余弦函数减少
# 自令牌评判器
class SelfCritic(Module):
def __init__(self, net):
super().__init__()
self.net = net
dim = net.attn_layers.dim
self.to_logits = nn.Linear(dim, 1)
# 线性层将嵌入映射到单个评分值
def forward(self, x):
embed = self.net(x, return_embeddings = True)
return self.to_logits(embed)
# 前向传播:获取嵌入并计算评分
class NonAutoregressiveWrapper(Module):
"""
https://arxiv.org/abs/1904.09324 - Mask-Predict论文
https://arxiv.org/abs/2202.04200 - MaskGIT论文
"""
def __init__(
self,
net, # 基础Transformer网络
*,
mask_id, # 掩码令牌ID
steps = 18, # 生成步数
self_cond = False, # 是否使用自条件生成
self_cond_train_prob = 0.75, # 自条件训练概率
no_replace_prob = 0.15, # 保持原令牌的概率(BERT MLM中的设置)
random_token_prob = 0.1, # 替换为随机令牌的概率
schedule = 'linear', # 调度函数类型
can_mask_prev_unmasked = False, # 是否可以重新掩码已解掩码的令牌
token_critic: TransformerWrapper | None = None, # 外部令牌评判器
self_token_critic = False, # 是否使用自令牌评判器
critic_loss_weight = 1., # 评判器损失权重
use_simple_mdlm_loss_weight = True # 是否使用简单MDLM损失加权
):
super().__init__()
assert not (self_token_critic and exists(token_critic))
# 确保不自相矛盾:不能同时使用自评判器和外部评判器
self.net = net
self.dim = net.emb_dim
self.num_tokens = net.num_tokens
self.mask_id = mask_id
# MLM相关概率设置
self.no_replace_prob = no_replace_prob
self.random_token_prob = random_token_prob
self.max_seq_len = net.max_seq_len
self.steps = steps
# 调度函数设置
if callable(schedule):
self.schedule_fn = schedule
elif schedule == 'linear':
self.schedule_fn = linear_schedule
elif schedule == 'cosine':
self.schedule_fn = cosine_schedule
else:
raise ValueError(f'invalid schedule {schedule}')
# 损失加权函数(简单MDLM论文中的方法)
self.loss_weight_fn = None
if use_simple_mdlm_loss_weight:
grad_and_value_schedule_fn = vmap(grad_and_value(self.schedule_fn))
# 使用vmap批量计算梯度和值
def loss_weight_fn(times):
grad, value = grad_and_value_schedule_fn(times)
return grad / (1. - value)
# 公式(10):grad(t) / (1 - s(t))
self.loss_weight_fn = loss_weight_fn
self.can_mask_prev_unmasked = can_mask_prev_unmasked
# 自条件生成设置
self.self_cond = self_cond
if self_cond:
self.null_embed = nn.Parameter(torch.randn(dim))
self.to_self_cond = nn.Linear(dim, dim, bias = False) if self_cond else None
self.self_cond_train_prob = self_cond_train_prob
# 令牌评判器设置
self.token_critic = token_critic
if self_token_critic:
self.token_critic = SelfCritic(net)
self.critic_loss_weight = critic_loss_weight
@torch.no_grad()
def generate(
self,
batch_size = None,
start_temperature = 1., # 起始温度
filter_thres = 0.7, # 过滤阈值
noise_level_scale = 1., # 噪声缩放因子
**kwargs
):
sample_one = not exists(batch_size)
batch_size = default(batch_size, 1)
device = next(self.net.parameters()).device
was_training = self.training
self.eval() # 切换到评估模式
# 生成时间步
times = torch.linspace(0., 1., self.steps + 1)
# 初始化全掩码序列
shape = (batch_size, self.max_seq_len)
seq = torch.full(shape, self.mask_id, device = device)
mask = torch.full(shape, True, device = device)
# 计算每个时间步的掩码数量
all_mask_num_tokens = (self.schedule_fn(times[1:]) * self.max_seq_len).long()
# 自条件生成初始化
has_self_cond = self.self_cond
last_embed = self.null_embed if has_self_cond else None
# 迭代生成过程
for mask_num_tokens, steps_until_x0 in zip(all_mask_num_tokens.tolist(), reversed(range(self.steps))):
# 准备自条件
self_cond = self.to_self_cond(last_embed) if has_self_cond else None
# 前向传播获取logits和嵌入
logits, embeds = self.net(
seq,
sum_embeds = self_cond,
return_logits_and_embeddings = True,
**kwargs
)
if has_self_cond:
last_embed = embeds # 更新自条件嵌入
# 过滤低概率令牌
if exists(filter_thres):
logits = top_k(logits, filter_thres)
# 退火温度计算
annealing_scale = steps_until_x0 / self.steps
temperature = start_temperature * annealing_scale
# 计算概率分布
probs = (logits / max(temperature, 1e-3)).softmax(dim = -1)
# Gumbel采样
sampled_ids = gumbel_sample(logits, temperature = max(temperature, 1e-3))
# 更新掩码位置的令牌
seq = torch.where(mask, sampled_ids, seq)
# 计算评分(用于选择要掩码的令牌)
if exists(self.token_critic):
scores = self.token_critic(seq)
scores = rearrange(scores, 'b n 1 -> b n')
scores = scores + noise_level_scale * gumbel_noise(scores) * annealing_scale
else:
# 使用预测不确定性作为评分
scores = 1 - logits.softmax(dim = -1)
scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1'))
scores = rearrange(scores, 'b n 1 -> b n')
if mask_num_tokens == 0:
pass # 不掩码任何令牌
# 如果不允许掩码已解掩码的令牌
if not self.can_mask_prev_unmasked:
scores = scores.masked_fill(~mask, -torch.finfo(scores.dtype).max)
# 选择要掩码的令牌
mask_indices = scores.topk(mask_num_tokens, dim = -1).indices
mask = torch.zeros_like(scores, dtype = torch.bool).scatter(1, mask_indices, True)
# 掩码选中的令牌
seq = seq.masked_fill(mask, self.mask_id)
self.train(was_training) # 恢复原始训练状态
if sample_one:
seq = rearrange(seq, '1 n -> n')
return seq
def forward(
self,
x, # 输入序列 [batch_size, seq_len]
only_train_generator = False, # 仅训练生成器
only_train_critic = False, # 仅训练评判器
generator_sample_temperature = None, # 生成器采样温度
**kwargs
):
b, n, device = *x.shape, x.device
assert n == self.max_seq_len
orig_seq = x.clone() # 保存原始序列
# 生成随机时间和随机排列
rand_times = torch.empty(b, device = device).uniform_(0, 1)
batched_randperm = torch.rand((b, n), device = device).argsort(dim = -1).float()
# 计算掩码
rand_probs = self.schedule_fn(rand_times)
num_tokens_mask = (rand_probs * n).clamp(min = 1.)
mask = batched_randperm < rearrange(num_tokens_mask, 'b -> b 1')
# 维度变化:num_tokens_mask从[b]变为[b, 1],与batched_randperm比较
# 创建掩码输入(类似BERT MLM)
replace_mask_id_mask = mask.clone()
frac_seq_left = 1.
# 部分令牌保持原样(no replace)
if self.no_replace_prob > 0. and coin_flip():
frac_seq_left -= self.no_replace_prob
no_replace_prob_mask = get_mask_subset_prob(mask, self.no_replace_prob)
replace_mask_id_mask &= ~no_replace_prob_mask
# 部分令牌替换为随机令牌
if self.random_token_prob > 0. and coin_flip():
random_token_prob_mask = get_mask_subset_prob(replace_mask_id_mask, self.random_token_prob * frac_seq_left)
random_tokens = torch.randint(0, self.num_tokens, (b, n), device = device)
x = torch.where(random_token_prob_mask, random_tokens, x)
replace_mask_id_mask &= ~random_token_prob_mask
# 创建掩码输入序列
masked = torch.where(replace_mask_id_mask, self.mask_id, x)
# 自条件生成
if self.self_cond:
self_cond = self.null_embed
if sample_prob(self.self_cond_train_prob):
with torch.no_grad():
self_cond = self.net(masked, return_embeddings = True, **kwargs).detach()
kwargs.update(sum_embeds = self.to_self_cond(self_cond))
# 前向传播获取logits
context = torch.no_grad if only_train_critic else nullcontext
with context():
logits = self.net(masked, **kwargs)
# 选择损失函数
loss_fn = F.cross_entropy if not self.net.output_is_log_prob else F.nll_loss
# 计算生成器损失
if exists(self.loss_weight_fn):
# 使用简单MDLM损失加权
loss = loss_fn(
rearrange(logits, 'b n l -> b l n'), # 维度变化:[batch, seq_len, vocab] -> [batch, vocab, seq_len]
orig_seq,
reduction = 'none'
)
loss_weights = self.loss_weight_fn(rand_times) # 计算损失权重
loss = einx.multiply('b n, b', loss, loss_weights) # 应用损失权重
loss = loss[mask].mean() # 只计算掩码位置的损失
else:
# 标准交叉熵损失
loss = loss_fn(
logits[mask], # 只计算掩码位置的logits
orig_seq[mask], # 只计算掩码位置的目标
)
# 如果没有评判器或只训练生成器,直接返回
if not exists(self.token_critic) or only_train_generator:
return Losses(loss, loss, None)
# 采样生成序列
sampled_ids = gumbel_sample(logits, temperature = default(generator_sample_temperature, random()))
generated = torch.where(mask, sampled_ids, orig_seq)
# 计算评判器损失
critic_logits = self.token_critic(generated)
critic_labels = (sampled_ids != orig_seq).float() # 标签:预测错误的为1,正确的为0
critic_loss = F.binary_cross_entropy_with_logits(
rearrange(critic_logits, '... 1 -> ...'), # 去掉最后一个维度
critic_labels
)
# 根据训练模式确定总损失
if only_train_critic:
total_loss = critic_loss
loss = None
else:
total_loss = loss + critic_loss * self.critic_loss_weight
return Losses(total_loss, loss, critic_loss)
📄 文件: x_transformers/dpo.py
1. 文件功能摘要
这个文件实现了直接偏好优化(Direct Preference Optimization, DPO)算法,用于基于人类偏好数据微调语言模型,使其生成更符合人类偏好的文本。
2. 核心术语解释
- DPO (Direct Preference Optimization):一种强化学习算法,通过直接优化策略模型与参考模型的对数概率比值来对齐人类偏好,避免了传统的奖励模型训练。
- TransformerWrapper:x_transformers库中的Transformer包装器,包含完整的Transformer架构(嵌入层、注意力层、前馈层等)。
- LayerNorm:层归一化,用于稳定深度神经网络的训练,对每个样本的特征进行归一化。
- Einsum:爱因斯坦求和约定,用于简洁表达张量运算(如矩阵乘法、转置、求和等)。
- Residual Connection:残差连接,将输入直接加到输出上,缓解梯度消失问题。
- Logits:模型输出的原始分数(未归一化),通常在线性层后得到。
- Log Probability:对数概率,对softmax输出取对数,用于数值稳定性。
- Mask:掩码,用于标记序列中的有效位置(如非填充位置、非提示部分)。
- Beta (β):DPO中的温度参数,控制策略模型与参考模型差异的强度。
3. 代码逐行/逐块注释
from copy import deepcopy
import torch
from torch.nn import Module
import torch.nn.functional as F
from x_transformers.x_transformers import TransformerWrapper
import einx
from einops import rearrange
# helper functions
def exists(v):
return v is not None # 检查变量是否为None的辅助函数
def freeze_all_layers_(module):
for param in module.parameters():
param.requires_grad = False # 冻结模型所有参数,使其在训练中不更新
def log_prob_from_model_and_seq(model, seq):
src_seq, tgt_seq = seq[:, :-1], seq[:, 1:] # 将序列分割为输入(src_seq)和目标(tgt_seq)
logits = model(src_seq) # 模型前向传播,得到每个位置的logits
log_prob = logits.log_softmax(dim = -1) # 计算对数概率(沿最后一个维度,即词汇表维度)
return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq) # 提取目标token对应的对数概率
# einx.get_at: 从log_prob中按tgt_seq索引取值,b=batch, n=序列长度, l=词汇表大小
def masked_mean(log_probs, mask = None):
if not exists(mask):
return log_probs.mean(dim = -1) # 若无掩码,直接计算均值
if mask.shape[-1] == (log_probs.shape[-1] + 1):
mask = mask[:, :-1] # 如果掩码长度比log_probs多1(包含EOS),则去掉最后一个位置
log_probs = log_probs.masked_fill(~mask, 0.) # 将掩码为False的位置置0
num = log_probs.sum(dim = -1) # 计算有效位置的对数概率和
den = mask.sum(dim = -1) # 计算有效位置数量
return num / den.clamp(min = 1e-5) # 计算均值,分母最小为1e-5避免除零
def maybe_and_mask(*masks):
masks = [*filter(exists, masks)] # 过滤掉None的掩码
if len(masks) == 0:
return None # 如果没有有效掩码,返回None
mask, *rest_masks = masks
for rest_mask in rest_masks:
mask = mask & rest_mask # 将所有掩码进行逻辑与操作,得到交集
return mask
# main class
class DPO(Module):
def __init__(
self,
model: TransformerWrapper,
*,
beta = 0.1,
pad_id = None
):
super().__init__()
self.policy_model = model # 策略模型(待优化的模型)
self.ref_model = deepcopy(model) # 参考模型(初始模型的深拷贝)
freeze_all_layers_(self.ref_model) # 冻结参考模型参数
self.beta = beta # DPO温度参数,控制策略与参考模型的差异程度
self.pad_id = pad_id # 填充token的ID,用于生成掩码
def parameters(self):
return self.policy_model.parameters() # 只返回策略模型的参数,用于优化器更新
def forward(
self,
preferred_seq, # 偏好序列(人类更喜欢的回复)
unpreferred_seq, # 非偏好序列(人类不喜欢的回复)
*,
prompt_mask, # 提示部分的掩码(True表示提示部分)
preferred_seq_mask = None, # 偏好序列的有效token掩码
unpreferred_seq_mask = None, # 非偏好序列的有效token掩码
):
assert preferred_seq.ndim == 2 # 确保输入是二维张量 [batch, seq_len]
assert preferred_seq.shape == unpreferred_seq.shape # 确保两个序列形状相同
# 如果提供了pad_id,自动生成掩码(非填充位置为True)
if exists(self.pad_id):
if not exists(preferred_seq_mask):
preferred_seq_mask = preferred_seq != self.pad_id
if not exists(unpreferred_seq_mask):
unpreferred_seq_mask = unpreferred_seq != self.pad_id
"""
Following Appendix B in https://arxiv.org/abs/2305.18290
DPO算法实现,参考原论文附录B
"""
# 计算参考模型的对数概率(不计算梯度)
with torch.no_grad():
self.ref_model.eval() # 将参考模型设为评估模式
ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq)
ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq)
# 计算策略模型的对数概率(计算梯度)
policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)
# 对对数概率进行掩码平均(只计算非提示部分)
preferred_seq_mask = maybe_and_mask(~prompt_mask, preferred_seq_mask) # 排除提示部分
unpreferred_seq_mask = maybe_and_mask(~prompt_mask, unpreferred_seq_mask)
# 对每个序列的对数概率进行掩码平均,得到每个样本的标量值
ref_preferred_logprob, policy_preferred_logprob = map(
lambda t: masked_mean(t, preferred_seq_mask),
(ref_preferred_logprob, policy_preferred_logprob)
)
ref_unpreferred_logprob, policy_unpreferred_logprob = map(
lambda t: masked_mean(t, unpreferred_seq_mask),
(ref_unpreferred_logprob, policy_unpreferred_logprob)
)
# DPO核心公式
policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob # 策略模型的偏好对数比
ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob # 参考模型的偏好对数比
# DPO损失函数:-log σ(β * (π_θ - π_ref))
losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))
return losses.mean() # 返回批次平均损失
关键张量维度变化说明:
- 输入序列:`preferred_seq` 和 `unpreferred_seq` 形状为 `[batch_size, seq_length]`
- 模型输出:`logits` 形状为 `[batch_size, seq_length-1, vocab_size]`
- 对数概率:`log_prob` 形状为 `[batch_size, seq_length-1, vocab_size]`
- 目标对数概率:`log_prob_from_model_and_seq` 返回形状为 `[batch_size, seq_length-1]`
- 掩码平均后:每个序列的对数概率变为 `[batch_size]` 形状
- 最终损失:`losses` 形状为 `[batch_size]`,经 `.mean()` 后变为标量
DPO算法核心思想:
通过最大化以下目标函数来优化策略模型:
[
L{DPO} = -\mathbb{E} \left[ \log \sigma \left( \beta \left( \log \frac{\pi\theta(yw|x)}{\pi\theta(yl|x)} - \log \frac{\pi{ref}(yw|x)}{\pi{ref}(y_l|x)} \right) \right) \right]
]
其中:
- (y_w) 是偏好序列,(y_l) 是非偏好序列
- (\pi\theta) 是策略模型,(\pi{ref}) 是参考模型
- (\beta) 是温度参数
- (\sigma) 是sigmoid函数
该损失函数鼓励策略模型对偏好序列给出比非偏好序列更高的相对概率(相对于参考模型)。
📄 文件: x_transformers/xval.py
XValTransformer 代码分析
1. 文件功能摘要
这个文件实现了一个特殊的 Transformer 模型(XValTransformer),它能够同时处理离散的 token 和连续的数值数据,特别适用于需要混合符号和数值计算的场景(如数学推理、算术运算等)。
2. 核心术语解释
| 术语 | 解释 |
|---|---|
| XVal | "Cross-Valued" 的缩写,指模型能同时处理离散 token 和连续数值 |
| LayerNorm | 层归一化,用于稳定神经网络训练,对每个样本的特征进行归一化 |
| Einsum | Einstein 求和约定,用于简洁表达张量运算(通过 einops 库实现) |
| Residual | 残差连接,将输入直接加到输出上,有助于梯度流动和深层网络训练 |
| AttentionLayers | 注意力层堆叠,包含多头注意力机制和前馈网络 |
| TokenEmbedding | Token 嵌入层,将离散 token ID 映射为连续向量 |
| ScaledSinusoidalEmbedding | 缩放的正弦位置编码,相对位置编码的一种变体 |
| AbsolutePositionalEmbedding | 绝对位置编码,为每个位置学习独立的嵌入向量 |
| Memory Tokens | 记忆 token,可学习的特殊 token,用于存储全局信息 |
| Autoregressive | 自回归,模型逐个生成输出,每个输出依赖于之前的输出 |
3. 代码逐行/逐块注释
"""
regular transformer with discrete tokens, but continuous for number
generalizes better for arithmetic
https://arxiv.org/abs/2310.02989
"""
# 文件说明:这是一个能同时处理离散 token 和连续数值的 Transformer 模型
# 特别适用于算术运算,参考论文 arXiv:2310.02989
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Callable
from collections import namedtuple
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
from x_transformers.x_transformers import (
AttentionLayers,
TokenEmbedding,
ScaledSinusoidalEmbedding,
AbsolutePositionalEmbedding,
always
)
from x_transformers.autoregressive_wrapper import (
top_k,
top_p
)
# 定义命名元组用于结构化返回
LossBreakdown = namedtuple('LossBreakdown', ['cross_entropy_loss', 'numerical_mse_loss'])
GenerateReturn = namedtuple('GenerateReturn', ['sampled_token_ids', 'sampled_numbers', 'is_number_mask'])
# 辅助函数
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
XValTransformerWrapper 类
class XValTransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens, # 词汇表大小(离散 token 数量)
max_seq_len, # 最大序列长度
numerical_token_id, # 表示数值的特殊 token ID
attn_layers: AttentionLayers, # 注意力层模块
emb_dim = None, # 嵌入维度
logits_dim = None, # 输出 logits 维度
tie_embedding = False, # 是否绑定输入输出嵌入权重
max_mem_len = 0, # 最大记忆长度
num_memory_tokens = None, # 记忆 token 数量
emb_dropout = 0., # 嵌入 dropout 率
use_abs_pos_emb = True, # 是否使用绝对位置编码
scaled_sinu_pos_emb = False # 是否使用缩放正弦位置编码
):
super().__init__()
dim = attn_layers.dim
emb_dim = default(emb_dim, dim) # 默认使用注意力层的维度
self.emb_dim = emb_dim
self.token_emb = TokenEmbedding(emb_dim, num_tokens) # token 嵌入层
self.numerical_token_id = numerical_token_id # 数值 token 的特殊标识
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
# 位置编码配置
if not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb):
self.pos_emb = always(0) # 不使用位置编码
elif scaled_sinu_pos_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim) # 缩放正弦位置编码
else:
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) # 绝对位置编码
self.emb_dropout = nn.Dropout(emb_dropout) # 嵌入 dropout
# 记忆 tokens 初始化
num_memory_tokens = default(num_memory_tokens, 0)
self.has_memory_tokens = num_memory_tokens > 0
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
# 注意力层
self.attn_layers = attn_layers
# 输出层
logits_dim = default(logits_dim, num_tokens)
# 如果绑定嵌入权重,使用 token 嵌入矩阵的转置作为输出层
self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
# 数值输出层:将隐藏状态映射为单个数值
self.to_numerical_output = nn.Sequential(
nn.Linear(dim, 1), # 线性层输出单个值
Rearrange('... 1 -> ...') # 移除最后一个维度
)
forward 方法
def forward(
self,
x: Tensor, # 离散 token 输入 [batch, seq_len]
x_num: Tensor, # 连续数值输入 [batch, seq_len]
return_embeddings = False, # 是否返回嵌入而不是 logits
return_intermediates = False, # 是否返回中间结果
return_mems = False, # 是否返回记忆
mask = None, # 注意力掩码
return_attn = False, # 是否返回注意力图
mems = None, # 外部记忆
pos = None, # 自定义位置编码
prepend_embeds = None, # 预添加的嵌入(如图像特征)
**kwargs
):
assert x.shape == x_num.shape # 确保两个输入形状一致
batch = x.shape[0]
# 创建数值 token 的掩码:标记哪些位置是数值 token
is_number_mask = x == self.numerical_token_id
# 1. token 嵌入
x = self.token_emb(x) # [batch, seq_len] -> [batch, seq_len, emb_dim]
# 2. 数值缩放:对数值 token 的位置进行缩放
scale = torch.where(is_number_mask, x_num, 1.) # 数值位置用 x_num,其他位置用 1
scale = rearrange(scale, '... -> ... 1') # [batch, seq_len] -> [batch, seq_len, 1]
x = x * scale # 对嵌入进行缩放
# 3. 添加位置编码
x = x + self.pos_emb(x, pos = pos)
# 4. 添加记忆 tokens(如果有)
if self.has_memory_tokens:
m = repeat(self.memory_tokens, 'm d -> b m d', b = batch) # 扩展记忆 tokens
x, mem_ps = pack([m, x], 'b * d') # 将记忆 tokens 拼接到序列开头
if exists(mask):
num_mems = m.shape[-2]
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
# 5. 预添加其他嵌入(如 PaLI 中的图像嵌入)
if exists(prepend_embeds):
_, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], '预添加嵌入的维度必须与模型维度相同'
x = torch.cat((prepend_embeds, x), dim = -2) # 在序列维度拼接
# 6. 应用 dropout
x = self.emb_dropout(x)
# 7. 通过注意力层
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
# 8. 分离记忆 tokens(如果有)
if self.has_memory_tokens:
m, x = unpack(x, mem_ps, 'b * d') # 分离记忆 tokens
intermediates.memory_tokens = m # 保存到中间结果
# 9. 生成输出
if not return_embeddings:
logits = self.to_logits(x) # 离散 token 的 logits [batch, seq_len, num_tokens]
numerical_pred = self.to_numerical_output(x) # 数值预测 [batch, seq_len]
out = (logits, numerical_pred) # 返回两个输出
else:
out = x # 直接返回嵌入
# 10. 根据参数返回不同结果
if return_intermediates:
return out, intermediates
if return_mems:
hiddens = intermediates.hiddens
new_mems = tuple(t[..., -self.max_mem_len:, :].detach() for t in hiddens)
return out, new_mems
if return_attn:
attn_maps = tuple(t.post_softmax_attn for t in intermediates.attn_intermediates)
return out, attn_maps
return out
XValAutoregressiveWrapper 类
class XValAutoregressiveWrapper(nn.Module):
def __init__(
self,
net: XValTransformerWrapper, # XValTransformer 模型
ignore_index = -100, # 忽略的索引(用于损失计算)
pad_value = 0, # 填充值
numerical_loss_weight = 1. # 数值损失的权重
):
super().__init__()
self.net = net
self.max_seq_len = net.max_seq_len
self.numerical_loss_weight = numerical_loss_weight
self.ignore_index = ignore_index
generate 方法(自回归生成)
@torch.no_grad()
def generate(
self,
start_tokens: Tensor, # 起始 tokens [batch, seq_len]
start_numbers: Tensor, # 起始数值 [batch, seq_len]
seq_len, # 要生成的序列长度
filter_logits_fn: Callable = top_k, # logits 过滤函数
filter_kwargs: dict = dict(), # 过滤参数
temperature = 1., # 采样温度
**kwargs
):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)
assert num_dims >= 2, '起始 tokens 的维度必须大于等于 2'
assert start_tokens.shape == start_numbers.shape
b, t, device = *start_tokens.shape, start_tokens.device
self.net.eval() # 切换到评估模式
out = start_tokens # 生成的 tokens
num_out = start_numbers # 生成的数值
# 自回归生成循环
for _ in range(seq_len):
# 取最后 max_seq_len 个 tokens 作为输入
x = out[:, -self.max_seq_len:]
x_num = num_out[:, -self.max_seq_len:]
# 前向传播
logits, numerical_pred = self.net(x, x_num, **kwargs)
# 取最后一个位置的输出
last_logits = logits[:, -1] # [batch, num_tokens]
last_num_pred = numerical_pred[:, -1:] # [batch, 1]
# 过滤 logits(如 top-k, top-p)
filtered_logits = filter_logits_fn(last_logits, **filter_kwargs)
# 计算概率分布
probs = F.softmax(filtered_logits / temperature, dim=-1)
# 采样下一个 token
sample = torch.multinomial(probs, 1) # [batch, 1]
# 添加到输出序列
out = torch.cat((out, sample), dim = -1)
num_out = torch.cat((num_out, last_num_pred), dim = -1)
# 只返回新生成的部分
out = out[:, t:]
num_out = num_out[:, t:]
# 标记哪些位置是数值 token,非数值位置用 nan 填充
is_number = out == self.net.numerical_token_id
num_out = torch.where(is_number, num_out, float('nan'))
self.net.train(was_training) # 恢复原来的训练状态
return GenerateReturn(out, num_out, is_number) # 返回结构化结果
forward 方法(训练)
def forward(
self,
x: Tensor, # 完整序列 tokens [batch, seq_len]
x_num: Tensor, # 完整序列数值 [batch, seq_len]
return_loss_breakdown = False, # 是否返回损失分解
**kwargs
):
# 1. 准备输入和目标(自回归:预测下一个 token)
inp, target = x[:, :-1], x[:, 1:] # tokens
x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:] # 数值
# 2. 创建目标掩码(忽略特定索引)
target_mask = target != self.ignore_index
# 3. 处理注意力掩码
mask = kwargs.get('mask', None)
if exists(mask):
target_mask &= mask # 结合两种掩码
if mask.shape[1] == x.shape[1]:
mask = mask[:, :-1] # 调整掩码形状与输入匹配
kwargs['mask'] = mask
# 4. 前向传播
logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
# 5. 计算交叉熵损失(离散 token 损失)
logits = rearrange(logits, 'b n c -> b c n') # 调整为 PyTorch cross_entropy 需要的形状
cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
# 6. 准备数值目标
target_is_number_mask = target == self.net.numerical_token_id # 标记数值位置
x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.) # 非数值位置填充 0
# 7. 计算数值 MSE 损失
numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
numerical_mse_loss = numerical_mse_loss * target_mask # 应用目标掩码
numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.) # 非数值位置填充 0
# 8. 组合损失
loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
loss = loss[target_mask] # 只计算有效位置的损失
loss = loss.mean() # 取平均
# 9. 返回结果
if not return_loss_breakdown:
return loss
return loss, LossBreakdown(cross_entropy_loss, numerical_mse_loss)
关键张量变换维度变化总结
-
输入阶段:
- `x`: [batch, seq_len] → `token_emb(x)`: [batch, seq_len, emb_dim]
- `x_num`: [batch, seq_len] → `rearrange(x_num, '... -> ... 1')`: [batch, seq_len, 1]
- 缩放操作:`x * scale`: [batch, seq_len, emb_dim] × [batch, seq_len, 1] → [batch, seq_len, emb_dim]
-
记忆 tokens 添加:
- `memory_tokens`: [num_memory_tokens, dim] → `repeat(...)`: [batch, num_memory_tokens, dim]
- 拼接后:`x`: [batch, seq_len + num_memory_tokens, dim]
-
注意力层:
- 输入输出维度不变:[batch, seq_len, dim]
-
输出层:
- `to_logits(x)`: [batch, seq_len, dim] → [batch, seq_len, num_tokens]
- `to_numerical_output(x)`: [batch, seq_len, dim] → [batch, seq_len, 1] → [batch, seq_len]
-
损失计算:
- `logits`: [batch, seq_len, num_tokens] → `rearrange(...)`: [batch, num_tokens, seq_len]
- `cross_entropy_loss`: [batch, seq_len]
- `numerical_mse_loss`: [batch, seq_len]
这个模型的核心创新在于能够同时处理离散符号和连续数值,通过特殊的数值 token 标识和缩放机制,使得模型在算术推理等任务上表现更好。
📄 文件: x_transformers/free_transformer.py
FreeTransformer 代码分析
1. 文件功能摘要
这是一个基于"Free Transformer"论文(arXiv:2510.17558)实现的变分自编码器-解码器架构,通过二进制潜在表示进行序列生成,支持条件生成和自回归训练。
2. 核心术语解释
- LayerNorm/RMSNorm: 层归一化,用于稳定神经网络训练,RMSNorm是LayerNorm的变体,只对输入进行缩放而不进行中心化
- Einsum: Einstein求和约定,用于简洁表达张量运算,特别是多维数组的乘积和求和
- Residual Connection: 残差连接,将输入直接加到输出上,有助于梯度流动和深层网络训练
- Rotary Positional Embedding: 旋转位置编码,通过旋转矩阵将位置信息注入到注意力机制中
- Straight-Through Estimator: 直通估计器,在离散采样中允许梯度反向传播的技巧
- KL Divergence: KL散度,衡量两个概率分布差异的指标,这里用于潜在表示的熵正则化
- KV Cache: 键值缓存,在自回归生成中缓存之前的键值对以提高推理效率
- Binary Entropy: 二元熵,衡量二元随机变量的不确定性
- Gumbel Sample: Gumbel采样,用于从分类分布中可微分地采样
3. 代码逐行/逐块注释
导入和辅助函数
from __future__ import annotations
# https://arxiv.org/abs/2510.17558
# François Fleuret
# https://www.youtube.com/watch?v=Nao16-6l6dQ
import math
import torch
from torch import nn, Tensor, is_tensor, tensor, arange
import torch.nn.functional as F
from torch.nn import Module, ModuleList
# 导入x_transformers库中的组件
from x_transformers.x_transformers import (
Encoder,
Decoder,
TransformerWrapper
)
from x_transformers.autoregressive_wrapper import (
gumbel_sample,
top_p,
top_k
)
# 导入einops库用于张量操作
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat, einsum, pack, unpack
# 辅助函数
def exists(v):
return v is not None # 检查变量是否存在(非None)
def default(v, d):
return v if exists(v) else d # 如果v存在则返回v,否则返回默认值d
def log(t, eps = 1e-20):
return t.clamp_min(eps).log() # 安全的对数计算,防止数值下溢
def pack_with_inverse(t, pattern):
packed, ps = pack([t], pattern) # 按照pattern打包张量
def inverse(out, inv_pattern = None):
inv_pattern = default(inv_pattern, pattern)
unpacked, = unpack(out, ps, inv_pattern) # 按照原始形状解包
return unpacked
return packed, inverse # 返回打包后的张量和解包函数
BinaryMapper类 - 二进制映射器
NAT = math.log(2) # 自然对数2,表示1比特的信息量
def binary_entropy(logits):
# 计算二元分布的熵
prob = logits.sigmoid() # 将logits转换为概率
not_prob = 1. - prob # 互补概率
# 熵公式: -Σ p * log(p)
return -(prob * F.logsigmoid(logits) + not_prob * F.logsigmoid(-logits)).sum(dim = -1)
class BinaryMapper(Module):
def __init__(
self,
bits = 1,
kl_loss_threshold = NAT # 1 bit
):
super().__init__()
self.bits = bits # 每个潜在变量的比特数
self.num_codes = 2 ** bits # 可能的编码数量
# 生成所有可能的二进制编码
power_two = 2 ** arange(bits) # [1, 2, 4, ...] 用于二进制到十进制的转换
# 生成所有2^bits个编码,每个编码长度为bits
codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
# 注册为缓冲区(不参与梯度计算但保存在模型中)
self.register_buffer('power_two', power_two, persistent = False)
self.register_buffer('codes', codes, persistent = False)
# 辅助损失相关
self.kl_loss_threshold = kl_loss_threshold # KL损失阈值
self.register_buffer('zero', tensor(0.), persistent = False) # 零张量
def forward(
self,
logits,
temperature = 1.,
straight_through = None,
calc_aux_loss = None
):
# 设置默认值:训练时使用直通估计器并计算辅助损失
straight_through = default(straight_through, self.training)
calc_aux_loss = default(calc_aux_loss, self.training)
# 验证输入维度
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
# 温度采样:调整概率分布
prob_for_sample = (logits / temperature).sigmoid()
# 采样:根据概率生成二进制位
sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
# 将二进制转换为十进制索引
indices = (self.power_two * sampled_bits).sum(dim = -1)
# 生成one-hot编码
one_hot = F.one_hot(indices, self.num_codes).float()
# 计算辅助KL损失
aux_kl_loss = self.zero
if calc_aux_loss:
# 计算负熵:bits * log(2) - 实际熵
kl_div = self.bits * NAT - binary_entropy(logits)
# 只惩罚超过阈值的部分
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
# 直通估计器:允许梯度通过离散采样
if straight_through:
# 计算soft G:通过logits计算每个编码的soft概率
soft_G = (
# 计算编码中位为1的概率贡献
einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
# 计算编码中位为0的概率贡献
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
).exp() # 取指数得到概率
# 直通技巧:one_hot + soft_G - detach(soft_G)
# 前向传播使用one_hot,反向传播使用soft_G的梯度
one_hot = one_hot + soft_G - soft_G.detach()
return one_hot, aux_kl_loss # 返回one-hot编码和KL损失
FreeTransformer类 - 主模型
class FreeTransformer(Module):
def __init__(
self,
*,
num_tokens, # 词汇表大小
dim, # 模型维度
dec_head_depth, # 解码器头部深度
dec_tail_depth, # 解码器尾部深度
max_seq_len, # 最大序列长度
enc_depth = 1, # 编码器深度
dim_latent = None, # 潜在空间维度
attn_dim_head = 64, # 注意力头维度
heads = 8, # 注意力头数
latent_bits = 16, # 潜在变量比特数
per_token_latents = True, # 是否为每个token生成潜在变量
kl_loss_threshold = NAT, # KL损失阈值
binary_mapper_kwargs: dict = dict(), # 二进制映射器参数
enc_kwargs: dict = dict(), # 编码器参数
dec_kwargs: dict = dict(), # 解码器参数
kl_loss_weight = 1., # KL损失权重
latent_dropout_prob = 0., # 潜在变量dropout概率
pad_id = -1, # 填充token ID
**kwargs
):
super().__init__()
dim_latent = default(dim_latent, dim) # 默认潜在维度等于模型维度
# 词嵌入层
self.token_emb = nn.Embedding(num_tokens, dim)
# 输出层(无偏置)
self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
# 用于查询潜在变量的可学习token
self.query_token_for_latents = nn.Parameter(torch.randn(dim) * 1e-2)
self.per_token_latents = per_token_latents # 是否每个token都有潜在变量
# 编码器:用于从输入生成潜在变量
self.encoder = Encoder(
dim = dim,
depth = enc_depth,
attn_dim_head = attn_dim_head,
heads = heads,
only_cross = True, # 只使用交叉注意力
cross_attend = True, # 启用交叉注意力
use_rmsnorm = True, # 使用RMSNorm
rotary_pos_emb = True, # 使用旋转位置编码
pre_norm_has_final_norm = True, # 预归一化有最终归一化
**kwargs,
**enc_kwargs
)
# 将编码器输出映射到潜在比特logits
self.to_latent_bit_logits = nn.Linear(dim, latent_bits, bias = False)
# 二进制映射器
self.binary_mapper = BinaryMapper(
latent_bits,
kl_loss_threshold,
**binary_mapper_kwargs
)
# 从潜在变量到条件向量的映射
self.from_latent_to_condition = nn.Linear(self.binary_mapper.num_codes, dim, bias = False)
# 潜在变量dropout
self.latent_dropout = nn.Dropout(latent_dropout_prob)
# 解码器头部(可选)
self.decoder_head = Decoder(
dim = dim,
depth = dec_head_depth,
attn_dim_head = attn_dim_head,
heads = heads,
rotary_pos_emb = True,
use_rmsnorm = True,
pre_norm_has_final_norm = False,
**kwargs,
**dec_kwargs
) if dec_head_depth > 0 else None
# 解码器尾部(必须存在)
assert dec_tail_depth > 0
self.decoder_tail = Decoder(
dim = dim,
depth = dec_tail_depth,
attn_dim_head = attn_dim_head,
heads = heads,
rotary_pos_emb = True,
use_rmsnorm = True,
pre_norm_has_final_norm = True,
**kwargs,
**dec_kwargs
)
self.pad_id = pad_id # 填充ID
self.kl_loss_weight = kl_loss_weight # KL损失权重
@property
def device(self):
return next(self.parameters()).device # 获取模型所在设备
def encode_to_latents(
self,
decoder_head_embeds, # 解码器头部输出
mask = None, # 注意力掩码
return_kl_loss = False, # 是否返回KL损失
per_token_latents = None # 是否每个token都有潜在变量
):
per_token_latents = default(per_token_latents, self.per_token_latents)
batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
# 复制查询token
query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
encoder_kwargs = dict()
# 处理每个token都有潜在变量的情况
if per_token_latents:
# 为每个token位置复制查询token
query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
# 设置旋转位置编码
rotary_pos = torch.arange(seq_len, device = device)
encoder_kwargs.update(
pos = rotary_pos, # 查询token的位置
context_pos = rotary_pos # 上下文token的位置
)
# 编码器前向传播
pooled = self.encoder(
query_tokens,
context = decoder_head_embeds, # 上下文是解码器头部输出
context_mask = mask, # 上下文掩码
**encoder_kwargs
)
# 生成潜在比特logits
bit_logits = self.to_latent_bit_logits(pooled)
# 通过二进制映射器得到one-hot潜在变量和KL损失
one_hot_latents, kl_loss = self.binary_mapper(bit_logits, calc_aux_loss = return_kl_loss)
if not return_kl_loss:
return one_hot_latents
return one_hot_latents, kl_loss
@torch.no_grad() # 推理时不计算梯度
def generate(
self,
prompts, # 提示序列
seq_len, # 生成序列长度
latents = None, # 潜在变量(用于条件生成)
filter_logits_fn = top_p, # logits过滤函数
logit_filter_kwargs: dict = dict(thres = 0.9), # 过滤参数
use_kv_cache = True # 是否使用KV缓存
):
# 打包输入并获取解包函数
prompts, inverse_pack = pack_with_inverse(prompts, '* n')
batch = prompts.shape[0]
# 处理潜在变量条件
condition = None
if exists(latents):
if not is_tensor(latents):
latents = tensor(latents, device = self.device)
if latents.dtype in (torch.int, torch.long):
# 如果给定的是索引,转换为one-hot
latents = F.one_hot(latents, self.binary_mapper.num_codes).float()
# 调整潜在变量维度
if latents.ndim == 1: # 重复潜在变量
latents = repeat(latents, 'd -> b 1 d', b = batch)
elif latents.ndim == 2:
latents = rearrange(latents, 'b d -> b 1 d')
# 将潜在变量映射为条件向量
condition = self.from_latent_to_condition(latents)
# KV缓存初始化
head_cache = tail_cache = None
# 生成过程
prompt_len = prompts.shape[-1]
generated = prompts # 已生成的序列
tokens = self.token_emb(generated) # 词嵌入
# 自回归生成
for _ in range(max(0, seq_len - prompt_len)):
# 解码器头部(如果存在)
if exists(self.decoder_head):
head_embed, next_head_cache = self.decoder_head(tokens, cache = head_cache, return_hiddens = True)
else:
head_embed, next_head_cache = tokens, None
# 处理旋转位置编码的偏移
seq_pos_offset = head_cache.cache_length if exists(head_cache) else 0
# 解码器尾部
tail_embed, next_tail_cache = self.decoder_tail(
head_embed,
cache = tail_cache,
seq_pos_offset = seq_pos_offset,
self_attn_kv_residuals = condition, # 条件向量作为自注意力的KV残差
return_hiddens = True
)
tail_embed = tail_embed[:, -1] # 只取最后一个token
# 生成logits并采样
logits = self.token_unembed(tail_embed)
logits = filter_logits_fn(logits, **logit_filter_kwargs) # 过滤logits
sampled = gumbel_sample(logits) # Gumbel采样
# 更新生成序列
generated, _ = pack((generated, sampled), 'b *')
tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
# 更新KV缓存
if use_kv_cache:
head_cache = next_head_cache
tail_cache = next_tail_cache
# 返回解包后的生成序列
return inverse_pack(generated)
def forward(
self,
seq, # 输入序列
seq_for_latents = None, # 用于生成潜在变量的序列(可选)
return_all_losses = False # 是否返回所有损失
):
batch, device = seq.shape[0], seq.device
# 准备自回归训练:输入为seq[:-1],标签为seq[1:]
seq, labels = seq[:, :-1], seq[:, 1:]
tokens = self.token_emb(seq) # 词嵌入
# 解码器头部(如果存在)
if exists(self.decoder_head):
tokens = self.decoder_head(tokens)
# 确定用于编码潜在变量的序列
if exists(seq_for_latents):
# 使用单独的序列生成潜在变量
tokens_for_latents = self.token_emb(seq_for_latents)
if exists(self.decoder_head):
tokens_for_latents = self.decoder_head(tokens_for_latents)
encoder_mask = seq_for_latents != self.pad
---
