当前位置: 首页 > news >正文

深入浅出 Multi-Head Attention:原理 + 例子 + PyTorch 实现

本文带你一步步理解 Transformer 中最核心的模块:多头注意力机制(Multi-Head Attention)。从原理到实现,配图 + 举例 + PyTorch 代码,一次性说清楚!


什么是 Multi-Head Attention?

简单说,多头注意力就是一种让模型在多个角度“看”一个序列的机制。

在自然语言中,一个词的含义往往依赖于上下文,比如:

“我把苹果给了她”

模型在处理“苹果”时,需要关注“我”“她”“给了”等词,多头注意力就是这样一种机制——从多个角度理解上下文关系。


Self-Attention 是什么?为什么还要多头?

在讲“多头”之前,咱们先回顾一下基础的 Self-Attention

Self-Attention(自注意力)机制的目标是:

让每个词都能“关注”整个句子里的其他词,融合上下文。

它的核心步骤是:

  1. 对每个词生成 Query、Key、Value 向量

  2. 用 Query 和所有 Key 做点积,算出每个词对其他词的关注度(打分)

  3. 用 Softmax 得到权重,对 Value 加权平均,生成当前词的新表示

这样做的好处是:词的语义表示不再是孤立的,而是上下文相关的。


Self-Attention vs Multi-Head Attention

但问题是——单头 Self-Attention 视角有限。就像一个老师只能从一种角度讲课。

于是,Multi-Head Attention 应运而生

特性Self-Attention(单头)Multi-Head Attention(多头)
输入映射矩阵一组 Q/K/V 线性变换多组 Q/K/V,每个头一组
学习角度单一视角多角度并行理解
表达能力有限更丰富、强大
结构简单并行多个头 + 合并输出

一句话总结:

Multi-Head Attention = 多个不同“视角”的 Self-Attention 并行处理 + 合并结果


 多头注意力:8个脑袋一起思考!

多头 = 多个“单头注意力”并行处理!

每个头使用不同的线性变换矩阵,所以能从不同视角处理数据:

  • 第1个头可能专注短依赖(like 动词和主语)

  • 第2个头可能专注实体关系(我 vs 她)

  • 第3个头可能关注时间顺序(“给了”前后)

  • ……共用同一个输入,学习到不同特征!

多头的步骤:

  1. 将输入向量(如512维)拆成多个头(比如8个,每个64维)

  2. 每个头独立进行 attention

  3. 所有头的输出拼接

  4. 再过一次线性变换,融合成最终输出


 PyTorch 实现(简洁版)

我们来看下 PyTorch 中的简化实现:

import torch
import torch.nn as nn
import copydef clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])def attention(query, key, value, mask=None, dropout=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = torch.softmax(scores, dim=-1)if dropout:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attnclass MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super().__init__()assert d_model % h == 0self.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = query.size(0)query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for lin, x in zip(self.linears, (query, key, value))]x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

举个例子:多头在实际模型中的作用

假设输入是句子:

"The animal didn't cross the street because it was too tired."

多头注意力的不同头可能会:

  • 🧠 头1:关注“animal”和“it”之间的指代关系;

  • 📐 头2:识别“because”和“tired”之间的因果联系;

  • 📚 头3:注意句子的结构层次……

所以说,多头注意力本质上是一个“并行注意力专家系统”!


 总结

项目解释
目的提升模型表达能力,从多个角度理解输入
核心机制将向量分头 → 每头独立 attention → 合并输出
技术关键view, transpose, matmul, softmax, 拼接线性层

推荐学习路径

  • 🔹 理解 Self-Attention 的点积公式

  • 🔹 搞懂 view, transpose 等张量操作

  • 🔹 看 Transformer 整体结构,关注每层作用

http://www.xdnf.cn/news/17785.html

相关文章:

  • 数字信号处理技术架构与功能演进
  • 鸿蒙语言基础
  • 如何在直播App中集成美颜SDK?人脸美型功能从0到1实现指南
  • 基于 HT 数字孪生智慧交通可视化系统
  • 安卓App中调用升级接口并实现版本检查和升级功能的完整方案
  • IP检测工具“ipjiance”
  • MySQL锁详解
  • 2025年大数据实训室建设及大数据实训平台解决方案
  • Vmware esxi 查看硬盘健康状况
  • 【深度学习】张量计算:爱因斯坦求和约定|tensor系列03
  • 如何才能学会代数几何,代数几何的前置学科是什么
  • 使用Trae CN分析项目架构
  • 理解.NET Core中的配置Configuration
  • 时序逻辑电路——序列检测器
  • 【Contiki】Contiki process概述
  • 基于slimBOXtv 9.16 V2-晶晨S905L3A/ S905L3AB-Mod ATV-Android9.0-线刷通刷固件包
  • 铁氧体和纳米晶:车载定制电感的材料选择
  • 什么是Python单例模式
  • 解决方德桌面操作系统V5.0-G23没ll命令的问题
  • 以太网交换机介绍
  • Docker compose使用、容器迁移
  • 3个实用的脚本
  • Linux系统编程---多进程
  • Python3.14都有什么重要新特性
  • 聚合直播-Simple Live-v1.7.7-全网直播平台能在一个软件上看完
  • java+postgresql+swagger-多表关联insert操作(九)
  • C++ 常用的智能指针
  • 使用Docker搭建开源Email服务器
  • 高防IP如何针对DDoS攻击特点起防护作用
  • 小刚说C语言刷题——1033 判断奇偶数