PyTorch应用实战一:实现卷积操作

文章目录

  • 实验环境
  • 0.卷积定义
  • 1.利用张量操作实现卷积
    • 1.1 unfold函数
    • 1.2 张量分片
  • 2.实现卷积操作
    • 2.1 编写卷积函数
    • 2.2 对编写的卷积函数举例分析
    • 2.3 验证编写卷积函数的正确性
  • 附:系列文章

实验环境

python3.6 + pytorch1.8.0

import torch
print(torch.__version__)
1.8.0

0.卷积定义

卷积操作是指两个函数f和g之间的一种数学运算,它在信号处理、图像处理、机器学习等领域中广泛应用。在离散情况下,卷积操作可以表示为:

( f ∗ g ) [ n ] = ∑ m = − ∞ ∞ f [ m ] g [ n − m ] (f * g)[n] = \sum_{m=-\infty}^{\infty}f[m]g[n-m] (fg)[n]=m=f[m]g[nm]

其中, f f f g g g是离散函数, ∗ * 表示卷积操作, n n n是离散的变量。卷积操作可以看作是将函数 g g g沿着 n n n轴翻转,然后平移,每次和函数 f f f相乘并求和,最后得到一个新的函数。这种操作可以实现信号的滤波、特征提取等功能,是数字信号处理中非常重要的基础操作。

1.利用张量操作实现卷积

1.1 unfold函数

PyTorch的unfold函数用于对张量进行展开操作。torch.unfold()可以理解为将一个高维的张量展开成一个二维矩阵的操作。即将原来的张量沿着指定的维度展开成一个二维矩阵,其中第一维对应原来张量的维度,第二维对应展开的位置。

函数原型如下:

torch.unfold(input, dimension, size, step)

参数说明:

  • input (Tensor) – 要展开的张量
  • dimension (int) – 沿着哪个维度展开
  • size (int) – 展开窗口的大小
  • step (int) – 两个相邻窗口之间的步长

1.2 张量分片

import torch
a = torch.arange(16).view(4, 4)
a
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15]])
b = a.unfold(0, 3, 1)
b
tensor([[[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]],[[ 4,  8, 12],[ 5,  9, 13],[ 6, 10, 14],[ 7, 11, 15]]])
b.shape
torch.Size([2, 4, 3])
c = b.unfold(1, 3, 1)
c
tensor([[[[ 0,  1,  2],[ 4,  5,  6],[ 8,  9, 10]],[[ 1,  2,  3],[ 5,  6,  7],[ 9, 10, 11]]],[[[ 4,  5,  6],[ 8,  9, 10],[12, 13, 14]],[[ 5,  6,  7],[ 9, 10, 11],[13, 14, 15]]]])
c.shape
torch.Size([2, 2, 3, 3])

完整程序

import torch
a = torch.arange(16).view(4, 4)
b = a.unfold(0, 3, 1)
c = b.unfold(1, 3, 1)
c.shape
torch.Size([2, 2, 3, 3])

这段代码定义了三个变量。假设我们将其分别命名为abc,则:

  • 变量a是一个4x4的张量,其中包含了0到15的整数值,它通过torch.arange(16).view(4, 4)两个函数调用来实现。
  • 变量b是通过对变量a进行折叠操作得到的一个张量,具体来说,它是将变量a沿着第0维(即行)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量b打印出来,会得到:
tensor([[[ 0,  1,  2],[ 4,  5,  6],[ 8,  9, 10],[12, 13, 14]],[[ 1,  2,  3],[ 5,  6,  7],[ 9, 10, 11],[13, 14, 15]]])

其中,第一个子张量的值为[[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]],第二个子张量的值为[[1, 2, 3], [5, 6, 7], [9, 10, 11], [13, 14, 15]]。注意,这个张量的形状为(2, 4, 3),即它包含2个子张量,每个子张量的形状为(4, 3)

  • 变量c是对变量b进行类似的操作得到的,但是它是在第1维(即列)上展开并取子张量。具体来说,它是将变量b沿着第1维(即列)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量c打印出来,会得到:
tensor([[[[ 0,  1,  2],[ 4,  5,  6],[ 8,  9, 10]],[[ 1,  2,  3],[ 5,  6,  7],[ 9, 10, 11]]],[[[ 4,  5,  6],[ 8,  9, 10],[12, 13, 14]],[[ 5,  6,  7],[ 9, 10, 11],[13, 14, 15]]]])

其中,第一个子张量的值为[[[0, 1, 2], [4, 5, 6], [8, 9, 10]], [[1, 2, 3], [5, 6, 7], [9, 10, 11]]],第二个子张量的值为[[[4, 5, 6], [8, 9, 10], [12, 13, 14]], [[5, 6, 7], [9, 10, 11], [13, 14, 15]]]。注意,这个张量的形状为(2, 2, 3, 3),即它包含2个子张量,每个子张量的形状为(2, 3, 3)

2.实现卷积操作

2.1 编写卷积函数

完整程序

import torch
def conv2d(x, weight, bias, stride, pad):n, c, h, w = x.shaped, c, k, j = weight.shapex_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)x_pad[:, :, pad:-pad, pad:-pad] = xx_pad = x_pad.unfold(2, k, stride)x_pad = x_pad.unfold(3, j, stride)out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)out = out + bias.view(1, -1, 1, 1)return out

该函数实现了二维卷积操作。下面对函数进行详细分析:

  1. 输入参数:
  • x: 输入张量,维度为(batch_size,in_channels,input_height,input_width)。
  • weight: 卷积核张量,维度为(out_channels,in_channels,kernel_height,kernel_width)。
  • bias: 偏置项张量,维度为(out_channels,)。
  • stride: 卷积核移动的步长,可以是一个数或是一个长度为 2 的元组,分别表示水平方向和竖直方向的步长。
  • pad: 输入张量周围要填充的零的数量。
  1. 局部填充:
  • 在进行卷积操作之前,需要在输入张量的周围按照给定的 pad 进行填充,以避免卷积核在张量边缘处超出范围的情况发生。
  • 在函数中使用 x_pad 表示经过填充后的输入张量。
  • 具体实现:将输入张量 x 在第 2 和第 3 个维度(height 和 width 维度)上分别拆分成若干个形状为(kernel_height,kernel_width)的张量,每个张量之间的跳跃长度由 stride 决定,然后在第 2 和第 3 个维度上分别进行展开。这样每个展开后的张量就可以看作一个二维卷积核作用在 x 上的局部卷积结果,这些局部结果被按照第 2 和第 3 个维度重新拼接起来,得到新的张量 x_pad。
  1. 卷积操作:
  • 在新的张量 x_pad 上使用 einsum 函数对卷积核进行卷积操作。

  • einsum 的第一个参数表示操作的规则,其中 ndhw 表示最终输出的张量的维度为(batch_size,out_channels,output_height,output_width),nchw 和 dckj 表示两个输入张量 x_pad 和 weight 的维度,其中 c k j 分别表示 input_channels、kernel_height 和 kernel_width。

  • 最终得到的输出张量形状为(batch_size,out_channels,output_height,output_width),并在每个位置上加上偏置项 bias。

  • torch: PyTorch库

  • einsum: Einstein summation notation,爱因斯坦求和约定,一种张量求和的简便表示法。

  • 'nchwkj,dckj->ndhw': 爱因斯坦求和符号,左侧的张量为x_pad,右侧的张量为weight。在左侧张量中,n, c, h, w分别表示batch size、通道数、高度、宽度。在右侧张量中,d, c, k, j分别表示输出通道数、输入通道数、卷积核高度、卷积核宽度。这个式子的意义是将x_padweight执行卷积操作,并输出结果张量,其形状为(batch_size, output_channels, height, width)

  • x_pad: 输入的张量,形状为 (batch_size, input_channels, input_height, input_width)

  • weight: 卷积核张量,形状为 (output_channels, input_channels, kernel_height, kernel_width)

  1. 返回结果:
  • 返回卷积操作后得到的输出张量。

2.2 对编写的卷积函数举例分析

# 设置测试数据
x = torch.randn(2, 3, 5, 5, requires_grad=True)
weight = torch.randn(4, 3, 3, 3, requires_grad=True)
bias = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
x, weight, bias
(tensor([[[[-0.4888,  1.0257,  0.0312, -0.9026, -0.9060],[ 0.2071, -0.4962, -0.1658,  1.0919,  0.3785],[-0.4654,  1.5442,  0.6005,  0.3594, -2.6207],[ 0.5830,  0.0533,  0.5719,  1.5413,  0.5949],[-0.9152, -0.2114, -0.4888, -0.0065, -0.9767]],[[ 0.4706, -0.1108, -0.1563, -1.7946, -0.8533],[-0.2119,  0.3165, -2.2668, -0.8956,  1.0617],[-0.7809, -0.2120, -0.8592, -0.5057,  0.7954],[-2.8820, -0.6888,  0.4450, -0.3586, -0.9477],[ 0.6244,  0.4303,  1.4739,  0.2740,  1.6605]],[[-0.1501,  0.6234, -1.6086,  0.1693,  0.4932],[ 1.0611, -1.0938,  0.1695,  1.0193,  0.4263],[ 1.4681, -0.1552, -0.0667, -0.7293,  1.0816],[ 0.8972,  1.1683, -1.4757,  0.4421, -0.0355],[-2.1331,  1.4847,  0.1378, -1.6907, -0.1350]]],[[[-1.3853,  1.6396,  0.3436,  0.3841,  0.2355],[-0.2206, -0.5087, -1.6956,  1.3205,  0.7058],[ 0.0993,  0.3533, -0.2086,  0.2969,  0.2627],[ 0.3752,  0.0304,  1.2487,  1.3963, -0.0063],[-1.3758,  0.5088, -1.3849,  1.3050,  0.4150]],[[ 0.2824, -2.8634, -0.1016, -0.1627,  1.7081],[ 0.1406,  0.2220, -0.6005,  0.2997, -0.1846],[ 1.6700,  0.5787,  0.6561, -0.0236,  1.7743],[ 2.1429, -0.2838, -0.0527,  0.3504, -0.3444],[-0.9409, -0.4734, -0.4060, -0.5088, -1.8518]],[[-2.2152,  0.2104, -0.3302,  0.2036, -0.9443],[-0.6576, -0.4455,  0.5117, -2.0058, -1.3985],[-0.5688,  1.2338, -0.1832,  0.1760,  0.4506],[-0.6563,  0.4021, -1.6210,  0.5582, -0.9238],[-1.0506, -0.9638,  0.7453, -0.3535, -0.3536]]]], requires_grad=True),tensor([[[[ 0.3069,  0.2079, -0.2952],[ 1.7681,  1.1056, -1.0555],[ 1.5845,  0.8294,  0.6588]],[[ 0.2574,  0.5007,  0.2912],[-0.0210,  0.6593, -0.9691],[-0.2918,  0.5695, -1.1242]],[[ 0.7327, -0.3453,  0.7041],[-0.2236, -1.7762,  0.0190],[-1.0927, -2.9369,  0.1768]]],[[[-2.3830, -1.4807,  1.8573],[ 1.0097, -0.9640,  1.0361],[-0.5222, -1.0386, -0.4016]],[[ 0.5071,  1.1433, -0.1194],[-0.0133, -0.3878, -0.1853],[ 0.3456, -0.6502,  0.2221]],[[-1.7672, -0.0469, -0.5996],[-0.2080, -1.6209,  0.4120],[ 0.8404, -1.6748, -0.7170]]],[[[ 0.2850,  0.1691, -0.9228],[ 0.7234,  0.5582, -0.4327],[ 0.6563,  0.2941,  1.5549]],[[ 0.2642, -1.9061,  1.6212],[-0.5276, -0.5608,  0.3824],[ 0.4452, -2.5152,  0.4490]],[[-0.1276,  0.7784,  0.7998],[-0.3030, -0.9776,  0.9681],[ 1.0225,  0.8946, -0.8084]]],[[[-0.5087, -0.8345, -1.4763],[-0.4938,  1.1979, -0.1335],[ 0.5010,  0.2865,  0.0728]],[[-0.3177, -0.6937, -1.0327],[ 0.8147, -1.7101, -1.8257],[-0.1593, -1.3855, -0.0885]],[[-0.4687, -1.6307,  1.5791],[-1.3030,  0.2004, -0.7055],[ 0.0674, -0.8772,  0.1586]]]], requires_grad=True),tensor([ 1.5349, -0.5608,  0.5182,  0.3328], requires_grad=True))
n, c, h, w = x.shape
d, c, k, j = weight.shape
n, c, h, w
(2, 3, 5, 5)
d, c, k, j
(4, 3, 3, 3)
# 补零
x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
x_pad
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.]]],[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.]],[[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
x_pad.shape
torch.Size([2, 3, 9, 9])
x_pad[:, :, pad:-pad, pad:-pad] = x
x_pad
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000, -0.4888,  1.0257,  0.0312, -0.9026, -0.9060,0.0000,  0.0000],[ 0.0000,  0.0000,  0.2071, -0.4962, -0.1658,  1.0919,  0.3785,0.0000,  0.0000],[ 0.0000,  0.0000, -0.4654,  1.5442,  0.6005,  0.3594, -2.6207,0.0000,  0.0000],[ 0.0000,  0.0000,  0.5830,  0.0533,  0.5719,  1.5413,  0.5949,0.0000,  0.0000],[ 0.0000,  0.0000, -0.9152, -0.2114, -0.4888, -0.0065, -0.9767,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000]],[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.4706, -0.1108, -0.1563, -1.7946, -0.8533,0.0000,  0.0000],[ 0.0000,  0.0000, -0.2119,  0.3165, -2.2668, -0.8956,  1.0617,0.0000,  0.0000],[ 0.0000,  0.0000, -0.7809, -0.2120, -0.8592, -0.5057,  0.7954,0.0000,  0.0000],[ 0.0000,  0.0000, -2.8820, -0.6888,  0.4450, -0.3586, -0.9477,0.0000,  0.0000],[ 0.0000,  0.0000,  0.6244,  0.4303,  1.4739,  0.2740,  1.6605,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000]],[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000, -0.1501,  0.6234, -1.6086,  0.1693,  0.4932,0.0000,  0.0000],[ 0.0000,  0.0000,  1.0611, -1.0938,  0.1695,  1.0193,  0.4263,0.0000,  0.0000],[ 0.0000,  0.0000,  1.4681, -0.1552, -0.0667, -0.7293,  1.0816,0.0000,  0.0000],[ 0.0000,  0.0000,  0.8972,  1.1683, -1.4757,  0.4421, -0.0355,0.0000,  0.0000],[ 0.0000,  0.0000, -2.1331,  1.4847,  0.1378, -1.6907, -0.1350,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000]]],[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000, -1.3853,  1.6396,  0.3436,  0.3841,  0.2355,0.0000,  0.0000],[ 0.0000,  0.0000, -0.2206, -0.5087, -1.6956,  1.3205,  0.7058,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0993,  0.3533, -0.2086,  0.2969,  0.2627,0.0000,  0.0000],[ 0.0000,  0.0000,  0.3752,  0.0304,  1.2487,  1.3963, -0.0063,0.0000,  0.0000],[ 0.0000,  0.0000, -1.3758,  0.5088, -1.3849,  1.3050,  0.4150,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000]],[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.2824, -2.8634, -0.1016, -0.1627,  1.7081,0.0000,  0.0000],[ 0.0000,  0.0000,  0.1406,  0.2220, -0.6005,  0.2997, -0.1846,0.0000,  0.0000],[ 0.0000,  0.0000,  1.6700,  0.5787,  0.6561, -0.0236,  1.7743,0.0000,  0.0000],[ 0.0000,  0.0000,  2.1429, -0.2838, -0.0527,  0.3504, -0.3444,0.0000,  0.0000],[ 0.0000,  0.0000, -0.9409, -0.4734, -0.4060, -0.5088, -1.8518,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000]],[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000, -2.2152,  0.2104, -0.3302,  0.2036, -0.9443,0.0000,  0.0000],[ 0.0000,  0.0000, -0.6576, -0.4455,  0.5117, -2.0058, -1.3985,0.0000,  0.0000],[ 0.0000,  0.0000, -0.5688,  1.2338, -0.1832,  0.1760,  0.4506,0.0000,  0.0000],[ 0.0000,  0.0000, -0.6563,  0.4021, -1.6210,  0.5582, -0.9238,0.0000,  0.0000],[ 0.0000,  0.0000, -1.0506, -0.9638,  0.7453, -0.3535, -0.3536,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,0.0000,  0.0000]]]], grad_fn=<CopySlices>)
# 卷积
x_pad = x_pad.unfold(2, k, stride)
x_pad.shape
torch.Size([2, 3, 4, 9, 3])
x_pad = x_pad.unfold(3, j, stride)
x_pad.shape
torch.Size([2, 3, 4, 4, 3, 3])
out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
out.shape
torch.Size([2, 4, 4, 4])
bias.view(1, -1, 1, 1).shape
torch.Size([1, 4, 1, 1])
# 偏置
out = out + bias.view(1, -1, 1, 1)
out
tensor([[[[ 0.6573, -0.3444,  1.5693, -0.1906],[ 2.5483,  5.1142, -2.3528, -3.6162],[ 2.9913, -6.1289,  6.8200,  0.9229],[ 0.4849,  0.1813,  3.2616,  1.5637]],[[-0.1524, -1.2003, -0.3415,  0.0318],[-1.7830,  2.5286, -1.6660,  3.1253],[ 1.3314, -8.2623, -5.0055,  5.7671],[-1.0563,  5.2751, -0.4214,  2.8473]],[[ 0.0908,  2.6704,  1.0336,  0.0481],[ 0.2077,  2.0459,  1.8095, -0.7039],[ 0.9519, -4.5551,  3.7108,  0.7446],[ 0.6689,  3.9448,  2.3968,  0.6958]],[[ 0.2318, -0.3356,  2.4320,  0.0480],[ 0.2101, -1.7177,  6.3956, -0.4108],[ 8.2352, -5.8456, 12.9459, -0.8763],[-2.3292, -1.5263,  2.1349,  0.3653]]],[[[-0.0869,  1.0713,  0.1655,  2.4414],[-1.3623, -3.0759,  0.2430,  2.3259],[-0.9281,  2.1402,  7.1618,  4.1895],[ 0.9273,  1.1176,  0.7792,  0.9265]],[[ 1.6465, -1.7187, -0.7251, -0.8871],[-1.6260, -0.8628, -1.0122,  3.2737],[ 0.5831,  2.1665, -0.5353, -2.0468],[-2.3738, -0.1232, -0.0771, -1.8642]],[[ 0.2819,  6.0978,  2.9618,  0.4676],[ 1.3592,  6.7231,  3.8100,  3.6118],[ 0.9885, -5.7760,  5.4375,  0.5480],[-0.5778,  1.4657, -2.8315,  0.1923]],[[-0.1444,  3.6788,  0.3721,  0.1150],[-1.4057,  0.1613, -2.5436,  1.3156],[-6.1195,  1.8325,  3.1565,  0.8296],[ 1.6766,  6.9403,  1.3986,  0.8758]]]], grad_fn=<AddBackward0>)

2.3 验证编写卷积函数的正确性

import torch.nn.functional as F
x = torch.randn(2, 3, 5, 5, requires_grad=True)
w = torch.randn(4, 3, 3, 3, requires_grad=True)
b = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
torch_out = F.conv2d(x, w, b, stride, pad)
my_out = conv2d(x, w, b, stride, pad)
torch_out == my_out
tensor([[[[ True,  True,  True,  True],[ True, False, False,  True],[ True,  True, False,  True],[ True,  True, False,  True]],[[ True,  True,  True,  True],[ True, False,  True,  True],[False,  True,  True,  True],[ True,  True,  True,  True]],[[ True, False, False,  True],[False, False,  True,  True],[ True, False, False,  True],[ True,  True, False,  True]],[[ True,  True, False,  True],[False, False, False,  True],[ True, False, False, False],[ True,  True,  True,  True]]],[[[ True,  True,  True,  True],[False,  True, False,  True],[ True,  True, False,  True],[ True, False,  True,  True]],[[ True,  True, False,  True],[ True, False, False, False],[ True, False, False, False],[ True, False,  True,  True]],[[ True,  True, False,  True],[False, False, False,  True],[ True, False, False, False],[ True, False,  True,  True]],[[ True, False,  True,  True],[ True, False, False,  True],[False, False, False, False],[ True,  True, False,  True]]]])
torch.allclose(torch_out, my_out, atol=1e-5)
True
  • torch.allclose是用于检查两个张量之间的数值是否相等的函数。

  • 在使用时,需要将第一个张量作为第一个参数传入(即torch_out),将第二个张量作为第二个参数传入(即my_out),并将允许的绝对误差(atol)作为第三个参数传入(默认值为1e-8)。

  • 函数将返回一个布尔值,表示两个张量是否具有相近的数值。如果返回True,则表示两个张量具有相近的数值,否则表示它们之间存在数值差异。

grad_out = torch.randn(*torch_out.shape)
grad_x = torch.autograd.grad(torch_out, x, grad_out, retain_graph=True)
my_grad_x = torch.autograd.grad(my_out, x, grad_out, retain_graph=True)
torch.allclose(grad_x[0], my_grad_x[0], atol=1e-5)
True
grad_w = torch.autograd.grad(torch_out, w, grad_out, retain_graph=True)
my_grad_w = torch.autograd.grad(my_out, w, grad_out, retain_graph=True)
torch.allclose(grad_w[0], my_grad_w[0], atol=1e-5)
True
grad_b = torch.autograd.grad(torch_out, b, grad_out, retain_graph=True)
my_grad_b = torch.autograd.grad(my_out, b, grad_out, retain_graph=True)
torch.allclose(grad_b[0], my_grad_b[0], atol=1e-5)
True

全是True,表明编写的卷积函数在一定范围内与PyTorch内置的Conv2d函数结果相近,说明了实现的正确性

附:系列文章

序号文章目录直达链接
1PyTorch应用实战一:实现卷积操作https://want595.blog.csdn.net/article/details/132575530
2PyTorch应用实战二:实现卷积神经网络进行图像分类https://want595.blog.csdn.net/article/details/132575702
3PyTorch应用实战三:构建神经网络https://want595.blog.csdn.net/article/details/132575758
4PyTorch应用实战四:基于PyTorch构建复杂应用https://want595.blog.csdn.net/article/details/132625270
5PyTorch应用实战五:实现二值化神经网络https://want595.blog.csdn.net/article/details/132625348
6PyTorch应用实战六:利用LSTM实现文本情感分类https://want595.blog.csdn.net/article/details/132625382

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/147630.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

dbeaver 1064 42000 错误 query execution failed

编辑驱动属性&#xff0c;将 allowMultiQueries 设置为 true 参考 https://blog.csdn.net/u200814342A/article/details/132458960

管道-有名管道

一、有名管道 有名管道与匿名管道的不同&#xff1a; 有名管道提供了一个路径名&#xff0c;并以FIFO的文件形式存在于文件系统中。与匿名管道不同&#xff0c;有名管道可以被不相关的进程使用&#xff0c;只要它们可以访问该路径&#xff0c;就能够通过有名管道进行通信。 FI…

安防监控产品经营商城小程序的作用是什么

安防监控产品覆盖面较大&#xff0c;监控器、门禁、对讲机、烟感等都有很高用途&#xff0c;家庭、办公单位各场景往往用量不少&#xff0c;对商家来说&#xff0c;市场高需求背景下也带来了众多生意&#xff0c;但线下门店的局限性&#xff0c;导致商家想要进一步增长不容易。…

Flutter笔记 - ListTile组件及其应用

Flutter笔记 ListTile组件及其应用 作者&#xff1a;李俊才 &#xff08;jcLee95&#xff09;&#xff1a;https://blog.csdn.net/qq_28550263 邮箱 &#xff1a;291148484163.com 本文地址&#xff1a;https://blog.csdn.net/qq_28550263/article/details/133411883 目 录 1. …

【Java每日一题】— —第十九题:用二维数组存放九九乘法表,并将其输出。(2023.10.03)

&#x1f578;️Hollow&#xff0c;各位小伙伴&#xff0c;今天我们要做的是第十九题。 &#x1f3af;问题&#xff1a; 用二维数组存放九九乘法表&#xff0c;并将其输出。 测试结果如下&#xff1a; &#x1f3af; 答案&#xff1a; System.out.println("九九乘法表如…

Python中匹配模糊的字符串

嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 如何使用thefuzz 库&#xff0c;它允许我们在python中进行模糊字符串匹配。 此外&#xff0c;我们将学习如何使用process 模块&#xff0c;该模块允许我们在模糊…

【AI视野·今日NLP 自然语言处理论文速览 第四十五期】Mon, 2 Oct 2023

AI视野今日CS.NLP 自然语言处理论文速览 Mon, 2 Oct 2023 Totally 44 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Efficient Streaming Language Models with Attention Sinks Authors Guangxuan Xiao, Yuandong Tian, Beidi C…

Redis与分布式-分布式锁

接上文 Redis与分布式-集群搭建 1.分布式锁 为了解决上述问题&#xff0c;可以利用分布式锁来实现。 重新复制一份redis&#xff0c;配置文件都是刚下载时候的不用更改&#xff0c;然后启动redis服务和redis客户。 redis存在这样的命令&#xff1a;和set命令差不多&#xff0…

淘宝天猫渠道会员购是什么意思?如何开通天猫淘宝渠道会员购有什么用?

淘宝天猫渠道会员购是什么意思&#xff1f; 淘宝天猫渠道会员购与淘宝天猫粉丝福利购意思基本相同&#xff0c;都可以领取淘宝天猫大额内部隐藏优惠券、通过草柴APP开通绑定渠道会员还可以获得购物返利。 草柴APP如何绑定开通淘宝天猫渠道会员&#xff1f; 1、手机下载安装「…

笔记二:odoo搜索、筛选和分组

一、搜索 1、xml代码 <!--搜索和筛选--><record id"view_search_book_message" model"ir.ui.view"><field name"name">book_message</field><field name"model">book_message</field><field…

“把握拐点,洞悉投资者情绪与比特币价格的未来之路!“

“本来这篇文章是昨天晚上发的&#xff0c;国庆节庆祝喝多了&#xff0c;心有余而力不足&#xff01;直接头躺马桶GG了” 标准普尔 500 指数 200 天移动平均线云是我几个月来一直分享的下行目标&#xff0c;上周正式重新测试了该目标。200 日移动平均线云表示为: 200 天指数移…

Linux(CentOS/Ubuntu)——安装nginx

如果确定你的系统是基于CentOS或RHEL&#xff0c;可以使用以下命令&#xff1a; ①、安装库文件 #安装gcc yum install gcc-c#安装PCRE pcre-devel yum install -y pcre pcre-devel#安装zlib yum install -y zlib zlib-devel#安装Open SSL yum install -y openssl openssl-de…

【JVM】垃圾回收(GC)详解

垃圾回收&#xff08;GC&#xff09;详解 一. 死亡对象的判断算法1. 引用计数算法2. 可达性分析算法 二. 垃圾回收算法1. 标记-清除算法2. 复制算法3. 标记-整理算法4. 分代算法 三. STW1. 为什么要 STW2. 什么情况下 STW 四. 垃圾收集器1. CMS收集器&#xff08;老年代收集器&…

React antd Table点击下一页后selectedRows丢失之前页选择内容的问题

一、问题 使用了React antd 的<Table>标签&#xff0c;是这样记录选中的行id与行内容的&#xff1a; <TabledataSource{data.list}rowSelection{{selectedRowKeys: selectedIdsInSearchTab,onChange: this.onSelectChange,}} // 表格是否可复选&#xff0c;加 type: …

cygwin编译haproxy

下载安装cygwin cygwin下载、安装-CSDN博客 编译haproxy 打开cygwin终端 下载程序 haproxy程序 OpenPKG Project: Download 输入下面命令下载程序 wget http://download.openpkg.org/components/cache/haproxy/haproxy-2.8.3.tar.gz 解压 tar -zxvf haproxy-2.8.3.tar.gz…

正则表达式的应用(前端写法)

文章目录 1、匹配字符串中&#xff0c;a标签的href值2、校验邮箱3、校验手机号码3、待添加... 1、匹配字符串中&#xff0c;a标签的href值 (1) 代码 /*** description 匹配字符串中&#xff0c;a标签的href值* param {string} str 匹配的字符串* return {Array} 返回href值*/…

OpenGLES:绘制一个彩色、旋转的3D立方体

一.概述 之前关于OpenGLES实战开发的博文&#xff0c;不论是实现相机滤镜还是绘制图形&#xff0c;都是在2D纬度 这篇博文开始&#xff0c;将会使用OpenGLES进入3D世界 本篇博文会实现一个颜色渐变、旋转的3D立方体 动态3D图形的绘制&#xff0c;需要具备一些基础的线性代数…

c++学习之优先级队列

目录 1.初识优先级队列 库中的实现 使用优先级队列 2.优先级队列的实现 3.仿函数 利用仿函数实现的优先级队列 迭代器区间构造&#xff08;建堆&#xff09; 1.初识优先级队列 如果我们给每个元素都分配一个数字来标记其优先级&#xff0c;不妨设较小的数字具有较…

国庆10.03

运算符重载 代码 #include <iostream> using namespace std; class Num { private:int num1; //实部int num2; //虚部 public:Num(){}; //无参构造Num(int n1,int n2):num1(n1),num2(n2){}; //有参构造~Num(){}; //析构函数const Num operator(const Num &other)co…

OpenGLES:绘制一个彩色、旋转的3D圆柱

一.概述 上一篇博文讲解了怎么绘制一个彩色旋转的立方体 这一篇讲解怎么绘制一个彩色旋转的圆柱 圆柱的顶点创建主要基于2D圆进行扩展&#xff0c;与立方体没有相似之处 圆柱绘制的关键点就是将圆柱拆解成&#xff1a;两个Z坐标不为0的圆 一个长方形的圆柱面 绘制2D圆的…