当前位置: 首页 > news >正文

13.编码器的结构

从入门AI到手写Transformer-13.编码器的结构

  • 13.编码器的结构
  • 代码

整理自视频 老袁不说话 。

13.编码器的结构

T r a n s f o r m e r E n c o d e r : 输入 [ b , n ] TransformerEncoder:输入[b,n] TransformerEncoder:输入[b,n]

  • E m b e d d i n g : − > [ b , n , d ] Embedding:->[b,n,d] Embedding:>[b,n,d]
  • P o s i t i o n a l E n c o d e r : − > [ b , n , d ] PositionalEncoder:->[b,n,d] PositionalEncoder:>[b,n,d]
  • D r o p o u t : − > [ b , n , d ] Dropout:->[b,n,d] Dropout:>[b,n,d]
  • E n c o d e r B l o c k : [ b , n , d ] − > [ b , n , d ] EncoderBlock:[b,n,d]->[b,n,d] EncoderBlock:[b,n,d]>[b,n,d] 重复N次
    • M u l t i h e a d A t t e n t i o n : 3 ∗ [ b , n , d ] − > [ b , n , d ] MultiheadAttention:3*[b,n,d]->[b,n,d] MultiheadAttention:3[b,n,d]>[b,n,d]
    • D r o p o u t : [ b , n , d ] − > [ b , n , d ] Dropout:[b,n,d]->[b,n,d] Dropout:[b,n,d]>[b,n,d]
    • A d d N o r m : 2 ∗ [ b , n , d ] ( D r o u p o u t 输出, M u l t i h e a d A t t e n t i o n 输入 ) − > [ b , n , d ] AddNorm:2*[b,n,d](Droupout输出,MultiheadAttention输入)->[b,n,d] AddNorm:2[b,n,d](Droupout输出,MultiheadAttention输入)>[b,n,d]
    • F F N : [ b , n , d ] − > [ b , n , d ] FFN:[b,n,d]->[b,n,d] FFN:[b,n,d]>[b,n,d]
    • D r o p o u t : [ b , n , d ] − > [ b , n , d ] Dropout:[b,n,d]->[b,n,d] Dropout:[b,n,d]>[b,n,d]
    • A d d N o r m : 2 ∗ [ b , n , d ] ( D r o u p o u t 输出, F F N 输入 ) − > [ b , n , d ] AddNorm:2*[b,n,d](Droupout输出,FFN输入)->[b,n,d] AddNorm:2[b,n,d](Droupout输出,FFN输入)>[b,n,d]
      在这里插入图片描述
      编码器结构
      在这里插入图片描述
      多处执行Dropout

代码

import torch.nn as nnclass Embedding(nn.Module):def __init__(self,*args,**kwargs)->None:super().__init__(*args,**kwargs)def forward(self):print(self.__class__.__name__)
class PositionalEncoding(nn.Module):def __init__(self,*args,**kwargs)->None:super().__init__(*args,**kwargs)def forward(self):print(self.__class__.__name__)
class MultiheadAttention(nn.Module):def __init__(self,*args,**kwargs)->None:super().__init__(*args,**kwargs)def forward(self):print(self.__class__.__name__)
class Dropout(nn.Module):def __init__(self,*args,**kwargs)->None:super().__init__(*args,**kwargs)def forward(self):print(self.__class__.__name__)
class AddNorm(nn.Module):def __init__(self,*args,**kwargs)->None:super().__init__(*args,**kwargs)def forward(self):print(self.__class__.__name__)
class FFN(nn.Module):def __init__(self,*args,**kwargs)->None:super().__init__(*args,**kwargs)def forward(self):print(self.__class__.__name__)class EncoderBlock(nn.Module):def __init__(self,*args, **kwargs)->None:super().__init__(*args,**kwargs)self.mha = MultiheadAttention()self.dropout1=Dropout()self.addnorm1=AddNorm()self.ffn=FFN()self.dropout2=Dropout()self.addnorm2 = AddNorm()def forward(self):self.mha()self.dropout1()self.addnorm1()self.ffn()self.dropout2()self.addnorm2()class TransformerEncoder(nn.Module):def __init__(self,*args,**kwargs)->None:super().__init__(*args,**kwargs)self.embedding=Embedding() # 把序号转变为有语义信息的编码self.posenc=PositionalEncoding()self.dropout=Dropout()self.encblocks=nn.Sequential()for i in range(3):self.encblocks.add_module(str(i),EncoderBlock())def forward(self):self.embedding()self.posenc()self.dropout()for i,blk in enumerate(self.encblocks):print(i)blk()te=TransformerEncoder()
te()

输出结果

Embedding
PositionalEncoding
Dropout
0
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm
1
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm
2
MultiheadAttention
Dropout
AddNorm
FFN
Dropout
AddNorm
http://www.xdnf.cn/news/5725.html

相关文章:

  • 5.Rust+Axum:打造高效错误处理与响应转换机制
  • Wireshark 搜索组合速查表
  • HTML新标签与核心 API 实战
  • tomcat 的安装与启动
  • 具身智能机器人学习路线全解析
  • Muduo网络库实现 [十四] - HttpResponse模块
  • 【4.1.-4.20学习周报】
  • 信号的传输方式
  • JS实现RSA加密
  • Redis面试——日志
  • 《Java 泛型的作用与常见用法详解》
  • 【Linux】第九章 控制服务和守护进程
  • iptables 防火墙
  • JUC学习(1) 线程和进程
  • Springboot 自动装配原理是什么?SPI 原理又是什么?
  • 《AI大模型应知应会100篇》第23篇:角色扮演技巧:让AI成为你需要的专家
  • 【英语语法】基本句型
  • Redis面试——常用命令
  • webgl入门实例-09索引缓冲区示例
  • BH1750光照传感器---附代码
  • java + spring boot + mybatis 通过时间段进行查询
  • 【JavaScript】二十四、JS的执行机制事件循环 + location + navigator + history
  • 基于尚硅谷FreeRTOS视频笔记——13—HAL库和RTOS时钟源问题
  • UE学习记录part18
  • Java锁的分类与解析
  • LeetCode算法题(Go语言实现)_51
  • Vue3如何选择传参方式
  • C++面试
  • 【HDFS入门】HDFS核心配置与优化指南概述
  • 【Python学习笔记】Pandas实现Excel质检记录表初审、复核及质检统计