深度学习|损失函数:网络参数优化基准

文章目录

    • 引言
    • 均方误差
      • 计算示例
      • 矩阵形式
      • 代码实现
    • 交叉熵误差
      • 计算示例
      • 代码实现
    • 绝对误差
      • 计算示例
      • 代码实现
    • Hinge Loss
      • 计算示例
      • 代码实现
    • Kullback-Leibler Divergence
      • 计算示例
      • 代码实现
    • 结语

引言

在上文「深度学习|模型训练:手写 SimpleNet」中,我们以简单的 Python 代码演示了神经网络的整个训练过程,我们知道了神经网络的学习就是从数据样例中自动学得神经网络的权重参数最优解的过程。其中不难发现,要想让模型参数在模型训练的迭代中得到一次次的优化,其中损失函数起着至关重要的作用,损失函数是衡量模型参数好坏的基准,选择合适的损失函数是决定模型可以有效训练的前提条件。

本文我们将进一步介绍更多不同的损失函数,介绍它们的定义与代表的含义,以及它们在神经网络训练中的如何起到“促进”的作用。

在这里插入图片描述

均方误差

均方误差Mean Squared Error, MSE)是深度学习和机器学习中最常用的损失函数,尤其在回归问题中。

均方误差是模型的预测值( y ^ \hat{y} y^)与实际值( y y y)之间差异的平方的平均值,其数学公式为式 1:

MSE = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 (1) \text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2 \tag{1} MSE=N1i=1N(yiy^i)2(1)

其中 N N N 是样本数量; y i y_i yi 是实际值; y ^ i \hat{y}_i y^i 是预测值。

从均方误差的定义可以看出,它对于量化预测值与实际结果之间的差异非常有效。均方误差越小,说明模型的预测越准确(衡量预测精度)。在训练神经网络时,我们通常会使用均方误差作为损失函数,通过优化该损失函数,使得模型的预测结果尽可能接近实际结果,以最小化预测误差,从而提高模型性能。

计算示例

假设我们有四组预测值和实际值如下:

  • 预测值: y ^ = [ 3.0 , − 0.5 , 2.0 , 7.0 ] \hat{y} = [3.0, -0.5, 2.0, 7.0] y^=[3.0,0.5,2.0,7.0]
  • 实际值: y = [ 2.5 , 0.0 , 2.0 , 8.0 ] y = [2.5, 0.0, 2.0, 8.0] y=[2.5,0.0,2.0,8.0]

我们首先计算每个样本的平方误差:

  1. ( 3.0 − 2.5 ) 2 = ( 0.5 ) 2 = 0.25 (3.0 - 2.5)^2 = (0.5)^2 = 0.25 (3.02.5)2=(0.5)2=0.25
  2. ( − 0.5 − 0.0 ) 2 = ( − 0.5 ) 2 = 0.25 (-0.5 - 0.0)^2 = (-0.5)^2 = 0.25 (0.50.0)2=(0.5)2=0.25
  3. ( 2.0 − 2.0 ) 2 = ( 0.0 ) 2 = 0.0 (2.0 - 2.0)^2 = (0.0)^2 = 0.0 (2.02.0)2=(0.0)2=0.0
  4. ( 7.0 − 8.0 ) 2 = ( − 1.0 ) 2 = 1.0 (7.0 - 8.0)^2 = (-1.0)^2 = 1.0 (7.08.0)2=(1.0)2=1.0

然后,将这些平方误差求和,并求其平均值:

MSE = 1 4 ( 0.25 + 0.25 + 0.0 + 1.0 ) = 1.5 4 = 0.375 \text{MSE} = \frac{1}{4} \left(0.25 + 0.25 + 0.0 + 1.0\right) = \frac{1.5}{4} = 0.375 MSE=41(0.25+0.25+0.0+1.0)=41.5=0.375

均方误差常用于回归问题,能够有效地衡量预测值与实际值之间的平均平方差。由于较大的误差值被平方,提高了其对模型训练过程中的重要性(敏感度),使模型能够更倾向于“关注”出错较大的样本,因此均方误差在深度学习和许多统计模型中被广泛应用。

矩阵形式

在大多时候模型单个样例的输出不是一个单值,而是一个包含 k 个值的数组(向量),我们可以将对上文单值输出的均方误差推广到对数组输出的均方误差计算。求一批预测结果的数组输出的均方误差,可以用式 2 表示:

MSE = 1 2 N ∑ i = 1 N ∑ j = 1 k ( y i j − y ^ i j ) 2 (2) \text{MSE} = \frac{1}{2N} \sum_{i=1}^{N} \sum_{j=1}^{k}(y_{ij} - \hat{y}_{ij})^2 \tag{2} MSE=2N1i=1Nj=1k(yijy^ij)2(2)

其中 N N N 表示样本数量,k 表示输出数组的长度; y i j y_{ij} yij 是第 i 个样例的第 j 个输出的实际值; y ^ i j \hat{y}_{ij} y^ij 是第 i 个样例的第 j 个输出的预测值。取 1 2 \frac{1}{2} 21 是为了方便后续计算方便,在求导中化掉。

在前文的手写数字识别任务中,我们知道神经网络输出的 one-hot 表示的 10 个 y 值,分别代表了推理结果为 0 ~ 9 的概率。

假设我们有两组预测和实际值如下:

  • 预测值 1:[0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0](索引位 2 的概率值最大,表示预测结果为数字 2);
  • 预测值 2:[0.5, 0.05, 0.2, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0](索引位 0 的概率值最大,表示预测结果为数字 0);
  • 实际值:[0, 0, 1, 0, 0, 0, 0, 0, 0, 0](实际结果为数字 2);
  • 实际值:[1, 0, 0, 0, 0, 0, 0, 0, 0, 0](实际结果为数字 0);

按照式 2 计算均方误差:

MSE = 1 2 ( E 1 + E 2 ) = 1 2 × 2 ( ∑ k ( y k − t k ) 2 + ∑ k ( y k − t k ) 2 ) = 1 2 × 2 ( [ ( 0.1 − 0 ) 2 + ( 0.05 − 0 ) 2 + ( 0.6 − 1 ) 2 + . . . + ( 0.0 − 0 ) 2 ] + [ ( 0.5 − 1 ) 2 + ( 0.05 − 0 ) 2 + ( 0.2 − 0 ) 2 + . . . + ( 0.0 − 0 ) 2 ] ) = 1 2 × 2 ( 0.195 + 0.315 ) = 0.1275 \text{MSE} = \frac{1}{2} (E_1 + E_2) \\ = \frac{1}{2 \times 2} (\sum_{k}(y_k - t_k)^2 + \sum_{k}(y_k - t_k)^2) \\ = \frac{1}{2 \times 2} ([(0.1 - 0)^2 + (0.05 - 0)^2 + (0.6 - 1)^2 + ... + (0.0 - 0)^2] + [(0.5 - 1)^2 + (0.05 - 0)^2 + (0.2 - 0)^2 + ... + (0.0 - 0)^2]) \\ = \frac{1}{2 \times 2} (0.195 + 0.315) \\ = 0.1275 MSE=21(E1+E2)=2×21(k(yktk)2+k(yktk)2)=2×21([(0.10)2+(0.050)2+(0.61)2+...+(0.00)2]+[(0.51)2+(0.050)2+(0.20)2+...+(0.00)2])=2×21(0.195+0.315)=0.1275

代码实现

均方误差的 Python 代码实现如下:

import numpy as npdef mean_squared_error(y, t):"""均方误差函数Args:y: 神经网络的输出t: 监督数据Returns:float: 均方误差"""batch_size = y.shape[0]return 0.5 * np.sum((y-t)**2) / batch_size# 示例数据
y_true = np.array([2.5, 0.0, 2.0, 8.0]) # 真实值
y_pred = np.array([3.0, -0.5, 2.0, 7.0]) # 预测值# 计算并输出均方误差
mse = mean_squared_error(y_pred, y_true)
print("Mean Squared Error (MSE):", mse)
# Mean Squared Error (MSE): 0.1875

这里引入了取最终结果的 1 2 \frac{1}{2} 21,所以所得 0.1875 正好是上文示例中计算结果的一半。

由于均方误差是一个连续且光滑的函数(平滑性),许多优化算法(如梯度下降法)可以有效利用其梯度信息进行参数更新,从而有效提高学习效率。

对于很多像预测房价、温度等的回归问题,均方误差是最常用的损失函数,它能够有效捕捉模型预测的偏差,从而指导模型朝着更小的误差方向调整。借助梯度下降法,可以通过对均方误差函数求关于模型权重参数的导数,然后随着梯度向下调整模型参数,使得模型的预测结果更加准确。

均方误差由于其对大偏差实例的敏感性,当存在异常值的情况下,可能导致模型不稳定。在面对异常值时,可以考虑使用其他损失函数,如均绝对误差Mean Absolute Error, MAE)或鲁棒回归损失函数

交叉熵误差

交叉熵误差Cross-Entropy Loss)可用于量化两个概率分布之间的差异,比如预测分布和真实标签分布之间的差距,也是一种很常用的损失函数,尤其在分类任务中。

在二分类问题中,交叉熵误差的计算公式如式 3:

Cross Entropy = − 1 N ∑ n = 1 N [ t n log ⁡ y ^ n + ( 1 − t n ) log ⁡ ( 1 − y ^ n ) ] (3) \text{Cross Entropy} = -\frac{1}{N} \sum_{n=1}^{N} \left[ t_n \log{\hat{y}_n} + (1 - t_n) \log{(1 - \hat{y}_n)} \right] \tag{3} Cross Entropy=N1n=1N[tnlogy^n+(1tn)log(1y^n)](3)

在多分类问题中,交叉熵误差的计算工时如式 4:

Cross Entropy = − 1 N ∑ n = 1 N ∑ k = 1 K t n k log ⁡ y ^ n k (4) \text{Cross Entropy} = -\frac{1}{N} \sum_{n=1}^{N} \sum_{k=1}^{K} t_{nk} \log{\hat{y}_{nk}} \tag{4} Cross Entropy=N1n=1Nk=1Ktnklogy^nk(4)

其中:

  • E E E 是交叉熵误差。
  • N N N 是样本的总数。
  • t n t_n tn 表示第 n 个样本的真实标签(通常为 0 或 1)。
  • y ^ n \hat{y}_n y^n 表示模型对第 n 个样本预测为正类的概率。
  • t n k t_{nk} tnk 表示第 n 个样本的第 k 个输出的真实标签(通常为 0 或 1)。
  • y ^ n k \hat{y}_{nk} y^nk 表示模型对第 n 个样本的第 k 个输出预测为正类的概率。

交叉熵误差的目标在于最小化预测概率与真实标签之间的不一致性。当预测概率接近真实标签时,交叉熵损失较低;当预测概率远离真实标签时,损失值较高。该损失函数非常适合于处理概率输出,可以确保反馈的信息能够有效地更新模型参数。

计算示例

假设在一个二分类问题中,我们有三个样本及其真实标签和预测概率如下:

样本真实标签 t t t预测概率 y ^ \hat{y} y^
110.9
200.2
310.4

计算交叉熵误差:

  1. 对于样本 1:
    E 1 = − [ 1 ⋅ log ⁡ ( 0.9 ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − 0.9 ) ] = − log ⁡ ( 0.9 ) ≈ 0.1054 E_1 = -[1 \cdot \log(0.9) + (1-1) \cdot \log(1-0.9) ] = -\log(0.9) \approx 0.1054 E1=[1log(0.9)+(11)log(10.9)]=log(0.9)0.1054

  2. 对于样本 2:
    E 2 = − [ 0 ⋅ log ⁡ ( 0.2 ) + ( 1 − 0 ) ⋅ log ⁡ ( 1 − 0.2 ) ] = − log ⁡ ( 0.8 ) ≈ 0.2231 E_2 = -[0 \cdot \log(0.2) + (1-0) \cdot \log(1-0.2)] = -\log(0.8) \approx 0.2231 E2=[0log(0.2)+(10)log(10.2)]=log(0.8)0.2231

  3. 对于样本 3:
    E 3 = − [ 1 ⋅ log ⁡ ( 0.4 ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − 0.4 ) ] = − log ⁡ ( 0.4 ) ≈ 0.9163 E_3 = -[1 \cdot \log(0.4) + (1-1) \cdot \log(1-0.4)] = -\log(0.4) \approx 0.9163 E3=[1log(0.4)+(11)log(10.4)]=log(0.4)0.9163

总损失 E E E 为:

Cross Entropy = 1 3 ( E 1 + E 2 + E 3 ) ≈ 1 3 ( 0.1054 + 0.2231 + 0.9163 ) ≈ 1.2448 3 ≈ 0.4149 \text{Cross Entropy} = \frac{1}{3}(E_1 + E_2 + E_3) \\ \approx \frac{1}{3}(0.1054 + 0.2231 + 0.9163) \\ \approx \frac{1.2448}{3} \\ \approx 0.4149 Cross Entropy=31(E1+E2+E3)31(0.1054+0.2231+0.9163)31.24480.4149

在多分类任务中,交叉熵误差计算的是对应正确解概率输出的自然对数,如式 4 所示。

因为 t n k t_{nk} tnk 中只有正确解索引位的值为 1,其他均为 0,式 4 实际只计算了对应正确解神经元输出的自然对数。

以上文给定手写数字识别任务的预测输出 y ^ \hat{y} y^ 与正式输出 y 的示例为例,即推理结果 y ^ \hat{y} y^ 相对实际结果 y 的交叉熵误差为:

Cross Entropy = − 1 2 ( log ⁡ y 2 ^ + log ⁡ y 0 ^ ) = − 1 2 ( log ⁡ 0.6 + log ⁡ 0.5 ) ≈ 0.6 。 \text{Cross Entropy} = - \frac{1}{2} ( \log{\hat{y_2}} + \log{\hat{y_0}} ) \\ = - \frac{1}{2} (\log{0.6} + \log{0.5} ) \\ \approx 0.6。 Cross Entropy=21(logy2^+logy0^)=21(log0.6+log0.5)0.6

代码实现

均方误差的 Python 代码实现如下:

import numpy as npdef cross_entropy_error(y, t):"""交叉熵误差函数Args:y: 神经网络的输出t: 监督数据Returns:float: 交叉熵误差"""# 监督数据是 one-hot-vector 的情况下,转换为正确解标签的索引if t.size == y.size:t = t.argmax(axis=1)batch_size = y.shape[0]return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size# 示例数据
y_true = np.array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])  # 真实值
y_pred = np.array([[0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0], [0.5, 0.05, 0.2, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]])  # 预测值# 计算并输出均方误差
mse = cross_entropy_error(y_pred, y_true)
print("Cross Entropy Loss:", loss)
# Cross Entropy Loss: 0.6019862188296516

交叉熵误差 是一种重要的损失函数,主要用于分类问题,评估类别的概率分布,常用于二分类和多分类任务。比如在分类任务中,交叉熵误差能够有效量化实际分布与预测分布之间的差异,并驱动模型参数朝着最小化预测与真实标签之间不一致的方向更新,因此在深度学习和机器学习领域得到了广泛应用。

绝对误差

绝对误差Mean Absolute Error, MAE)是指预测值与实际值之间差异的平均绝对值。其计算公式如下式 5:

MAE = 1 N ∑ n = 1 N ∣ y n − y n ^ ∣ (5) \text{MAE} = \frac{1}{N} \sum_{n=1}^{N} |y_n - \hat{y_n}| \tag{5} MAE=N1n=1Nynyn^(5)

其中, y n ^ \hat{y_n} yn^ 表示第 n 个预测值, y n y_n yn 表示第 n 个实际值, N N N 是样本的总数。

绝对误差通过计算预测值与实际值之间的绝对差值来衡量误差。与均方误差不同,绝对误差没有平方运算,因此对异常值的敏感性较低。

计算示例

我们可以用一个简单的例子来说明 MAE 的计算。

假设我们有三组预测值和实际值如下:

  • 预测值: y ^ = [ 2.5 , 0.0 , 2.1 ] \hat{y} = [2.5, 0.0, 2.1] y^=[2.5,0.0,2.1]
  • 实际值: y = [ 3.0 , − 0.5 , 2.0 ] y = [3.0, -0.5, 2.0] y=[3.0,0.5,2.0]

首先计算每个样本的绝对误差:

  1. ∣ 2.5 − 3.0 ∣ = 0.5 |2.5 - 3.0| = 0.5 ∣2.53.0∣=0.5
  2. ∣ 0.0 − ( − 0.5 ) ∣ = 0.5 |0.0 - (-0.5)| = 0.5 ∣0.0(0.5)=0.5
  3. ∣ 2.1 − 2.0 ∣ = 0.1 |2.1 - 2.0| = 0.1 ∣2.12.0∣=0.1

然后,将这些绝对误差求和,并求其平均值:

E = 1 3 ( 0.5 + 0.5 + 0.1 ) = 1.1 3 ≈ 0.367 E = \frac{1}{3} \left(0.5 + 0.5 + 0.1\right) \\ = \frac{1.1}{3} \\ \approx 0.367 E=31(0.5+0.5+0.1)=31.10.367

代码实现

绝对误差的 Python 代码实现如下:

import numpy as npdef mean_absolute_error(y_true, y_pred):"""计算绝对误差(Mean Absolute Error, MAE)Args:y_true : np.array,真实值的数组y_pred : np.array,预测值的数组Returns:float: 计算得到的绝对误差(MAE)"""# 计算绝对误差absolute_errors = np.abs(y_true - y_pred)  # 计算每个样本的绝对误差mae = np.mean(absolute_errors)  # 计算平均绝对误差return mae# 示例数据
y_true = np.array([3.0, -0.5, 2.0])  # 真实值
y_pred = np.array([2.5, 0.0, 2.1])  # 预测值# 计算并输出绝对误差
mae = mean_absolute_error(y_true, y_pred)
print("Mean Absolute Error (MAE):", mae)
# Mean Absolute Error (MAE): 0.3666666666666667

绝对误差常用于回归问题,衡量预测值与实际值之间的平均绝对差。由于绝对误差对异常值的影响较小,因此在某些应用中比均方误差更为稳健。特别是在目标是最小化预测值与实际值之间误差的情况下,MAE 是一个常用的评价指标。

Hinge Loss

Hinge Loss 是一种主要用于支持向量机(SVM)和某些神经网络模型的损失函数,尤其是在二分类问题中。它旨在最大化类之间的间隔(margin),并通过对分类正确但距离决策边界不够远的样本施加惩罚来优化学习过程。

其定义公式如下式 6:

Hinge = 1 N ∑ n = 1 N max ⁡ ( 0 , 1 − y n y ^ n ) (6) \text{Hinge} = \frac{1}{N} \sum_{n=1}^{N} \max(0, 1 - y_n \hat{y}_n) \tag{6} Hinge=N1n=1Nmax(0,1yny^n)(6)

其中:

  • E E E 是总损失。
  • N N N 是样本数量。
  • y n y_n yn 是样本的真实标签,通常取值为 + 1 +1 +1 − 1 -1 1
  • y ^ n \hat{y}_n y^n 是模型对样本的预测值(即类别的未缩放输出)。

当样本被正确分类,并且预测结果与真实标签之间的距离大于 1 时,损失为 0。如果样本被错误分类,或者正确分类但距离决策边界小于 1,则会产生正的损失。

计算示例

假设我们有三个样本及其真实标签和模型预测值如下:

样本真实标签 y y y预测值 y ^ \hat{y} y^
1+10.8
2-1-0.6
3+11.2

计算 Hinge Loss:

  1. 对于样本 1:
    E 1 = max ⁡ ( 0 , 1 − ( 1 ) ( 0.8 ) ) = max ⁡ ( 0 , 0.2 ) = 0.2 E_1 = \max(0, 1 - (1)(0.8)) = \max(0, 0.2) = 0.2 E1=max(0,1(1)(0.8))=max(0,0.2)=0.2

  2. 对于样本 2:
    E 2 = max ⁡ ( 0 , 1 − ( − 1 ) ( − 0.6 ) ) = max ⁡ ( 0 , 1 − 0.6 ) = max ⁡ ( 0 , 0.4 ) = 0.4 E_2 = \max(0, 1 - (-1)(-0.6)) = \max(0, 1 - 0.6) = \max(0, 0.4) = 0.4 E2=max(0,1(1)(0.6))=max(0,10.6)=max(0,0.4)=0.4

  3. 对于样本 3:
    E 3 = max ⁡ ( 0 , 1 − ( 1 ) ( 1.2 ) ) = max ⁡ ( 0 , 1 − 1.2 ) = max ⁡ ( 0 , − 0.2 ) = 0 E_3 = \max(0, 1 - (1)(1.2)) = \max(0, 1 - 1.2) = \max(0, -0.2) = 0 E3=max(0,1(1)(1.2))=max(0,11.2)=max(0,0.2)=0

总损失 E E E 为:

E = 1 3 ( E 1 + E 2 + E 3 ) = 1 3 ( 0.2 + 0.4 + 0 ) = 0.2 E = \frac{1}{3} ( E_1 + E_2 + E_3 ) \\ = \frac{1}{3} ( 0.2 + 0.4 + 0) \\ = 0.2 E=31(E1+E2+E3)=31(0.2+0.4+0)=0.2

代码实现

import numpy as npdef hinge_loss(y_true, y_pred):"""计算 Hinge Loss参数:y_true: np.array,真实类别标签(-1 或 1)y_pred: np.array,预测值(可以是与真实值相同的分类数值)返回:float: 计算得到的 Hinge Loss"""# 确保 y_true 值为 -1 或 1assert np.all(np.isin(y_true, [-1, 1])), "y_true must contain only -1 or 1"# 计算 Hinge Losslosses = np.maximum(0, 1 - y_true * y_pred)  # Hinge Lossreturn np.mean(losses)  # 计算平均 Hinge Loss# 示例数据
y_true = np.array([1, -1, 1])  # 真实标签
y_pred = np.array([0.8, -0.6, 1.2])  # 预测值# 计算并输出 Hinge Loss
loss = hinge_loss(y_true, y_pred)
print("Hinge Loss:", loss)
# Hinge Loss: 0.19999999999999998

Hinge Loss 通常用于分类任务,特别是在支持向量机中,能够有效地处理分类边界。当使用此损失函数时,目标是使 Hinge Loss 变小,从而提升模型对分类的能力。由于它关注于决策边界的距离,因此在处理不平衡数据或存在异常值的情况下表现良好,适合处理对分类精度要求较高的任务。

Kullback-Leibler Divergence

Kullback-Leibler DivergenceKL Divergence)是一种衡量两个概率分布之间差异的非对称度量。它通常用于信息论、统计学习和机器学习中,尤其是在模型评估、生成模型、变分推断等场景下。

其定义公式如下:

D K L ( P ∣ ∣ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) (7) D_{KL}(P || Q) = \sum_{x} P(x) \log \left( \frac{P(x)}{Q(x)} \right) \tag{7} DKL(P∣∣Q)=xP(x)log(Q(x)P(x))(7)

其中:

  • D K L ( P ∣ ∣ Q ) D_{KL}(P || Q) DKL(P∣∣Q) 表示从分布 Q Q Q 到分布 P P P 的 Kullback-Leibler 散度。
  • P ( x ) P(x) P(x) Q ( x ) Q(x) Q(x) 分别表示概率分布 P P P Q Q Q 中事件 x x x 的概率。

KL 散度度量的是使用分布 Q Q Q 来近似分布 P P P 时所需的额外信息量(或熵)。越小的 KL 散度表示 Q Q Q 越接近 P P P。KL 散度总是非负的 D K L ( P ∣ ∣ Q ) ≥ 0 D_{KL}(P || Q) \geq 0 DKL(P∣∣Q)0,这是由 Gibbs 不等式保证的。KL 散度是非对称的,即 D K L ( P ∣ ∣ Q ) ≠ D K L ( Q ∣ ∣ P ) D_{KL}(P || Q) \neq D_{KL}(Q || P) DKL(P∣∣Q)=DKL(Q∣∣P),这使得它不符合距离的性质。

  • 信息量度量:KL 散度可以被视为从 Q Q Q 中取样时,关于 P P P 最优编码的额外成本。
  • 模型训练:在许多生成模型(如变分自编码器)中,KL 散度用于量化模型的近似分布与真实分布之间的差异。
  • 分类问题:在多分类任务中,KL 散度可以用来衡量预测的概率分布与真实标签的概率分布(通常是 one-hot encoding)之间的区别。

计算示例

假设有两个离散概率分布 P P P Q Q Q

  • P = [ 0.4 , 0.6 ] P = [0.4, 0.6] P=[0.4,0.6]
  • Q = [ 0.5 , 0.5 ] Q = [0.5, 0.5] Q=[0.5,0.5]

我们可以计算 KL 散度:

D K L ( P ∣ ∣ Q ) = P ( 1 ) log ⁡ ( P ( 1 ) Q ( 1 ) ) + P ( 2 ) log ⁡ ( P ( 2 ) Q ( 2 ) ) D_{KL}(P || Q) = P(1) \log\left(\frac{P(1)}{Q(1)}\right) + P(2) \log\left(\frac{P(2)}{Q(2)}\right) DKL(P∣∣Q)=P(1)log(Q(1)P(1))+P(2)log(Q(2)P(2))

代入值:

D K L ( P ∣ ∣ Q ) = 0.4 log ⁡ ( 0.4 0.5 ) + 0.6 log ⁡ ( 0.6 0.5 ) D_{KL}(P || Q) = 0.4 \log\left(\frac{0.4}{0.5}\right) + 0.6 \log\left(\frac{0.6}{0.5}\right) DKL(P∣∣Q)=0.4log(0.50.4)+0.6log(0.50.6)

计算:

  1. 0.4 log ⁡ ( 0.8 ) ≈ 0.4 × − 0.223144 = − 0.089258 0.4 \log(0.8) \approx 0.4 \times -0.223144 = -0.089258 0.4log(0.8)0.4×0.223144=0.089258
  2. 0.6 log ⁡ ( 1.2 ) ≈ 0.6 × 0.182322 = 0.109393 0.6 \log(1.2) \approx 0.6 \times 0.182322 = 0.109393 0.6log(1.2)0.6×0.182322=0.109393

合并结果:

D K L ( P ∣ ∣ Q ) ≈ − 0.089258 + 0.109393 ≈ 0.020135 (2) D_{KL}(P || Q) \approx -0.089258 + 0.109393 \approx 0.020135 \tag{2} DKL(P∣∣Q)0.089258+0.1093930.020135(2)

代码实现

import numpy as npdef kl_divergence(p, q):"""计算 Kullback-Leibler Divergence (KL Divergence)Args:p : np.array,源分布的概率值(必须为非负且总和为 1)q : np.array,目标分布的概率值(必须为非负且总和为 1)Returns:float: 计算得到的 KL Divergence"""# 确保输入分布为概率分布(非负且总和为 1)assert np.all(p >= 0) and np.isclose(np.sum(p), 1), "p must be a valid probability distribution."assert np.all(q >= 0) and np.isclose(np.sum(q), 1), "q must be a valid probability distribution."# 计算 KL Divergence# 使用 np.where 来避免对 q 中为 0 的值进行 log 计算divergence = np.sum(np.where(p != 0, p * np.log(p / q), 0))  # 对于 p=0 的项不计算return divergence# 示例数据
p = np.array([0.4, 0.6])  # 源分布
q = np.array([0.5, 0.5])  # 目标分布# 计算并输出 KL Divergence
kl = kl_divergence(p, q)
print("Kullback-Leibler Divergence (KL Divergence):", kl)
# Kullback-Leibler Divergence (KL Divergence): 0.020135513550688863

Kullback-Leibler Divergence 是一种重要的统计量,帮助我们量化两个概率分布之间的差异,常用于机器学习和信息论领域。在许多模型训练过程中,它被用作损失函数,以确保模型输出尽量接近真实分布。

结语

损失函数(loss function)是表示神经网络性能“恶劣程度”的指标,即当前神经网络对监督数据有多么不拟合,多么不一致。

在神经网络模型中,均方误差、交叉熵误差、绝对误差、Kullback-Leibler Divergence 和 Hinge Loss 等都具有实际价值,选择何种损失函数通常取决于特定问题的需求、目标以及数据的特征。


PS:感谢每一位志同道合者的阅读,欢迎关注、点赞、评论!

  • 上一篇:深度学习|模型训练:手写 SimpleNet
  • 专栏:「数智通识」 | 「机器学习」

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1538404.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Nodejs+vue+Express游戏分享网站的设计与实现 7a2s2

目录 技术栈具体实现截图系统设计思路技术可行性nodejs类核心代码部分展示可行性论证研究方法解决的思路Express框架介绍源码获取/联系我 技术栈 该系统将采用B/S结构模式,开发软件有很多种可以用,本次开发用到的软件是vscode,用到的数据库是…

Flink垃圾图片分类优胜奖比赛攻略_贪吃的小香猪-148队

关联比赛: Apache Flink极客挑战赛——垃圾图片分类 一. 赛题背景分析及理解 本次竞赛要求结合大数据计算引擎Flink和深度学习的计算平台Intel Analytics Zoo应用在图片识别场景,进行垃圾图片的分类。 目标:对提供的100类大约6000张垃圾图片进行模型训…

五星级可视化页面(30):本系列最后一期,压轴出场。

不知不觉分享了30期高品质的五星级可视化大屏界面,该系列文章也该收尾了,本期为大家分享最后一批界面,我们下一个系列专辑见。

A股上市公司企业创新能力、质量、效率-原始数据+dofile+结果(2006-2023年)

上市公司的创新能力体现在其不断研发新技术、新产品和服务的能力上,这是企业保持竞争优势的关键;质量则是指公司所提供的产品或服务达到高标准的程度,高质量是赢得客户信任和市场份额的基础;效率则涵盖了生产运营中的资源利用程度…

智能车镜头组入门(一)车模的选择

这篇文章,我会简单的介绍下车模的、轮胎和负压的选择 今年的镜头组是自制车模,这比较考验学校之前参赛的经验。我们选择了某飞的mini车模。提供智能车方案的无非就两家,某飞和某邱,我们学校之前都用的是某飞的,在某飞…

鸿蒙Harmony应用开发,数据驾驶舱 项目结构搭建

对于一个项目而言,在拿到我们的开发任务后,我们最重要的就是技术的选型。选型定下来了之后我们便开始脚手架的搭建,然后开始撸代码,开搞. 首先我们需要对一些常见依赖库的引入 我们需要再oh-package.json5的dependencies节点下面…

leetcode:最高乘法得分

用auto可以过 class Solution { public:long long maxScore(vector<int>& a, vector<int>& b) {int n b.size();vector<vector<long long>> memo(4,vector<long long>(b.size(), LLONG_MIN));auto dfs [&](auto&& dfs, i…

Qt安卓开发连接手机调试(红米K60为例)

1.前置条件 本人默认您已经完成Qt安卓环境的配置&#xff0c;若还没配置请参考链接文章&#xff1a;【Qt】最详细教程&#xff0c;如何从零配置Qt Android安卓环境_qt_七夕先生-开放原子开发者工作坊。准备一台目前主流在用的手机&#xff0c;其实自己用的就行(只要你不是某些…

LeetCode-137. 只出现一次的数字 II【位运算 数组】

LeetCode-137. 只出现一次的数字 II【位运算 数组】 题目描述&#xff1a;解题思路一&#xff1a;解题思路二&#xff1a;符号位一起判断。背诵版解题思路三&#xff1a;0 题目描述&#xff1a; 给你一个整数数组 nums &#xff0c;除某个元素仅出现 一次 外&#xff0c;其余每…

CentOS7.9环境上NFS搭建及使用

Linux环境NFS搭建及使用 1. 服务器规划2. NFS服务器配置2.1 主机名设置2.2 nfs安装2.2.1 repo文件替换2.2.2 NFS服务安装 2.3 nfs配置2.4 服务查看2.5 资源发布2.6 配置nfs服务开机自启2.7 端口开放 3.应用服务器配置3.1 主机名设置3.2 nfs安装3.2.1 repo文件替换3.2.2 NFS服务…

XML映射器-高级查询结果映射

01-高级查询结果映射 emp表 dept表 02-多表关联查询映射 多对一映射 项目中Emp类的数据 项目中dept类的数据 想要多表查询需要建个公共类里面写入两个表中的属性,如下面方法 type里要写用到的类型,由于继承Emp所有Emp里面的属性直接写,column是写数据库的别字,property是写字…

WSL中使用AMBER GPU串行版

前提是已经安装过wsl 1 在 WSL 2 中启用 NVIDIA CUDA 参考在 WSL 2 上启用 NVIDIA CUDA | Microsoft Learn 注意&#xff1a;勿在 WSL 中安装任何 Linux 显示驱动程序。Windows 显示驱动程序将同时安装本机 Windows 和 WSL 支持的常规驱动程序组件。 2 在WSL2中配置Cuda 不安…

SEO之页面优化(一-页面标题2)

初创企业搭建网站的朋友看1号文章&#xff1b;想学习云计算&#xff0c;怎么入门看2号文章谢谢支持&#xff1a; 1、我给不会敲代码又想搭建网站的人建议 2、“新手上云”能够为你开启探索云世界的第一步 博客&#xff1a;阿幸SEO~探索搜索排名之道 &#xff08;接上一篇。。…

Fiddler抓包工具实战

文章目录 &#x1f7e2; Fiddler入门到精通&#x1f449;主要功能&#x1f449;使用场景 &#x1f7e2; 一、Fiddler抓包和F12抓包对比&#x1f7e2; 二、Fiddler的核心功能&#x1f7e2; 三、Fiddler的工作原理&#x1f7e2; 四、Fiddler功能配置使用&#x1f449;规则设置&am…

信息学奥赛报考指南

近年来&#xff0c;信息学奥林匹克竞赛&#xff08;NOI&#xff09;越来越受到家长和学生的重视。这项竞赛不仅能培养孩子的编程与算法思维&#xff0c;还为优秀的选手提供了进入国内顶尖大学的保送资格&#xff0c;并有机会参加国际级赛事。因此&#xff0c;许多家长都希望了解…

Microsoft Office LTSC 2024 离线安装ISO镜像

微软近日正式发布了 Microsoft Office 2024 LTSC&#xff08;长期服务版本&#xff09;&#xff0c;为那些不适用订阅模式的企业提供了最新的 Office 版本。LTSC 版本专为需要稳定、长期支持的用户设计&#xff0c;适合需要长期部署在特定硬件环境中的企业或组织。 ​从Office …

2024年【四川省安全员B证】新版试题及四川省安全员B证考试试卷

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 四川省安全员B证新版试题参考答案及四川省安全员B证考试试题解析是安全生产模拟考试一点通题库老师及四川省安全员B证操作证已考过的学员汇总&#xff0c;相对有效帮助四川省安全员B证考试试卷学员顺利通过考试。 1、…

9.18日常记录

一.信号和槽机制 信号和槽:是对象之间通信的一种机制 信号classA不关心有多少槽函数与之绑定&#xff0c;它只管触发信号&#xff0c;具体要触发哪些槽函数&#xff0c;是由Qt的信号和槽机制来实现的。这样的话就充分的体现了面向对象的解耦原则了&#xff0c;因为对于classA来…

CSS概览

概述 是什么 cascading style css 层叠样式表 由W3C制定的网页元素定义规则 为什么 美化 怎么办 设置样式 布局 css 引入 内部样式表 在head标签内部使用style标签 <html><head><style>.id{width: 400px;height: 400px;border: 1px solid black;ma…

Revit API:Element 的分类

前言 Revit的继承体系以Element作为作为最上层的元素&#xff0c;在这个体系里面&#xff0c;所有的构件都是从 Element 派生出来的。我们可以把这个派生的关系本身当作一个分类方式&#xff0c;但这种方式分的类别太多了&#xff0c;不一定可以记住。参考&#xff1a;Revit A…