nn.Linear 是 PyTorch 中的一个线性层(全连接层),用于将输入张量从一个维度空间映射到另一个维度空间。具体来说,nn.Linear 执行以下操作:
output=input×weightT+bias
其中:
input 是输入张量。
weight 是权重矩阵。
bias 是偏置项(如果 bias=True)。
-
具体作用:
输入维度:
假设键(key)的维度为 key_size,即每个键是一个形状为 (key_size,) 的向量。
输出维度:
通过 nn.Linear(key_size, num_hiddens),键被映射到一个新的维度空间,即每个键被转换为一个形状为 (num_hiddens,) 的向量。
权重矩阵:
nn.Linear 会自动创建一个形状为 (key_size, num_hiddens) 的权重矩阵 W_k。
这个权重矩阵将在训练过程中通过反向传播进行优化,以学习如何将键从 key_size 维度映射到 num_hiddens 维度。 -
示例
- import torch import torch.nn as nn# 假设 key_size = 64, num_hiddens = 128 key_size = 64 num_hiddens = 128# 定义线性层 W_k W_k = nn.Linear(key_size, num_hiddens, bias=False)# 假设 K 的形状为 (batch_size, sequence_length, key_size) batch_size = 2 sequence_length = 5 K = torch.randn(batch_size, sequence_length, key_size)# 应用线性变换 K_transformed = W_k(K)print(K_transformed.shape)
输出为torch.Size([2, 5, 128])
解释:
输入:键张量 K 的形状为 (2, 5, 64),表示批量大小为 2,序列长度为 5,每个键的维度为 64。
输出:经过线性变换后,K_transformed 的形状为 (2, 5, 128),表示每个键被映射到了 128 维的隐藏层空间。