pytorch千问模型源码分析


# 规范化技术,旨在替代传统的 Layer Normalization(LN)
# 核心思想是对输入张量的每个样本的每个特征进行规范化,使其均值为 0,方差为 1
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6): # 隐藏层的大小
        super().__init__()
        # 一个可学习的权重参数,初始化为全 1 张量。
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 用于防止除零错误的小常数。
        self.variance_epsilon = eps
    def forward(self, hidden_states):
        # 记录输入张量的数据类型,以便最终转换回原始类型。
        input_dtype = hidden_states.dtype
        # 转换为 torch.float32 类型,以确保数值稳定性。
        hidden_states = hidden_states.to(torch.float32)
        # 计算每个样本的方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 计算每个样本的 RMS 值,并对每个样本进行规范化
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 应用可学习的权重,其中 γγ 是一个可学习的参数,用于缩放规范化后的张量。
        return self.weight * hidden_states.to(input_dtype)
# 用于生成旋转位置嵌入。这种嵌入方法在 Transformer 模型中用于捕捉序列中的位置信息,尤其适用于长序列任务。
# 通过旋转的方式将位置信息编码到嵌入向量中。具体步骤如下:
# 生成频率:通过指数函数生成一系列频率值。计算正弦和余弦:利用生成的频率计算正弦和余弦值
# ,旋转嵌入:将输入向量按一定规则旋转,以嵌入位置信息。
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        # 最大位置嵌入的长度,默认为 2048,base:基数,默认为 10000。。
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        # inv_freq:计算频率的逆值。
        # 位置列表先归一化(从绝对位置变成相对位置),之后取指数(1--接近10000),之后取倒数,位置从1--越来越小
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        # register_buffer:将 inv_freq 注册为缓冲区,以便在模型保存和加载时保持不变。
        # register_buffer 方法用于注册一个非训练的缓冲区(buffer),这意味着它不会被梯度更新。当你使用 register_buffer 注册一个缓
        # 冲区时,它会被保存在模型的状态字典(state dict)中,并且在模型保存和加载时也会被序列化。
        # persistent=True:缓冲区会出现在模型的状态字典中,并且会被序列化和加载。
        # persistent=False:缓冲区不会出现在模型的状态字典中,但在实际保存和加载时,仍然会被序列化并加载。
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # Build here to make `torch.jit.trace` work.生成正弦和余弦缓存
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        # t 是一个包含位置索引的张量,形状为 (seq_len,)。
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        # torch.outer:计算外积,得到一个形状为 (seq_len, dim/2) 的张量
        freqs = torch.outer(t, self.inv_freq) # 计算频率。
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 拼接频率。emb 的形状为 (seq_len, dim)。
        # 在旋转位置嵌入(RoPE)中,我们通常将嵌入向量分为两个部分,并分别应用正弦和余弦变换。具体来说:
        # 对于每个位置 tt,计算频率 ff,得到一个形状为 (seq_len, dim/2) 的张量。
        # 将频率张量拼接两次,得到一个形状为 (seq_len, dim) 的张量。
        # 这样做的原因是,我们将嵌入向量分为两部分,每部分对应一个频率值。
        emb = torch.cat((freqs, freqs), dim=-1)
        # cos_cached 和 sin_cached:注册正弦和余弦缓存。
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
    def forward(self, x, seq_len=None): # x:输入张量。
        # x: [bs, num_attention_heads, seq_len, head_size]
        # 如果 seq_len 大于已缓存的最大长度,则重新生成缓存。
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return ( # 返回正弦和余弦缓存的切片。
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

class Qwen2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size # d
        self.intermediate_size = config.intermediate_size # hd
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # d-->hd
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# d-->hd
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # hd-->d
        self.act_fn = ACT2FN[config.hidden_act]
    def forward(self, hidden_state): # (h,s,d)
        # 门控信号生成:gate_proj(hidden_state) 生成门控信号
        # 特征调整:gate_output 与 up_output 相乘,将门控信号应用于特征表示。
        # 门控机制的作用:通过门控信号动态调整哪些特征应该通过哪些特征应该被抑制。
        # 激活函数的选择:如果 config.hidden_act 是 "sigmoid",那么激活函数将是 sigmoid
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))

class Qwen2Attention(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
        super().__init__() # 调用父类的初始化方法
        self.config = config # 配置类实例
        self.layer_idx = layer_idx # 层索引
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )
        
        self.hidden_size = config.hidden_size # d
        self.num_heads = config.num_attention_heads # q_h
        self.head_dim = self.hidden_size // self.num_heads # dk
        self.num_key_value_heads = config.num_key_value_heads # kv_h
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 比例
        self.max_position_embeddings = config.max_position_embeddings # p
        self.rope_theta = config.rope_theta # base
        self.is_causal = True # 是否用因果掩码
        self.attention_dropout = config.attention_dropout # dropout
        # 嵌入维度必须能被整除
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        # 线性投影
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        #需要注意的是这里的投影维度可能和q的投影维度不同
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        # 最后一个线性转换层
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        # 旋转位置嵌入层
        self.rotary_emb = Qwen2RotaryEmbedding(
            self.head_dim, # dk
            max_position_embeddings=self.max_position_embeddings,# max_position
            base=self.rope_theta, # base
        )
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,# 可选
        position_ids: Optional[torch.LongTensor] = None,# 可选
        past_key_value: Optional[Cache] = None, # 可选参数:缓存
        output_attentions: bool = False,# 是否输出注意力权重
        use_cache: bool = False, # 是否使用缓存
        cache_position: Optional[torch.LongTensor] = None, # 缓存位置
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size() # b,s,d
        # 投影
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        # (b,q_len,q_h,dk)-->(b,q_h,q_len,dk),transpose:换轴(转置)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        # (b,k_h,k_len,dk)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        kv_seq_len = key_states.shape[-2] # k_len
        # 缓存上个时间步的key,value表示
        if past_key_value is not None: # 如果设置了缓存
            if self.layer_idx is None: # 就必须有layer_idx,不然报错
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        # 旋转位置嵌入,传kv_len
        # 键/值序列长度:kv_seq_len 是键和值向量的长度,这是因为键和值向量代表的是相同的序列。
        # 查询序列长度:q_len 是查询向量的长度,这可能不同于键/值向量的长度。
        # 旋转位置嵌入:在计算旋转位置嵌入时,使用键/值序列长度是为了确保位置信息与键和值向量一致。
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        # 返回带位置信息的嵌入表示
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # 如果past_key_value is not None
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            # 更新
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # repeat k/v heads if n_kv_heads < n_heads
        # 如果键值头数量少于查询头数量,则重复键值头以匹配查询头数量。
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        # (b,q_h,q_len,dk)@(b,k_h,dk,k_len)-->(b,h,q_len,k_len)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )
        # 切片,在最后一个维度切出q_len的长度
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            # 相加,一般遮挡的地方是很大的负数
            attn_weights = attn_weights + causal_mask
        # upcast attention to fp32
        # 在q_len上归一化,得到query序列中每个token对应key中token的一系列权重,这些权重中较大的值表示和当前query中的token
        # 相似度较近,较小的表示离当前query中token较远
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        # dropout
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        # (b,h,q_len,k_len)@(b,h,v_len,dk)-->(b,h,q_len,dk)
        attn_output = torch.matmul(attn_weights, value_states)
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )
        # (b,h,q_len,dk)-->(b,h,q_len,h,dk),之后.contiguous()转为内存连续存储
        attn_output = attn_output.transpose(1, 2).contiguous()
        # (b,h,q_len,h,dk)-->(b,h,d)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        # 最后经过线性转换
        attn_output = self.o_proj(attn_output)
        # 不输出注意力权重
        if not output_attentions:
            attn_weights = None
        # 返回多头注意力的输出,注意力权重,上个时间步的key_value的缓存
        return attn_output, attn_weights, past_key_value
class Qwen2FlashAttention2(Qwen2Attention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) # 调用父类初始化方法
        # 如果大于2_10,这个是False
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
    def forward(
        self, # 当前实例
        hidden_states: torch.Tensor,# 上一层的输入,或者第一次传人的嵌入
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        bsz, q_len, _ = hidden_states.size() # b,s,d
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        # (b,h,q_len,dk)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        
        kv_seq_len = key_states.shape[-2] # k_len
        # 如果有k_v缓存,就必须有层索引
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        # Because the input can be padded, the absolute sequence length depends on the max position id.
        
        rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
        # 获取cos,sin
        cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
        # 获取带位置信息的q,k状态
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # 使用滑动窗口的条件是这么几个,配置中有这个属性,kv长度大于窗口大小等
        use_sliding_windows = (
            _flash_supports_window_size
            and getattr(self.config, "sliding_window", None) is not None
            and kv_seq_len > self.config.sliding_window
            and self.config.use_sliding_window
        )
        # 不支持就报警告信息
        if not _flash_supports_window_size:
            logger.warning_once(
                "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
                " make sure to upgrade flash-attn library."
            )
        # 如果设置了past_key_value
        if past_key_value is not None:
            # Activate slicing cache only if the config has a value `sliding_windows` attribute
            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
            if (
                getattr(self.config, "sliding_window", None) is not None
                and kv_seq_len > self.config.sliding_window
                and cache_has_contents
            ):
                slicing_tokens = 1 - self.config.sliding_window

                past_key = past_key_value[self.layer_idx][0]
                past_value = past_key_value[self.layer_idx][1]

                past_key = past_key[:, :, slicing_tokens:, :].contiguous()
                past_value = past_value[:, :, slicing_tokens:, :].contiguous()

                if past_key.shape[-2] != self.config.sliding_window - 1:
                    raise ValueError(
                        f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
                        f" {past_key.shape}"
                    )

                if attention_mask is not None:
                    attention_mask = attention_mask[:, slicing_tokens:]
                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads,设置q,k,v的头相同
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        # 如果是训练模式,就用dropout,评估模式不用dropout
        dropout_rate = 0.0 if not self.training else self.attention_dropout
        # 在PEFT(Prompt-Encoder Fine-Tuning)中,通常我们将层归一化(LayerNorm)用浮点数32位(float32)进行训练以
        # 保证稳定性。因此,输入的隐藏状态会被默默地转换为浮点数32位。所以,我们需要将它们转换回浮点数16位(float16),
        # 只是为了确保一切按预期工作。
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reashape to the expected shape for Flash Attention
        query_states = query_states.transpose(1, 2) # (b,q_len,h,dk)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        attn_output = self._flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )
        # (b,s,d)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)
        if not output_attentions:
            attn_weights = None
        return attn_output, attn_weights, past_key_value

    def _flash_attention_forward(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        query_length,
        dropout=0.0,
        softmax_scale=None,
        use_sliding_windows=False,
    ):
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
            causal = self.is_causal and query_length != 1

        # Decide whether to use SWA or not by layer index.
        # 超过配置的使用滑动窗口的最大层数的话,设置use_sliding_windows = False
        if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
            use_sliding_windows = False
        # 在序列中至少包含一个填充标记
        if attention_mask is not None:
            batch_size = query_states.shape[0] # b
            # query_states:(len(indices_q),h,dk)
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
            # 不用滑动窗口的情况
            if not use_sliding_windows:
                attn_output_unpad = flash_attn_varlen_func(
                    query_states,
                    key_states,
                    value_states,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_k=cu_seqlens_k,
                    max_seqlen_q=max_seqlen_in_batch_q,
                    max_seqlen_k=max_seqlen_in_batch_k,
                    dropout_p=dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                )
            else:
                attn_output_unpad = flash_attn_varlen_func(
                    query_states,
                    key_states,
                    value_states,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_k=cu_seqlens_k,
                    max_seqlen_q=max_seqlen_in_batch_q,
                    max_seqlen_k=max_seqlen_in_batch_k,
                    dropout_p=dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                    window_size=(self.config.sliding_window, self.config.sliding_window),
                )
            # attn_output_unpad:(len(indices_q),h,dk),indices_q:一维张量,非填充token索引
            # 这个可以把去填充的张量恢复成去填充之前的张量
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else: # 没有设置attention_mask的情况
            if not use_sliding_windows: # 不用滑动窗口的情况
                attn_output = flash_attn_func(
                    query_states,
                    key_states,
                    value_states,
                    dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                )
            else: # 用滑动窗口的情况
                attn_output = flash_attn_func(
                    query_states,
                    key_states,
                    value_states,
                    dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                    window_size=(self.config.sliding_window, self.config.sliding_window),
                )
        # 返回注意力输出,形状(b,s,h,dk)
        return attn_output
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # b,k_len,h,dk
        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
        # 在第一次迭代时,我们需要通过在正确的位置进行切片来正确地重新创建填充掩码
        if kv_seq_len != attention_mask.shape[-1]:
            # 获取掩码的最后一维的大小
            attention_mask_num_tokens = attention_mask.shape[-1]
            # 切出kv_len的大小
            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
        # 非填充的token索引,每一批次中每个序列长度的累积和,表示当前批次中最长序列的长度。
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        # index_first_axis = IndexFirstAxis.apply
        # 调用 .apply 方法:IndexFirstAxis.apply(input_tensor, indices) 实际上调用的是 forward 方法。
        # 前向传播:forward 方法执行具体的计算,并返回结果。
        # 反向传播:当计算梯度时,PyTorch 自动调用 backward 方法来计算梯度。
        # 返回的key_layer形状是(len(indices_k),h,dk),去了填充token的嵌入
        # len(indices_k):表示去除了填充后的有效序列长度。
        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
        if query_length == kv_seq_len:
            # 索引在第一个轴
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
            )
            # 这时给q符相应的值
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        
        elif query_length == 1:
            max_seqlen_in_batch_q = 1
            # 每一批次中每个序列长度的累积和
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # There is a memcpy here, that is very bad.
            # 非填充token索引
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1) # (b,h,dk)
        else:
            # The -q_len: 切片操作假设是左填充
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        return (
            query_layer,# (len(indices_q),h,dk),
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
class Qwen2SdpaAttention(Qwen2Attention):
    # Adapted from Qwen2Attention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 不支持输出注意力权重的回退警告,和回退到父类的forward调用
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
        bsz, q_len, _ = hidden_states.size() # b,q_len,d
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        # (b,s,d)-->(b,s,h,dk)-->(b,h,s,dk)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2] # k_len
        # 如果使用key_value缓存
        if past_key_value is not None:
            # 设置新的kv_seq_len
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        # 旋转位置嵌入
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        # 带上位置信息的嵌入
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # 设置缓存的情况
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # 设置k,v和q具有相同的头数
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        causal_mask = attention_mask
        # 如果有传人掩码,切取k_len长度
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and attention_mask is not None:
            # 设置内存连续状态
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()
        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        # q_len等于1时不需要因果掩码,编码器自注意力也不需要因果掩码
        is_causal = True if causal_mask is None and q_len > 1 else False
        
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        # (b,h,s,dk)-->(b,s,h,dk)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output, None, past_key_value

# 千问解码器层,继承自nn.Module
class Qwen2DecoderLayer(nn.Module):
    # 构造函数参数:self,当前实例对象,config,Qwen2Config实例对象,layer_idx,层索引
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__() # 调用父类的初始化方法
        self.hidden_size = config.hidden_size # d
        # 如果设置了使用滑动窗口,就必须设置_attn_implementation为"flash_attention_2,不然会警告
        if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
            logger.warning_once(
                f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
                "unexpected results may be encountered."
            )
        # 多头注意力层
        self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
        # 前馈全连接层
        self.mlp = Qwen2MLP(config)
        # 改良的标准化层
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    def forward(
        self, # 当前实例对象
        hidden_states: torch.Tensor, # 上次的解码器层输出或者第一次的嵌入层
        # 可选参数:注意力掩码
        attention_mask: Optional[torch.Tensor] = None,
        # 位置ids
        position_ids: Optional[torch.LongTensor] = None,
        # 可选参数:元组(里面元素是Tensor)类型,一般用于推理时,指当前token
        # 之前的所有token表示,这里是作为key,value
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        # 可选参数:是否输出注意力权重,布尔类型
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False, # 是否使用缓存
        # 缓存的位置ids,可选,是LongTensor类型
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs, #其他参数,以上是类型注解,规范性的代码就应该有这个
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states #残差前段,如果第一次,是嵌入,否则是上一次解码器的输出
        hidden_states = self.input_layernorm(hidden_states) # 标准化
        # 目标序列自注意力
        # 自注意力输出,自注意力权重,返回的key_value
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask, # 注意力掩码
            position_ids=position_ids, # 位置ids
            past_key_value=past_key_value, #上个时间步的key_value
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )
        # 自注意力前后残差连接
        hidden_states = residual + hidden_states
        #重新设定残差连接前段
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states) #标准化
        hidden_states = self.mlp(hidden_states) # 前馈全连接层
        hidden_states = residual + hidden_states # 前馈前后残差
        #解码器输出,放入一个元组中
        outputs = (hidden_states,)
        if output_attentions:  # 如果要输出注意力权重
            outputs += (self_attn_weights,)
        if use_cache: # 如果使用缓存
            outputs += (present_key_value,)
        return outputs # 返回元组

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1549459.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Python 学习笔记1 - 认识Python

一、什么是Python 1989 年圣诞节期间&#xff0c;荷兰数学和计算机科学研究学会的Guido van Rossum&#xff08;吉多.范罗苏姆&#xff09;决心开发一个新的解释程序&#xff0c;作为 ABC 语言的替代品。这门ABC语言的替代语言被取名为Python,命名来自Guido爱看的的电视剧Mont…

Secret Configmap

应用启动过程中可能需要一些敏感信息&#xff0c;比如访问数据库的用户名&#xff0c;密码或者秘钥&#xff0c;讲这些信息直接保存在容器镜像中显然不合适&#xff0c;kubernetes提供的解决方案就是Secret Secret会以密文的方式存储数据&#xff0c;避免了直接在配置文件中保…

图说数集相等定义表明“R各元x的对应x+0.0001的全体=R“是几百年重大错误

黄小宁 设集A&#xff5b;x&#xff5d;表A各元均由x代表&#xff0c;&#xff5b;x&#xff5d;中变量x的变域是A。其余类推。因各数x可是数轴上点的坐标故x∈R变为实数yx1的几何意义可是&#xff1a;一维空间“管道”g内R轴上的质点x∈R(x是点的坐标)沿“管道”g平移变为点y…

Java爬虫:获取SKU详细信息的艺术

在电子商务的世界里&#xff0c;SKU&#xff08;Stock Keeping Unit&#xff0c;库存单位&#xff09;是每个商品的唯一标识符&#xff0c;它包含了商品的详细信息&#xff0c;如尺寸、颜色、价格等。对于商家和开发者来说&#xff0c;获取商品的SKU详细信息对于库存管理、订单…

二阶低通滤波器(Simulink仿真)

1、如何将S域传递函数转为Z域传递函数 传递函数如何转化为差分方程_非差分方程转成差分方程-CSDN博客文章浏览阅读4.1k次,点赞4次,收藏50次。本文介绍了如何将传递函数转化为差分方程,主要适用于PLC和嵌入式系统。通过MATLAB的系统辨识工具箱获取传递函数,并探讨了离散化方…

pcs集群表决盘故障导致主机reboot

建议重建fence设备并配置 PCSOracle HA实战安装配置参考 - 墨天轮

如何通过GSR排名系统迅速提升谷歌排名?

如果你希望在谷歌上迅速提升某个关键词排名&#xff0c;或者某个关键词无论怎么优化都无法上首页&#xff0c;那么GSR关键词排名系统你就可以关注一下&#xff0c;GSR系统可以在短时间内帮助你进一步提升至首页。与传统的SEO方法不同&#xff0c;GSR侧重于外部优化&#xff0c;…

使用世界领先的 Qwen2.5-Math 开源模型当 AI 数学老师,让奥数解题辅导不在鸡飞狗跳(文末有福利)

在上篇文章中&#xff0c;我们使用Qwen2.5-Coder编写了一个自动编程的多智能体系统&#xff08;基于 Qwen2.5-Coder 模型和 CrewAI 多智能体框架&#xff0c;实现智能编程系统的实战教程&#xff09;&#xff0c;着实感受到了Qwen2.5-Coder和CrewAI强强联合所发挥出来的强大威力…

学习鸿蒙Harmong基础(二)

1.类声明和使用 class Perpon { name : string "小赵"; age : number 24; isShow :boolean true; // 构造函数 constructor(name:string,age:number,isShow:boolean){ this.name name; this.age age; this.isShow isShow } puperyInfo(){ if (this.isShow) { …

芝士AI写作有什么特色? 大模型支撑,智能改写续写,让写作更轻松

又到了一年的毕业季&#xff0c;大学四年眨眼间匆匆就过去了&#xff0c;毕业&#xff0c;求职&#xff0c;考研&#xff0c;工作&#xff0c;升学&#xff0c;但是在这之前&#xff0c;我们必须要完成论文的写作&#xff0c;这也是每一位大学生都必须要面对~ 芝士AI官网&…

Java Statement SqlTemplate 源码分析

Java Statement SqlTemplate 源码分析 目录概述需求&#xff1a; 设计思路实现思路分析1. 概述2. 关键类3. 主要功能4. 源码结构5. 示例代码6. 性能考虑7. 常见问题8. 总结&#xff1a; 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show y…

在传销案件中数据库取证的分步指南

金字塔计划的特点是分层结构&#xff0c;主要由招募新成员的机制驱动。取证部门调查这些方案时&#xff0c;往往依靠数据库记录来分析这种结构。这些记录详细描述了上级和下级之间的关系&#xff0c;使调查人员能够描绘出组织的动态。在本文中&#xff0c;我们将探讨如何利用数…

解锁初中学习新境界 —— 初中通关宝典速记手册

在初中这个学习生涯的关键阶段&#xff0c;掌握扎实的基础知识是取得优异成绩的关键。为此&#xff0c;我们特别推荐《初中通关宝典》——一本专为初中生打造的各科基础知识速记手册&#xff0c;它将成为你学习路上的得力助手。 文章目录 1. 全科覆盖&#xff0c;精准速记2.科学…

Spring--boot自动配置原理案例--阿里云--starter

Spring–boot自动配置原理案例–阿里云–starter 定义这个starter的作用是它可以将阿里云的工具类自动放入IOC容器中&#xff0c;供人使用。 我们看一看构建starter的过程&#xff0c;其实就是在atuoconfigure模块中加入工具类&#xff0c;然后写一个配置类在其中将工具类放入…

Hadoop三大组件之YARN(一)

YARN架构与任务提交流程详解 1. YARN的组成架构 YARN&#xff08;Yet Another Resource Negotiator&#xff09;是Hadoop生态系统中的一个重要组成部分&#xff0c;主要用于资源管理和调度。YARN的架构主要由以下几个关键组件构成&#xff1a; 1.1 ResourceManager&#xff…

企业IT安全重保服务:守护关键时刻的坚固防线

中国联通国际有限公司产品之IT安全重保服务&#xff1a;守护关键时刻的坚固防线 在数字化时代&#xff0c;信息安全已成为企业运营与国家安全的基石。随着各类重大活动、会议及内部专项工作的频繁举行&#xff0c;如何确保信息系统在关键时刻免受黑客攻击、网页篡改、病毒感染…

How to install JetBrains ToolBox in Ubuntu 22.04 LTS?

JetBrains Toolbox 的安装教程 在 2024 年 9 月 28 日&#xff0c;我想和大家分享一下 JetBrains Toolbox 的安装步骤&#xff0c;让你轻松开启高效的开发之旅。 一、准备工作 首先&#xff0c;确保你已经准备好了要安装的 JetBrains Toolbox 文件&#xff0c;可以从官方网站…

【SQL】未订购的客户

目录 语法 需求 示例 分析 代码 语法 SELECT columns FROM table1 LEFT JOIN table2 ON table1.common_field table2.common_field; LEFT JOIN&#xff08;或称为左外连接&#xff09;是SQL中的一种连接类型&#xff0c;它用于从两个或多个表中基于连接条件返回左表…

围攻特斯拉:六大门派边围攻、边互殴

这万万没想到&#xff0c;金庸小说里六大门派围攻光明顶这种剧情&#xff0c;居然在现实中出现了。就在这一个月里&#xff0c;有五款新车发布上市&#xff0c;他们所有人的目标&#xff0c;都是特斯拉的Model Y。他们要一起围攻特斯拉&#xff0c;抢夺它的用户。 这其中包括蔚…

Qt(9.28)

widget.cpp #include "widget.h"Widget::Widget(QWidget *parent): QWidget(parent) {QPushButton *btn1 new QPushButton("登录",this);this->setFixedSize(640,480);btn1->resize(80,40);btn1->move(200,300);btn1->setIcon(QIcon("C:…