Pytorch的基础用法

目录

张量 

正态分布与随机分布:

张量的尺寸:

张量的基本属性:

创建张量:

张量构造器:

其他张量创建方法:

索引与切片

索引(Indexing)

单维张量索引

多维张量索引

切片(Slicing)

单维张量切片

多维张量切片

布尔索引(Boolean Indexing)

高级索引(Advanced Indexing)

联合索引

广播机制 

广播规则

示例

示例 1:简单的一维张量

示例 2:多维张量

不适用的情况

合并与分割 

合并

torch.cat (Concatenate)

torch.stack

总结

分割

torch.chunk

torch.split

总结

矩阵乘法 

1. torch.matmul (点积)

2. torch.mm (矩阵乘法)

3. torch.bmm (批量矩阵乘法)

总结

数据处理

1.torch.norm 函数

常见的范数类型

2. torch.max 和 torch.min(最大值和最小值)

3. torch.mean(均值)

4. torch.prod(乘积)

5. torch.argmax 和 torch.argmin(最大值和最小值的索引)

6. torch.topk(最大值及其索引)

7. torch.kthvalue(第k个值)

8. 比较运算(大于、小于、等于)

9. keepdim 参数

高级用法(where,gather)

torch.where 函数的基本语法

示例

示例 1:基于条件选择元素

示例 2:基于条件修改元素

多维张量示例

总结

torch.gather 的基本语法

示例

示例 1:一维索引

示例 2:二维索引

实际应用示例

示例 3:在自然语言处理中的应用

总结


张量 

  1. 正态分布与随机分布:

    • torch.randn(size) 生成一个给定形状 size 的张量,其中每个元素都来自一个标准正态分布(均值为0,标准差为1)。注意这里的元素值不是限制在0-1之间的。
    • torch.rand(size) 生成一个给定形状 size 的张量,其中每个元素都来自一个0到1之间的均匀分布。
  2. 张量的尺寸:

    • a.size() 和 a.shape 都是用来获取张量尺寸的方法,其中 a.shape 是属性而 a.size() 是一个方法。两者都会返回一个元组,表示张量的各个维度大小。
    • a.size(1) 或 a.shape[1] 获取张量的第二个维度的长度(在Python中索引是从0开始的,因此第二个维度的索引是1)。
  3. 张量的基本属性:

    • a.numel() 返回张量中元素的总数。
    • a.dim() 返回张量的维度数。
  4. 创建张量:

    • torch.tensor(5) 创建一个标量张量,即一个0维张量,其值为5。
    • torch.from_numpy(a) 将NumPy数组转换为PyTorch张量,其中 a 是NumPy数组。
  5. 张量构造器:

    • torch.Tensor 通常是一个构造函数,用于创建一个新的张量。如果没有提供参数,默认创建一个浮点张量。
    • torch.full(size, fill_value) 创建一个给定形状 size 的张量,并用 fill_value 填充所有元素。
    • torch.ones(size) 创建一个给定形状 size 的张量,所有元素均为1。
    • torch.zeros(size) 创建一个给定形状 size 的张量,所有元素均为0。
    • torch.eye(n) 创建一个n×n的单位矩阵,适用于二维张量。
  6. 其他张量创建方法:

    • torch.arange(start, end, step) 生成一个从 start 到 end (不包括 end)的等差序列。
    • torch.linspace(start, end, steps) 在 start 和 end 之间生成 steps 个等间距的样本,包括 start 和 end
    • torch.randperm(n) 返回一个长度为n的张量,包含从0到n-1的随机排列。

请注意:

  • torch.range 已经被弃用,推荐使用 torch.arange
  • torch.normal 方法用于生成正态分布的样本,而不是压缩张量。
  • torch.randperm 可以用来生成一个随机排列的长整型张量。

索引与切片

索引(Indexing)

索引允许您通过指定一个或多个索引来访问张量中的单个元素或多维子集。

单维张量索引

对于一维张量,索引就像列表一样简单:

 
1import torch
2
3t = torch.tensor([1, 2, 3, 4])
4print(t[0])  # 访问第一个元素
5print(t[-1])  # 访问最后一个元素
多维张量索引

对于多维张量,您需要为每个维度提供一个索引:

 
1t = torch.tensor([[1, 2], [3, 4]])
2print(t[0, 1])  # 访问第一行第二列的元素

切片(Slicing)

切片允许您访问张量的一个连续子集。切片可以通过提供起始位置、结束位置和步长来实现。

单维张量切片
1t = torch.tensor([1, 2, 3, 4, 5])
2print(t[1:3])  # 访问索引1到2的元素 (不包括3)
3print(t[:3])   # 访问前三个元素
4print(t[3:])   # 访问从索引3到最后的所有元素
5print(t[::2])  # 每隔一个元素取一个
多维张量切片

对于多维张量,可以在每个维度上进行切片:

1t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2print(t[0:2, :])  # 访问前两行
3print(t[:, 1:3])  # 访问每行的第二列和第三列
4print(t[1:, 1:])  # 访问除了第一行和第一列外的所有元素

布尔索引(Boolean Indexing)

布尔索引允许您通过一个布尔掩码来选择张量中的元素:

1t = torch.tensor([1, 2, 3, 4, 5])
2mask = t > 3
3print(mask)  # 输出一个布尔张量 [False False False  True  True]
4print(t[mask])  # 选择大于3的元素 [4 5]

高级索引(Advanced Indexing)

高级索引允许您使用整数数组来进行索引,这可以用来选择特定位置的元素:

1t = torch.tensor([[1, 2], [3, 4], [5, 6]])
2rows = torch.tensor([0, 2])
3cols = torch.tensor([1, 0])
4print(t[rows, cols])  # 选择第一行第二列和第三行第一列的元素 [2, 5]

联合索引

您可以结合使用上述不同的索引方式来完成更复杂的任务。例如,您可以同时使用切片和布尔索引:

1t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2print(t[1:, mask])  # 结合使用切片和布尔索引

广播机制 

在PyTorch中,广播机制允许不同形状的张量进行数学运算,例如加法、减法、乘法等。广播机制是一种自动扩展张量的形状以匹配另一个张量的形状的过程。当两个张量的形状不完全相同,但可以通过一定的规则来扩展其中一个张量的形状,使得它们的形状兼容时,就可以使用广播机制。

广播规则

广播机制遵循以下规则:

  1. 维度匹配:从最后的维度开始比较,如果两个张量的维度数量不同,那么少的那个张量在前面添加维度,使得维度数量相同。
  2. 维度相等或为1:对于每个维度,如果两个张量在这个维度上的大小相等,或者其中一个张量在这个维度上的大小为1,那么这两个张量可以广播。如果两个张量在同一个维度上的大小既不相等,也不是1,则无法进行广播。

示例

假设我们有两个张量 tensor_atensor_b,我们想要执行加法运算。

示例 1:简单的一维张量
1import torch
2
3tensor_a = torch.tensor([1, 2, 3])
4tensor_b = torch.tensor(10)
5result = tensor_a + tensor_b
6print(result)

输出结果将是:

1tensor([11, 12, 13])

在这个例子中,标量 tensor_b 被扩展成与 tensor_a 相同的形状 [10, 10, 10],然后与 tensor_a 进行逐元素的加法运算。

示例 2:多维张量
1tensor_a = torch.tensor([[1, 2], [3, 4]])
2tensor_b = torch.tensor([10, 20])
3result = tensor_a + tensor_b
4print(result)
1tensor([[11, 22],
2        [13, 24]])

在这个例子中,张量 tensor_b 被扩展成与 tensor_a 相同的形状 [[10, 20], [10, 20]],然后与 tensor_a 进行逐元素的加法运算。

不适用的情况

广播机制不适用的情况通常发生在以下几个方面:

  1. 维度不匹配:如果两个张量的形状完全不同,且没有任何维度可以扩展为相同大小,那么就不能进行广播。例如,张量 tensor_a 的形状为 (3, 4),而 tensor_b 的形状为 (5, 6),则无法通过广播机制直接进行运算。

    1tensor_a = torch.randn(3, 4)
    2tensor_b = torch.randn(5, 6)
    3# 尝试广播将导致错误
    4try:
    5    result = tensor_a + tensor_b
    6except RuntimeError as e:
    7    print(e)
  2. 非1维度不相等:如果两个张量在任何一个维度上的大小不相等,并且不为1,那么也不能进行广播。例如,张量 tensor_a 的形状为 (3, 4),而 tensor_b 的形状为 (3, 5),则无法通过广播机制直接进行运算。

    1tensor_a = torch.randn(3, 4)
    2tensor_b = torch.randn(3, 5)
    3# 尝试广播将导致错误
    4try:
    5    result = tensor_a + tensor_b
    6except RuntimeError as e:
    7    print(e)
  3. 内存消耗:对于非常大的张量,即使广播机制理论上可行,但由于需要扩展张量的大小,可能导致内存消耗过高,从而引发内存不足的问题。

通过了解这些规则和限制,你可以更好地理解和使用PyTorch中的广播机制,以有效地进行张量运算。

合并与分割 

合并

torch.cat (Concatenate)

torch.cat 用于沿着指定的维度连接多个张量。它将这些张量拼接在一起,形成一个新的张量。

语法

1torch.cat(tensors, dim=0, out=None)
  • tensors:一个张量的序列。
  • dim:沿着哪个维度进行拼接,默认为0。
  • out:可选参数,用于指定输出张量的位置。

示例

假设我们有两个张量 t1t2,我们将它们沿着不同的维度进行拼接。

1import torch
2
3# 创建两个张量
4t1 = torch.tensor([[1, 2], [3, 4]])
5t2 = torch.tensor([[5, 6], [7, 8]])
6
7print("张量 t1:")
8print(t1)
9print("张量 t2:")
10print(t2)
11
12# 沿着行(第0维)拼接
13result_cat_row = torch.cat((t1, t2), dim=0)
14print("沿着行拼接的结果:")
15print(result_cat_row)
16
17# 沿着列(第1维)拼接
18result_cat_col = torch.cat((t1, t2), dim=1)
19print("沿着列拼接的结果:")
20print(result_cat_col)

torch.stack

torch.stack 用于在给定的新维度上堆叠一系列张量。它创建一个新的维度,并将输入张量沿着这个新维度堆叠起来。

语法

1torch.stack(tensors, dim=0, out=None)
  • tensors:一个张量的序列。
  • dim:新维度的位置,默认为0。
  • out:可选参数,用于指定输出张量的位置。

示例

同样的,我们使用上面定义的 t1t2 张量来演示 torch.stack 的用法。

1# 沿着新的第0维堆叠
2result_stack_new_dim = torch.stack((t1, t2), dim=0)
3print("沿着新的第0维堆叠的结果:")
4print(result_stack_new_dim)
5
6# 沿着新的第1维堆叠
7result_stack_new_dim_1 = torch.stack((t1, t2), dim=1)
8print("沿着新的第1维堆叠的结果:")
9print(result_stack_new_dim_1)

总结

  • torch.cat:用于沿着现有维度拼接张量。输入张量必须具有相同的形状,除了要拼接的那个维度。
  • torch.stack:用于在新维度上堆叠张量。输入张量必须具有相同的形状。

这两种方法都可以用来组合多个张量,但具体使用哪种取决于你要实现的功能。如果你希望合并张量而不改变它们的维度结构,可以使用 torch.cat。如果你希望在张量之间增加一个新的维度,并在该维度上堆叠张量,应该使用 torch.stack

分割

在PyTorch中,splitchunk 都可以用来分割张量,但它们有不同的用法和语义。下面是详细的

解释和示例代码。

torch.chunk

torch.chunk 方法用于将一个张量分割成多个连续的部分。它根据指定的数量(chunks 参数)来分割张量,并返回一个张量列表。每个张量的大小尽可能相同,但如果有剩余的部分,最后一部分的大小可能会稍小一些。

语法

1torch.chunk(input, chunks, dim=0)
  • input:要分割的张量。
  • chunks:分割成的块数。
  • dim:沿着哪个维度进行分割,默认为0(即第一维)。

示例

假设我们有一个形状为 (4, 4) 的张量,并希望将其分割成两部分:

1import torch
2
3# 创建一个4x4的张量
4tensor = torch.arange(16).view(4, 4)
5print("原始张量:")
6print(tensor)
7
8# 沿着第一维度(行)分割成两部分
9result_chunk = torch.chunk(tensor, chunks=2, dim=0)
10print("分割后的张量(按照行分割成两部分):")
11for r in result_chunk:
12    print(r)
13
14# 沿着第二维度(列)分割成两部分
15result_chunk_col = torch.chunk(tensor, chunks=2, dim=1)
16print("分割后的张量(按照列分割成两部分):")
17for r in result_chunk_col:
18    print(r)

torch.split

torch.split 方法用于将一个张量按照指定的大小进行分割。它可以接受一个整数或一个整数列表作为 split_size 参数,表示每个部分的大小。如果提供了整数列表,则表示每个部分的具体大小。

语法

1torch.split(tensor, split_size, dim=0)
  • tensor:要分割的张量。
  • split_size:每个部分的大小,可以是一个整数或一个整数列表。
  • dim:沿着哪个维度进行分割,默认为0。

示例

假设我们有一个形状为 (4, 4) 的张量,并希望将其分割成大小不同的部分:

1# 沿着第一维度(行)分割,每部分大小为2
2result_split = torch.split(tensor, split_size=2, dim=0)
3print("分割后的张量(每部分大小为2):")
4for r in result_split:
5    print(r)
6
7# 沿着第一维度(行)分割,前两部分大小为2,最后一部分大小为剩余的行数
8result_split_varied = torch.split(tensor, split_size_or_sections=[2, 2], dim=0)
9print("分割后的张量(前两部分大小为2,最后一部分大小为剩余的行数):")
10for r in result_split_varied:
11    print(r)
12
13# 沿着第二维度(列)分割,每部分大小为2
14result_split_col = torch.split(tensor, split_size=2, dim=1)
15print("分割后的张量(每部分大小为2,按照列分割):")
16for r in result_split_col:
17    print(r)

总结

  • torch.chunk 主要用于根据块数分割张量,返回的每个张量大小尽可能相同。
  • torch.split 主要用于根据指定的大小分割张量,可以指定每个部分的大小,也可以指定一个列表来表示每个部分的具体大小

矩阵乘法 

在PyTorch中,矩阵相乘是一个常见的操作,用于各种线性代数任务,例如神经网络中的前向传播。PyTorch提供了几种不同的方法来执行矩阵乘法,每种方法都有其特定的用途和规则。以下是常用的几种矩阵乘法操作及其规则:

1. torch.matmul (点积)

torch.matmul 用于执行两个张量之间的矩阵乘法。如果两个张量都是二维的,则执行标准的矩阵乘法。如果其中一个或两个张量具有更多维度,则torch.matmul会在相应的维度上执行矩阵乘法。

语法

1torch.matmul(tensor1, tensor2, out=None)
  • tensor1:第一个张量。
  • tensor2:第二个张量。
  • out:可选参数,用于指定输出张量的位置。

规则

  • 如果tensor1的形状为(n, m)tensor2的形状为(m, p),则结果张量的形状为(n, p)
  • 如果tensor1的形状为(a, b, n, m)tensor2的形状为(a, b, m, p),则结果张量的形状为(a, b, n, p)

示例

1import torch
2
3# 两个二维张量
4tensor1 = torch.tensor([[1, 2], [3, 4]])
5tensor2 = torch.tensor([[5, 6], [7, 8]])
6
7# 执行矩阵乘法
8result = torch.matmul(tensor1, tensor2)
9print("结果张量:")
10print(result)

输出结果将是:

1结果张量:
2tensor([[19, 22],
3        [43, 50]])

2. torch.mm (矩阵乘法)

torch.mm 专门用于两个二维张量之间的矩阵乘法。如果张量不是二维的,将会引发错误。

语法

1torch.mm(tensor1, tensor2, out=None)
  • tensor1:第一个二维张量。
  • tensor2:第二个二维张量。
  • out:可选参数,用于指定输出张量的位置。

规则

  • tensor1的形状为(n, m)tensor2的形状为(m, p),则结果张量的形状为(n, p)

示例

1# 同样的两个二维张量
2result_mm = torch.mm(tensor1, tensor2)
3print("使用torch.mm的结果张量:")
4print(result_mm)

3. torch.bmm (批量矩阵乘法)

torch.bmm 用于执行两个三维张量之间的批量矩阵乘法。每个三维张量中的每个切片(在第一个维度上)被视为单独的矩阵。

语法

1torch.bmm(batch_tensor1, batch_tensor2, out=None)
  • batch_tensor1:第一个三维张量。
  • batch_tensor2:第二个三维张量。
  • out:可选参数,用于指定输出张量的位置。

规则

  • batch_tensor1的形状为(batch_size, n, m)batch_tensor2的形状为(batch_size, m, p),则结果张量的形状为(batch_size, n, p)

示例

1batch_tensor1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
2batch_tensor2 = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
3
4# 执行批量矩阵乘法
5result_bmm = torch.bmm(batch_tensor1, batch_tensor2)
6print("批量矩阵乘法的结果张量:")
7print(result_bmm)

总结

  • torch.matmul 是最通用的矩阵乘法操作,适用于多种形状的张量。
  • torch.mm 专门用于二维张量的矩阵乘法。
  • torch.bmm 用于执行批量矩阵乘法,适用于三维张量。

PS:a@b等同于torch.matmul 

数据处理

在PyTorch中,有许多方法可以用来处理数据,包括计算范数(norm)、最大值(max)、最小值(min)、均值(mean)、乘积(prod)、最大值索引(argmax)、最小值索引(argmin)、最大值及其索引(topk)、第k个值(kthvalue)、比较运算(大于、小于、等于)等。下面将逐一介绍这些操作及其使用方法,并提供示例代码。

1.torch.norm 函数

torch.norm 是一个通用的函数,用于计算张量的各种范数。它支持多种类型的范数计算,并且可以指定范数的类型和计算范数的维度。

基本语法

1torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
  • input:输入的张量。
  • p:范数的类型。默认为 'fro'(Frobenius范数)。可以是 'inf''-inf'、整数(例如1或2)、浮点数(例如1.5)。
  • dim:计算范数的维度。默认为 None,此时计算整个张量的范数。
  • keepdim:布尔值,如果设置为 True,则保留原维度,即输出的张量将在指定的维度上保持为1。
  • out:可选参数,用于指定输出张量的位置。
  • dtype:可选参数,指定输出张量的数据类型。
常见的范数类型
  1. L1范数 (p=1):所有元素绝对值之和。
  2. L2范数 (p=2):所有元素平方和的平方根。
  3. 无穷范数 (p='inf'):最大绝对值元素。
  4. Frobenius范数 (p='fro'):矩阵中所有元素的平方和的平方根。

1import torch
2
3# 创建一个张量
4x = torch.tensor([[1, 2], [3, 4]])
5
6# 计算Frobenius范数
7norm_fro = torch.norm(x, p='fro')
8print("Frobenius范数:", norm_fro)
9
10# 计算L2范数
11norm_l2 = torch.norm(x, p=2)
12print("L2范数:", norm_l2)

2. torch.max 和 torch.min(最大值和最小值)

用于计算张量的最大值或最小值。

示例

1# 计算整个张量的最大值
2max_val = torch.max(x)
3print("最大值:", max_val)
4
5# 计算每一列的最大值
6max_vals_per_column = torch.max(x, dim=0)
7print("每列的最大值:", max_vals_per_column.values)
8print("每列最大值的索引:", max_vals_per_column.indices)
9
10# 计算每一行的最大值
11max_vals_per_row = torch.max(x, dim=1)
12print("每行的最大值:", max_vals_per_row.values)
13print("每行最大值的索引:", max_vals_per_row.indices)
14
15# 计算整个张量的最小值
16min_val = torch.min(x)
17print("最小值:", min_val)
18
19# 计算每一列的最小值
20min_vals_per_column = torch.min(x, dim=0)
21print("每列的最小值:", min_vals_per_column.values)
22print("每列最小值的索引:", min_vals_per_column.indices)
23
24# 计算每一行的最小值
25min_vals_per_row = torch.min(x, dim=1)
26print("每行的最小值:", min_vals_per_row.values)
27print("每行最小值的索引:", min_vals_per_row.indices)

3. torch.mean(均值)

用于计算张量的平均值。

示例

1# 计算整个张量的平均值
2mean_val = torch.mean(x)
3print("平均值:", mean_val)
4
5# 计算每一列的平均值
6means_per_column = torch.mean(x, dim=0)
7print("每列的平均值:", means_per_column)
8
9# 计算每一行的平均值
10means_per_row = torch.mean(x, dim=1)
11print("每行的平均值:", means_per_row)

4. torch.prod(乘积)

用于计算张量的元素乘积。

示例

1# 计算整个张量的乘积
2prod_val = torch.prod(x)
3print("乘积:", prod_val)
4
5# 计算每一列的乘积
6prods_per_column = torch.prod(x, dim=0)
7print("每列的乘积:", prods_per_column)
8
9# 计算每一行的乘积
10prods_per_row = torch.prod(x, dim=1)
11print("每行的乘积:", prods_per_row)

5. torch.argmax 和 torch.argmin(最大值和最小值的索引)

用于找到张量中的最大值或最小值的索引。

示例

1# 找到整个张量的最大值索引
2argmax_val = torch.argmax(x)
3print("最大值索引:", argmax_val)
4
5# 找到每一列的最大值索引
6argmax_vals_per_column = torch.argmax(x, dim=0)
7print("每列最大值的索引:", argmax_vals_per_column)
8
9# 找到每一行的最大值索引
10argmax_vals_per_row = torch.argmax(x, dim=1)
11print("每行最大值的索引:", argmax_vals_per_row)
12
13# 找到整个张量的最小值索引
14argmin_val = torch.argmin(x)
15print("最小值索引:", argmin_val)
16
17# 找到每一列的最小值索引
18argmin_vals_per_column = torch.argmin(x, dim=0)
19print("每列最小值的索引:", argmin_vals_per_column)
20
21# 找到每一行的最小值索引
22argmin_vals_per_row = torch.argmin(x, dim=1)
23print("每行最小值的索引:", argmin_vals_per_row)

6. torch.topk(最大值及其索引)

用于找到张量中最大的k个值及其索引。

示例

1# 找到整个张量中最大的3个值及其索引
2topk_val, topk_indices = torch.topk(x.flatten(), k=3)
3print("最大的3个值:", topk_val)
4print("最大的3个值的索引:", topk_indices)
5
6# 找到每一列中最大的2个值及其索引
7topk_val_per_column, topk_indices_per_column = torch.topk(x, k=2, dim=0)
8print("每列最大的2个值:", topk_val_per_column)
9print("每列最大的2个值的索引:", topk_indices_per_column)
10
11# 找到每一行中最大的2个值及其索引
12topk_val_per_row, topk_indices_per_row = torch.topk(x, k=2, dim=1)
13print("每行最大的2个值:", topk_val_per_row)
14print("每行最大的2个值的索引:", topk_indices_per_row)

7. torch.kthvalue(第k个值)

用于找到张量中第k个排序后的值及其索引。

示例

1# 找到整个张量中第3个排序后的值及其索引
2kth_value, kth_index = torch.kthvalue(x.flatten(), k=3)
3print("第3个排序后的值:", kth_value)
4print("第3个排序后的值的索引:", kth_index)
5
6# 找到每一列中第2个排序后的值及其索引
7kth_value_per_column, kth_index_per_column = torch.kthvalue(x, k=2, dim=0)
8print("每列第2个排序后的值:", kth_value_per_column)
9print("每列第2个排序后的值的索引:", kth_index_per_column)
10
11# 找到每一行中第2个排序后的值及其索引
12kth_value_per_row, kth_index_per_row = torch.kthvalue(x, k=2, dim=1)
13print("每行第2个排序后的值:", kth_value_per_row)
14print("每行第2个排序后的值的索引:", kth_index_per_row)

8. 比较运算(大于、小于、等于)

用于比较张量中的元素。

1# 创建另一个张量用于比较
2y = torch.tensor([[2, 2], [2, 2]])
3
4# 大于运算
5greater_than = x > y
6print("大于比较结果:", greater_than)
7
8# 小于运算
9less_than = x < y
10print("小于比较结果:", less_than)
11
12# 等于运算
13equal_to = x == y
14print("等于比较结果:", equal_to)

9. keepdim 参数

在使用 torch.maxtorch.mean 等方法时,可以设置 keepdim=True 来保持原维度。

示例

1# 计算每一列的平均值,并保持原维度
2means_per_column_keepdim = torch.mean(x, dim=0, keepdim=True)
3print("每列的平均值(保持维度):", means_per_column_keepdim)
4
5# 计算每一行的平均值,并保持原维度
6means_per_row_keepdim = torch.mean(x, dim=1, keepdim=True)
7print("每行的平均值(保持维度):", means_per_row_keepdim)

这些方法和操作在实际的数据处理和分析中非常有用,可以帮助你更好地理解和操作张量数据。通过这些操作,你可以执行复杂的统计分析、特征工程以及其他数据预处理任务。

在PyTorch中,torch.where 是一个非常有用的函数,用于根据条件选择元素。它类似于NumPy中的 np.where 函数,可以在条件表达式为真时选择一个张量中的值,否则选择另一个张量中的值。这对于基于条件进行数据筛选或修改非常有用。



高级用法(where,gather)

torch.where 函数的基本语法

torch.where 的基本语法如下:

1torch.where(condition, x, y)
  • condition:一个布尔张量,用于指定选择 x 或 y 中的元素。
  • x:当 condition 为 True 时选择的张量。
  • y:当 condition 为 False 时选择的张量。

示例

下面通过一些具体的例子来说明 torch.where 的用法。

示例 1:基于条件选择元素

假设我们有两个张量 AB,我们想要根据某个条件来选择 AB 中的元素。

 
1import torch
2
3# 创建两个张量
4A = torch.tensor([[1, 2], [3, 4]])
5B = torch.tensor([[5, 6], [7, 8]])
6
7# 定义一个条件张量
8condition = A > 2
9
10# 使用 torch.where 根据条件选择元素
11result = torch.where(condition, A, B)
12
13print("条件张量 condition:")
14print(condition)
15print("结果张量 result:")
16print(result)

输出将是:

1条件张量 condition:
2tensor([[False, False],
3        [ True,  True]])
4结果张量 result:
5tensor([[5, 6],
6        [3, 4]])

在这个例子中,当 A 中的元素大于2时,选择 A 中的元素;否则选择 B 中对应的元素。

示例 2:基于条件修改元素

假设我们有一个张量 C,我们想要将其中大于某个阈值的元素替换为另一个值。

 

python

深色版本

1# 创建一个张量
2C = torch.tensor([[1, 2], [3, 4]])
3
4# 定义一个条件
5threshold = 2
6condition = C > threshold
7
8# 将大于阈值的元素设置为 10
9modified_C = torch.where(condition, 10, C)
10
11print("原始张量 C:")
12print(C)
13print("修改后的张量 modified_C:")
14print(modified_C)

输出将是:

1原始张量 C:
2tensor([[1, 2],
3        [3, 4]])
4修改后的张量 modified_C:
5tensor([[1, 2],
6        [10, 10]])

在这个例子中,当 C 中的元素大于2时,将其替换为10;否则保持不变。

多维张量示例

假设我们有一个三维张量 D,我们想要根据某个条件来选择元素。

1# 创建一个三维张量
2D = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
3
4# 定义一个条件
5condition = D > 4
6
7# 将大于4的元素设置为 10
8modified_D = torch.where(condition, 10, D)
9
10print("原始三维张量 D:")
11print(D)
12print("修改后的三维张量 modified_D:")
13print(modified_D)

输出将是:

 
1原始三维张量 D:
2tensor([[[1, 2],
3         [3, 4]],
4        [[5, 6],
5         [7, 8]]])
6修改后的三维张量 modified_D:
7tensor([[[ 1,  2],
8         [ 3,  4]],
9        [[10, 10],
10         [10, 10]]])

在这个例子中,当 D 中的元素大于4时,将其替换为10;否则保持不变。

总结

torch.where 是一个非常强大的函数,可以用于基于条件选择张量中的元素或修改张量中的元素。它广泛应用于数据处理、特征工程以及机器学习任务中,帮助实现条件选择和数据转换等功能。通过上述示例,你应该能够更好地理解和应用 torch.where 函数。

在PyTorch中,torch.gather 是一个用于从源张量中收集元素的方法,根据索引张量来选取元素。它常用于根据索引选择特定维度上的元素,这对于实现某些复杂的索引和选择操作非常有用。

torch.gather 的基本语法

torch.gather 的基本语法如下:

1torch.gather(input, dim, index, out=None)
  • input:源张量,从中提取元素。
  • dim:沿着哪个维度进行收集。
  • index:索引张量,指示从源张量的哪个位置提取元素。
  • out:可选参数,用于指定输出张量的位置。

示例

下面通过一些具体的例子来说明 torch.gather 的用法。

示例 1:一维索引

假设我们有一个二维张量 A,我们想要根据另一个索引张量 idx 来选择 A 中的元素。

1import torch
2
3# 创建一个二维张量
4A = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
5
6# 创建一个索引张量
7idx = torch.tensor([[0, 2], [1, 0], [2, 1]])
8
9# 沿着第一个维度收集元素
10result = torch.gather(A, 0, idx)
11
12print("源张量 A:")
13print(A)
14print("索引张量 idx:")
15print(idx)
16print("结果张量 result:")
17print(result)

输出将是:

1源张量 A:
2tensor([[1, 2, 3],
3        [4, 5, 6],
4        [7, 8, 9]])
5索引张量 idx:
6tensor([[0, 2],
7        [1, 0],
8        [2, 1]])
9结果张量 result:
10tensor([[ 1,  3],
11        [ 4,  4],
12        [ 7,  8]])

在这个例子中,我们沿着第一个维度(即行)进行收集。idx 张量中的每个元素指示从 A 的对应行中选择哪个元素。例如,idx[0, 0] 的值为0,意味着从 A 的第一行(即索引为0的行)选择第一个元素(即索引为0的元素);idx[0, 1] 的值为2,意味着从 A 的第一行选择第三个元素(即索引为2的元素)。

示例 2:二维索引

假设我们有一个三维张量 B,我们想要根据索引张量 idx 来选择 B 中的元素。

1# 创建一个三维张量
2B = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]])
3
4# 创建一个索引张量
5idx = torch.tensor([[0, 1], [1, 0], [1, 1]])
6
7# 沿着第二个维度收集元素
8result = torch.gather(B, 1, idx)
9
10print("源张量 B:")
11print(B)
12print("索引张量 idx:")
13print(idx)
14print("结果张量 result:")
15print(result)

输出将是:

1源张量 B:
2tensor([[[ 1,  2],
3         [ 3,  4]],
4        [[ 5,  6],
5         [ 7,  8]],
6        [[ 9, 10],
7         [11, 12]]])
8索引张量 idx:
9tensor([[0, 1],
10        [1, 0],
11        [1, 1]])
12结果张量 result:
13tensor([[[ 1,  4],
14         [ 6,  5],
15         [10, 12]]])

在这个例子中,我们沿着第二个维度(即列)进行收集。idx 张量中的每个元素指示从 B 的对应行中选择哪个列的元素。

实际应用示例

示例 3:在自然语言处理中的应用

在自然语言处理中,torch.gather 常用于根据预测的概率分布来选择单词。假设我们有一个批处理的句子,每个句子都有一个预测的概率分布,我们想要根据每个句子的概率分布选择最可能的单词。

1# 假设我们有一个批处理的句子,每个句子都有一个预测的概率分布
2probs = torch.tensor([
3    [0.1, 0.2, 0.3, 0.4],
4    [0.4, 0.1, 0.2, 0.3],
5    [0.2, 0.3, 0.1, 0.4]
6])
7
8# 获取每个句子中概率最高的单词的索引
9max_prob_indices = torch.argmax(probs, dim=1, keepdim=True)
10
11# 选择每个句子中最可能的单词
12selected_words = torch.gather(probs, 1, max_prob_indices)
13
14print("概率分布 probs:")
15print(probs)
16print("概率最高的单词的索引 max_prob_indices:")
17print(max_prob_indices)
18print("选择的单词 selected_words:")
19print(selected_words)

输出将是:

1概率分布 probs:
2tensor([[0.10, 0.20, 0.30, 0.40],
3        [0.40, 0.10, 0.20, 0.30],
4        [0.20, 0.30, 0.10, 0.40]])
5概率最高的单词的索引 max_prob_indices:
6tensor([[3],
7        [0],
8        [3]])
9选择的单词 selected_words:
10tensor([[0.40],
11        [0.40],
12        [0.40]])

在这个例子中,我们首先找到了每个句子中概率最高的单词的索引,然后使用 torch.gather 选择了每个句子中最可能的单词。

总结

torch.gather 是一个非常有用的函数,可以用于从张量中根据索引选择特定的元素。通过指定索引张量和维度,你可以实现复杂的数据选择和索引操作。这对于数据处理、特征工程以及模型实现都非常重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1527660.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

基于javaweb的茶园茶农文化交流平台的设计与实现(源码+L文+ppt)

springboot基于javaweb的茶园茶农文化交流平台的设计与实现&#xff08;源码L文ppt&#xff09;4-20 系统功能结构 系统结构图可以把杂乱无章的模块按照设计者的思维方式进行调整排序&#xff0c;可以让设计者在之后的添加&#xff0c;修改程序内容的过程中有一个很明显的思维…

业务资源管理模式语言10

示例&#xff1a; 图15 表示RentTheResource 模式的一个实例&#xff0c;在一个录像带出租系统中&#xff0c;其中“Videotape&#xff08;录像带&#xff09;”扮演“Resource&#xff08;资源&#xff09;”&#xff0c;“Video Rental&#xff08;录像带出租&#xff09;”…

[数据集][目标检测]血细胞检测数据集VOC+YOLO格式2757张4类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;2757 标注数量(xml文件个数)&#xff1a;2757 标注数量(txt文件个数)&#xff1a;2757 标注…

浙大数据结构:02-线性结构4 Pop Sequence

这道题我们采用数组来模拟堆栈和队列。 简单说一下大致思路&#xff0c;我们用栈来存1234.....&#xff0c;队列来存输入的一组数据&#xff0c;栈与队列进行匹配&#xff0c;相同就pop 机翻 1、条件准备 stk是栈&#xff0c;que是队列。 tt指向的是栈中下标&#xff0c;fr…

【系统设计】主动查询与主动推送:如何选择合适的数据传输策略

基本描述总结 主动查询机制&#xff1a;系统A主动向系统B请求数据&#xff0c;采用严格的权限控制和身份认证&#xff0c;防止未授权的数据访问。数据在传输过程中使用TLS加密&#xff0c;并通过动态脱敏处理隐藏敏感信息。 推送机制&#xff1a;系统B在数据更新时主动向系统…

Java并发编程实战 05 | 什么是线程组?

1.线程组介绍 在 Java 中&#xff0c;ThreadGroup 用于表示一组线程。通过 ThreadGroup&#xff0c;我们可以批量控制和管理多个线程&#xff0c;使得线程管理更加方便。 ThreadGroup 和 Thread 的关系就像它们的字面意思一样简单&#xff1a;每个线程 (Thread) 必定属于一个线…

基于R语言的统计分析基础:操作XML文件与YAML文件

XML文件简介 在处理从文本文件中读取数据的任务时&#xff0c;用户承担着至关重要的责任&#xff0c;即需要充分理解和明确指定在创建该文件时所遵循的一系列约定和规范。这些约定涵盖了多个方面&#xff0c;包括但不限于&#xff1a; 注释字符&#xff1a;识别并忽略文件中用…

kubeadm 初始化 k8s 证书过期解决方案

概述 在使用 kubeadm 初始化的 Kubernetes 集群中&#xff0c;默认情况下证书的有效期为一年。当证书过期时&#xff0c;集群中的某些组件可能会停止工作&#xff0c;导致集群不可用。本文将详细介绍如何解决 kubeadm 初始化的 Kubernetes 集群证书过期的问题&#xff0c;并提…

CSP-J基础之常见的竞赛题库

文章目录 CSP-J基础之常见的竞赛题库1. 可达 (KEDA)2. 洛谷 (Luogu)3. Codeforces 洛谷账号的注册总结 CSP-J基础之常见的竞赛题库 在备战CSP-J&#xff08;Certified Software Professional Junior&#xff09;及其他信息学竞赛时&#xff0c;选手们常需要借助在线题库来进行…

android framework工程师遇到成长瓶颈迷茫怎么办?千里马经验分享

背景 近来有一些framework老司机粉丝朋友发来了一些framework工作中的一些疑问&#xff0c;具体描述如下&#xff1a; 这个同学遇到的问题&#xff0c;其实就是大部分framework开发者工作比较久后遇到的一个上升瓶颈问题。 具体总结有以下几个瓶颈问题 1、framework属于系…

Clion不识别C代码或者无法跳转C语言项目怎么办?

如果是中文会显示&#xff1a; 此时只需要右击项目&#xff0c;或者你的源代码目录&#xff0c;将这个项目或者源码目录标记为项目源和头文件即可。 英文如下&#xff1a;

STM32内部闪存FLASH(内部ROM)、IAP

1 FLASH简介 1 利用程序存储器的剩余空间来保存掉电不丢失的用户数据 2 通过在程序中编程(IAP)实现程序的自我更新 &#xff08;OTA&#xff09; 3在线编程&#xff08;ICP把整个程序都更新掉&#xff09; 1 系统的Bootloader写死了&#xff0c;只能用串口下载到指定的位置&a…

【MacOS】mac定位服务中删除已经卸载的软件

mac定位服务中删除已经卸载的软件 网上的帖子真不靠谱 直接右键 WeTypeSettings &#xff0c;查找位置&#xff0c;丢废纸篓即可&#xff01;会提示你卸载的&#xff01;

VLAN原理学习笔记

以太网是一种基于CSMA/CD的数据网络通信技术&#xff0c;其特征是共享通信介质。当主机数目较多时会导致安全隐患、广播泛滥、性能显著下降甚至造成网络不可用。 在这种情况下出现了VLAN (Virtual Local Area Network)技术解决以上问题。 1、VLAN快速配置 Vlan:Virtual Local…

C和指针:结构体(struct)和联合(union)

结构体和联合 结构体 结构体包含一些数据成员&#xff0c;每个成员可能具有不同的类型。 数组的元素长度相同&#xff0c;可以通过下标访问(转换为指针)。但是结构体的成员可能长度不同&#xff0c;所以不能用下标来访问它们。成员有自己的名字&#xff0c;可以通过名字访问…

springboot流浪天使乐园管理系统

基于springbootvue实现的流浪天使乐园管理系统&#xff08;源码L文ppt&#xff09;4-039 第4章 系统设计 4.1 总体功能设计 一般个人用户和管理者都需要登录才能进入流浪天使乐园管理系统&#xff0c;使用者登录时会在后台判断使用的权限类型&#xff0c;包括一般使用者…

以太网交换机工作原理学习笔记

在网络中传输数据时需要遵循一些标准&#xff0c;以太网协议定义了数据帧在以太网上的传输标准&#xff0c;了解以太网协议是充分理解数据链路层通信的基础。以太网交换机是实现数据链路层通信的主要设备&#xff0c;了解以太网交换机的工作原理也是十分必要的。 1、以太网协议…

Qt/C++编写的Onvif调试助手调试神器工具/支持云台控制/预置位设置等/有手机版本

一、功能特点 广播搜索设备&#xff0c;支持IPC和NVR&#xff0c;依次返回。可选择不同的网卡IP进行对应网段设备的搜索。依次获取Onvif地址、Media地址、Profile文件、Rtsp地址。可对指定的Profile获取视频流Rtsp地址&#xff0c;比如主码流地址、子码流地址。可对每个设备设…

ESP32_获取心知天气

目录 前言 一、获取心知天气API 二、编写代码 1.下载代码 2.代码讲解 1.安装Arduino.Json库 2.输入WIFI名称和密码 3.添加API 4.关于API的补充 三.数据的打印和处理 1.获取的数据 2.数据输出 总结 前言 环境&#xff1a;Arduino 芯片&#xff1a;ESP32 软件&…

基于springboot+vue实现的农家乐管理系统

基于springbootvue实现的山庄农家乐管理系统前后端分离项目&#xff08;文末查看源码lw&#xff09;4-10 系统角色&#xff1a; 管理员、用户 主要功能&#xff1a; &#xff08;1&#xff09;用户关键功能包含用户注册登陆、个人信息修改、首页、农家乐、美食信息、民宿信息…