多头自注意力(基础)
一、设计
多头自注意力的核心目标是突破单头注意力的局限,捕捉文本不同维度的语义关联,比如语法层面的搭配关联、语义层面的语义相似度关联等,从而提升模型的特征提取能力和整体表现力。其核心逻辑围绕“拆分-独立计算-拼接”展开,核心前提是满足隐藏层维度D与头数h、单头维度d_k的等式关系:D = h×d_k,例如当h=8、D=512时,单头维度d_k=64,这是后续所有维度变换的基础,必须保证D能被h整除,否则无法完成均匀拆分。

二、完整计算流程
输入的Query(Q)、Key(K)、Value(V)三维矩阵维度统一为[batch_size, seq_len, D]
其中batch_size为批次大小,seq_len为序列长度,D为隐藏层维度;
超参数h为注意力头数,d_k为单头维度,由D和h计算得出(d_k = D/h)。
整个计算流程分为5个关键步骤,各步骤衔接紧密,维度变换需严格遵循规范,示例如下(batch_size=4、D=512、h=8、d_k=64):
第一步是线性投影预处理。由于原始Q、K、V的特征维度虽为D,但特征表达需适配注意力计算,因此需对三者分别执行无激活函数的线性变换,得到新的Q'、K'、V'。该步骤不改变维度,变换后Q'、K'、V'仍为[4, seq_len, 512],核心作用是调整特征分布,为后续多头拆分和注意力计算做准备。
第二步是多头拆分。这一步是多头注意力的关键操作,需沿矩阵的最后一维(即D维),将Q'、K'、V'分别拆分为h个独立的子矩阵。以Q'为例,原始维度[4, seq_len, 512],拆分后得到8个维度为[4, seq_len, 64]的子矩阵,每个子矩阵对应一个注意力头,后续将由每个头独立完成注意力计算,捕捉不同维度的关联特征。
第三步是单头缩放点积注意力计算,这是每个头独立执行的核心环节,分为4个细分操作,维度变换贯穿始终:
-
计算注意力分数:将单个头的Q_i与K_i的转置做矩阵乘法,Q_i维度为[4, seq_len, 64],K_i转置后维度为[4, 64, seq_len],相乘后得到注意力分数矩阵,维度为[4, seq_len, seq_len],该矩阵表征序列中每个位置与其他所有位置的关联强度。
-
缩放操作:将注意力分数除以√d_k(此处为√64=8),目的是缓解维度增长导致的注意力分数过大、softmax归一化后梯度消失的问题,缩放后矩阵维度保持不变,仍为[4, seq_len, seq_len]。
-
归一化操作:对缩放后的注意力分数执行softmax函数,沿序列长度维度(seq_len)归一化,使每个位置的注意力分数总和为1,便于后续加权求和,维度依旧不变。
-
加权求和:将归一化后的注意力分数矩阵与单个头的V_i做矩阵乘法,V_i维度为[4, seq_len, 64],相乘后得到该头的注意力输出,维度为[4, seq_len, 64],即每个位置的特征由序列中所有位置的V特征加权得到。
第四步是多头拼接。当h个头均完成单头注意力计算后,会得到h个维度为[4, seq_len, 64]的输出矩阵,此时需沿最后一维将这h个子矩阵拼接,还原为原始的D维(512维),拼接后输出矩阵维度为[4, seq_len, 512],与线性投影后的维度保持一致。
第五步是最终线性变换。为了进一步融合多头注意力的特征,通常会对拼接后的矩阵再执行一次线性变换,输出维度仍为[4, seq_len, 512],作为整个多头自注意力模块的最终输出
三、维度拆分C++实现
结合上述流程,实现将D维矩阵拆分为h个d_k维矩阵的功能
1. 代码实现
#include
#include
#include
#include
using Matrix3D = std::vector>>;
// 核心拆分函数
std::vector splitMultiHead(const Matrix3D& input, int h) {
if (input.empty()) {
throw std::invalid_argument("输入矩阵为空");
}
int batch_size = input.size();
int seq_len = input[0].size();
int D = input[0][0].size();
if (D % h != 0) {
throw std::invalid_argument(
"D=" + std::to_string(D) + " 不能被 h=" + std::to_string(h) + " 整除"
);
}
int d_k = D / h;
std::vector output(h);
for (int head = 0; head < h; ++head) {
output[head].resize(batch_size);
for (int b = 0; b < batch_size; ++b) {
output[head][b].resize(seq_len);
for (int s = 0; s < seq_len; ++s) {
output[head][b][s].resize(d_k);
}
}
}
for (int b = 0; b < batch_size; ++b) {
for (int s = 0; s < seq_len; ++s) {
for (int head = 0; head < h; ++head) {
for (int d = 0; d < d_k; ++d) {
int col = head * d_k + d;
output[head][b][s][d] = input[b][s][col];
}
}
}
}
return output;
}
// 矩阵创建函数
Matrix3D createTestMatrix(int batch_size, int seq_len, int D, float fill_val = 1.0f) {
Matrix3D matrix(batch_size);
for (int b = 0; b < batch_size; ++b) {
matrix[b].resize(seq_len);
for (int s = 0; s < seq_len; ++s) {
matrix[b][s].resize(D, fill_val);
}
}
return matrix;
}
// 维度打印函数
void printShape(const Matrix3D& matrix) {
if (matrix.empty()) {
std::cout << "空矩阵" << std::endl;
return;
}
std::cout << "[" << matrix.size() << ", " << matrix[0].size() << ", " << matrix[0][0].size() << "]" << std::endl;
}
// 测试函数
void testSplit(int batch_size, int seq_len, int D, int h) {
Matrix3D input = createTestMatrix(batch_size, seq_len, D);
std::cout << "输入维度: ";
printShape(input);
try {
auto split_mat = splitMultiHead(input, h);
std::cout << "头数: " << split_mat.size() << ", 单头维度: ";
printShape(split_mat[0]);
std::cout << "拆分成功" << std::endl;
}
catch (const std::exception& e) {
std::cerr << "拆分失败: " << e.what() << std::endl;
}
}
int main() {
// 配置
int batch = 4, seq = 10, D = 512, h = 8;
testSplit(batch, seq, D, h);
// 测试另一组
std::cout << std::endl;
testSplit(2, 20, 256, 4);
return 0;
}
2. 代码说明
核心逻辑围绕splitMultiHead函数展开,先通过输入矩阵动态获取batch_size、seq_len、D三个关键维度,校验D与h的整除关系,再初始化h个子矩阵,最后通过四层循环按列拆分原矩阵,将对应维度的元素赋值给每个头的子矩阵。







