目录
张量
正态分布与随机分布:
张量的尺寸:
张量的基本属性:
创建张量:
张量构造器:
其他张量创建方法:
索引与切片
索引(Indexing)
单维张量索引
多维张量索引
切片(Slicing)
单维张量切片
多维张量切片
布尔索引(Boolean Indexing)
高级索引(Advanced Indexing)
联合索引
广播机制
广播规则
示例
示例 1:简单的一维张量
示例 2:多维张量
不适用的情况
合并与分割
合并
torch.cat (Concatenate)
torch.stack
总结
分割
torch.chunk
torch.split
总结
矩阵乘法
1. torch.matmul (点积)
2. torch.mm (矩阵乘法)
3. torch.bmm (批量矩阵乘法)
总结
数据处理
1.torch.norm 函数
常见的范数类型
2. torch.max 和 torch.min(最大值和最小值)
3. torch.mean(均值)
4. torch.prod(乘积)
5. torch.argmax 和 torch.argmin(最大值和最小值的索引)
6. torch.topk(最大值及其索引)
7. torch.kthvalue(第k个值)
8. 比较运算(大于、小于、等于)
9. keepdim 参数
高级用法(where,gather)
torch.where 函数的基本语法
示例
示例 1:基于条件选择元素
示例 2:基于条件修改元素
多维张量示例
总结
torch.gather 的基本语法
示例
示例 1:一维索引
示例 2:二维索引
实际应用示例
示例 3:在自然语言处理中的应用
总结
张量
-
正态分布与随机分布:
torch.randn(size)
生成一个给定形状size
的张量,其中每个元素都来自一个标准正态分布(均值为0,标准差为1)。注意这里的元素值不是限制在0-1之间的。torch.rand(size)
生成一个给定形状size
的张量,其中每个元素都来自一个0到1之间的均匀分布。
-
张量的尺寸:
a.size()
和a.shape
都是用来获取张量尺寸的方法,其中a.shape
是属性而a.size()
是一个方法。两者都会返回一个元组,表示张量的各个维度大小。a.size(1)
或a.shape[1]
获取张量的第二个维度的长度(在Python中索引是从0开始的,因此第二个维度的索引是1)。
-
张量的基本属性:
a.numel()
返回张量中元素的总数。a.dim()
返回张量的维度数。
-
创建张量:
torch.tensor(5)
创建一个标量张量,即一个0维张量,其值为5。torch.from_numpy(a)
将NumPy数组转换为PyTorch张量,其中a
是NumPy数组。
-
张量构造器:
torch.Tensor
通常是一个构造函数,用于创建一个新的张量。如果没有提供参数,默认创建一个浮点张量。torch.full(size, fill_value)
创建一个给定形状size
的张量,并用fill_value
填充所有元素。torch.ones(size)
创建一个给定形状size
的张量,所有元素均为1。torch.zeros(size)
创建一个给定形状size
的张量,所有元素均为0。torch.eye(n)
创建一个n×n的单位矩阵,适用于二维张量。
-
其他张量创建方法:
torch.arange(start, end, step)
生成一个从start
到end
(不包括end
)的等差序列。torch.linspace(start, end, steps)
在start
和end
之间生成steps
个等间距的样本,包括start
和end
。torch.randperm(n)
返回一个长度为n的张量,包含从0到n-1的随机排列。
请注意:
torch.range
已经被弃用,推荐使用torch.arange
。torch.normal
方法用于生成正态分布的样本,而不是压缩张量。torch.randperm
可以用来生成一个随机排列的长整型张量。
索引与切片
索引(Indexing)
索引允许您通过指定一个或多个索引来访问张量中的单个元素或多维子集。
单维张量索引
对于一维张量,索引就像列表一样简单:
1import torch
2
3t = torch.tensor([1, 2, 3, 4])
4print(t[0]) # 访问第一个元素
5print(t[-1]) # 访问最后一个元素
多维张量索引
对于多维张量,您需要为每个维度提供一个索引:
1t = torch.tensor([[1, 2], [3, 4]])
2print(t[0, 1]) # 访问第一行第二列的元素
切片(Slicing)
切片允许您访问张量的一个连续子集。切片可以通过提供起始位置、结束位置和步长来实现。
单维张量切片
1t = torch.tensor([1, 2, 3, 4, 5])
2print(t[1:3]) # 访问索引1到2的元素 (不包括3)
3print(t[:3]) # 访问前三个元素
4print(t[3:]) # 访问从索引3到最后的所有元素
5print(t[::2]) # 每隔一个元素取一个
多维张量切片
对于多维张量,可以在每个维度上进行切片:
1t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2print(t[0:2, :]) # 访问前两行
3print(t[:, 1:3]) # 访问每行的第二列和第三列
4print(t[1:, 1:]) # 访问除了第一行和第一列外的所有元素
布尔索引(Boolean Indexing)
布尔索引允许您通过一个布尔掩码来选择张量中的元素:
1t = torch.tensor([1, 2, 3, 4, 5])
2mask = t > 3
3print(mask) # 输出一个布尔张量 [False False False True True]
4print(t[mask]) # 选择大于3的元素 [4 5]
高级索引(Advanced Indexing)
高级索引允许您使用整数数组来进行索引,这可以用来选择特定位置的元素:
1t = torch.tensor([[1, 2], [3, 4], [5, 6]])
2rows = torch.tensor([0, 2])
3cols = torch.tensor([1, 0])
4print(t[rows, cols]) # 选择第一行第二列和第三行第一列的元素 [2, 5]
联合索引
您可以结合使用上述不同的索引方式来完成更复杂的任务。例如,您可以同时使用切片和布尔索引:
1t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
2print(t[1:, mask]) # 结合使用切片和布尔索引
广播机制
在PyTorch中,广播机制允许不同形状的张量进行数学运算,例如加法、减法、乘法等。广播机制是一种自动扩展张量的形状以匹配另一个张量的形状的过程。当两个张量的形状不完全相同,但可以通过一定的规则来扩展其中一个张量的形状,使得它们的形状兼容时,就可以使用广播机制。
广播规则
广播机制遵循以下规则:
- 维度匹配:从最后的维度开始比较,如果两个张量的维度数量不同,那么少的那个张量在前面添加维度,使得维度数量相同。
- 维度相等或为1:对于每个维度,如果两个张量在这个维度上的大小相等,或者其中一个张量在这个维度上的大小为1,那么这两个张量可以广播。如果两个张量在同一个维度上的大小既不相等,也不是1,则无法进行广播。
示例
假设我们有两个张量 tensor_a
和 tensor_b
,我们想要执行加法运算。
示例 1:简单的一维张量
1import torch
2
3tensor_a = torch.tensor([1, 2, 3])
4tensor_b = torch.tensor(10)
5result = tensor_a + tensor_b
6print(result)
输出结果将是:
1tensor([11, 12, 13])
在这个例子中,标量 tensor_b
被扩展成与 tensor_a
相同的形状 [10, 10, 10]
,然后与 tensor_a
进行逐元素的加法运算。
示例 2:多维张量
1tensor_a = torch.tensor([[1, 2], [3, 4]])
2tensor_b = torch.tensor([10, 20])
3result = tensor_a + tensor_b
4print(result)
1tensor([[11, 22],
2 [13, 24]])
在这个例子中,张量 tensor_b
被扩展成与 tensor_a
相同的形状 [[10, 20], [10, 20]]
,然后与 tensor_a
进行逐元素的加法运算。
不适用的情况
广播机制不适用的情况通常发生在以下几个方面:
-
维度不匹配:如果两个张量的形状完全不同,且没有任何维度可以扩展为相同大小,那么就不能进行广播。例如,张量
tensor_a
的形状为(3, 4)
,而tensor_b
的形状为(5, 6)
,则无法通过广播机制直接进行运算。1tensor_a = torch.randn(3, 4) 2tensor_b = torch.randn(5, 6) 3# 尝试广播将导致错误 4try: 5 result = tensor_a + tensor_b 6except RuntimeError as e: 7 print(e)
-
非1维度不相等:如果两个张量在任何一个维度上的大小不相等,并且不为1,那么也不能进行广播。例如,张量
tensor_a
的形状为(3, 4)
,而tensor_b
的形状为(3, 5)
,则无法通过广播机制直接进行运算。1tensor_a = torch.randn(3, 4) 2tensor_b = torch.randn(3, 5) 3# 尝试广播将导致错误 4try: 5 result = tensor_a + tensor_b 6except RuntimeError as e: 7 print(e)
-
内存消耗:对于非常大的张量,即使广播机制理论上可行,但由于需要扩展张量的大小,可能导致内存消耗过高,从而引发内存不足的问题。
通过了解这些规则和限制,你可以更好地理解和使用PyTorch中的广播机制,以有效地进行张量运算。
合并与分割
合并
torch.cat
(Concatenate)
torch.cat
用于沿着指定的维度连接多个张量。它将这些张量拼接在一起,形成一个新的张量。
语法:
1torch.cat(tensors, dim=0, out=None)
tensors
:一个张量的序列。dim
:沿着哪个维度进行拼接,默认为0。out
:可选参数,用于指定输出张量的位置。
示例:
假设我们有两个张量 t1
和 t2
,我们将它们沿着不同的维度进行拼接。
1import torch
2
3# 创建两个张量
4t1 = torch.tensor([[1, 2], [3, 4]])
5t2 = torch.tensor([[5, 6], [7, 8]])
6
7print("张量 t1:")
8print(t1)
9print("张量 t2:")
10print(t2)
11
12# 沿着行(第0维)拼接
13result_cat_row = torch.cat((t1, t2), dim=0)
14print("沿着行拼接的结果:")
15print(result_cat_row)
16
17# 沿着列(第1维)拼接
18result_cat_col = torch.cat((t1, t2), dim=1)
19print("沿着列拼接的结果:")
20print(result_cat_col)
torch.stack
torch.stack
用于在给定的新维度上堆叠一系列张量。它创建一个新的维度,并将输入张量沿着这个新维度堆叠起来。
语法:
1torch.stack(tensors, dim=0, out=None)
tensors
:一个张量的序列。dim
:新维度的位置,默认为0。out
:可选参数,用于指定输出张量的位置。
示例:
同样的,我们使用上面定义的 t1
和 t2
张量来演示 torch.stack
的用法。
1# 沿着新的第0维堆叠
2result_stack_new_dim = torch.stack((t1, t2), dim=0)
3print("沿着新的第0维堆叠的结果:")
4print(result_stack_new_dim)
5
6# 沿着新的第1维堆叠
7result_stack_new_dim_1 = torch.stack((t1, t2), dim=1)
8print("沿着新的第1维堆叠的结果:")
9print(result_stack_new_dim_1)
总结
torch.cat
:用于沿着现有维度拼接张量。输入张量必须具有相同的形状,除了要拼接的那个维度。torch.stack
:用于在新维度上堆叠张量。输入张量必须具有相同的形状。
这两种方法都可以用来组合多个张量,但具体使用哪种取决于你要实现的功能。如果你希望合并张量而不改变它们的维度结构,可以使用 torch.cat
。如果你希望在张量之间增加一个新的维度,并在该维度上堆叠张量,应该使用 torch.stack
。
分割
在PyTorch中,split
和 chunk
都可以用来分割张量,但它们有不同的用法和语义。下面是详细的
解释和示例代码。
torch.chunk
torch.chunk
方法用于将一个张量分割成多个连续的部分。它根据指定的数量(chunks
参数)来分割张量,并返回一个张量列表。每个张量的大小尽可能相同,但如果有剩余的部分,最后一部分的大小可能会稍小一些。
语法:
1torch.chunk(input, chunks, dim=0)
input
:要分割的张量。chunks
:分割成的块数。dim
:沿着哪个维度进行分割,默认为0(即第一维)。
示例:
假设我们有一个形状为 (4, 4)
的张量,并希望将其分割成两部分:
1import torch
2
3# 创建一个4x4的张量
4tensor = torch.arange(16).view(4, 4)
5print("原始张量:")
6print(tensor)
7
8# 沿着第一维度(行)分割成两部分
9result_chunk = torch.chunk(tensor, chunks=2, dim=0)
10print("分割后的张量(按照行分割成两部分):")
11for r in result_chunk:
12 print(r)
13
14# 沿着第二维度(列)分割成两部分
15result_chunk_col = torch.chunk(tensor, chunks=2, dim=1)
16print("分割后的张量(按照列分割成两部分):")
17for r in result_chunk_col:
18 print(r)
torch.split
torch.split
方法用于将一个张量按照指定的大小进行分割。它可以接受一个整数或一个整数列表作为 split_size
参数,表示每个部分的大小。如果提供了整数列表,则表示每个部分的具体大小。
语法:
1torch.split(tensor, split_size, dim=0)
tensor
:要分割的张量。split_size
:每个部分的大小,可以是一个整数或一个整数列表。dim
:沿着哪个维度进行分割,默认为0。
示例:
假设我们有一个形状为 (4, 4)
的张量,并希望将其分割成大小不同的部分:
1# 沿着第一维度(行)分割,每部分大小为2
2result_split = torch.split(tensor, split_size=2, dim=0)
3print("分割后的张量(每部分大小为2):")
4for r in result_split:
5 print(r)
6
7# 沿着第一维度(行)分割,前两部分大小为2,最后一部分大小为剩余的行数
8result_split_varied = torch.split(tensor, split_size_or_sections=[2, 2], dim=0)
9print("分割后的张量(前两部分大小为2,最后一部分大小为剩余的行数):")
10for r in result_split_varied:
11 print(r)
12
13# 沿着第二维度(列)分割,每部分大小为2
14result_split_col = torch.split(tensor, split_size=2, dim=1)
15print("分割后的张量(每部分大小为2,按照列分割):")
16for r in result_split_col:
17 print(r)
总结
torch.chunk
主要用于根据块数分割张量,返回的每个张量大小尽可能相同。torch.split
主要用于根据指定的大小分割张量,可以指定每个部分的大小,也可以指定一个列表来表示每个部分的具体大小
矩阵乘法
在PyTorch中,矩阵相乘是一个常见的操作,用于各种线性代数任务,例如神经网络中的前向传播。PyTorch提供了几种不同的方法来执行矩阵乘法,每种方法都有其特定的用途和规则。以下是常用的几种矩阵乘法操作及其规则:
1. torch.matmul
(点积)
torch.matmul
用于执行两个张量之间的矩阵乘法。如果两个张量都是二维的,则执行标准的矩阵乘法。如果其中一个或两个张量具有更多维度,则torch.matmul
会在相应的维度上执行矩阵乘法。
语法:
1torch.matmul(tensor1, tensor2, out=None)
tensor1
:第一个张量。tensor2
:第二个张量。out
:可选参数,用于指定输出张量的位置。
规则:
- 如果
tensor1
的形状为(n, m)
,tensor2
的形状为(m, p)
,则结果张量的形状为(n, p)
。 - 如果
tensor1
的形状为(a, b, n, m)
,tensor2
的形状为(a, b, m, p)
,则结果张量的形状为(a, b, n, p)
。
示例:
1import torch
2
3# 两个二维张量
4tensor1 = torch.tensor([[1, 2], [3, 4]])
5tensor2 = torch.tensor([[5, 6], [7, 8]])
6
7# 执行矩阵乘法
8result = torch.matmul(tensor1, tensor2)
9print("结果张量:")
10print(result)
输出结果将是:
1结果张量:
2tensor([[19, 22],
3 [43, 50]])
2. torch.mm
(矩阵乘法)
torch.mm
专门用于两个二维张量之间的矩阵乘法。如果张量不是二维的,将会引发错误。
语法:
1torch.mm(tensor1, tensor2, out=None)
tensor1
:第一个二维张量。tensor2
:第二个二维张量。out
:可选参数,用于指定输出张量的位置。
规则:
tensor1
的形状为(n, m)
,tensor2
的形状为(m, p)
,则结果张量的形状为(n, p)
。
示例:
1# 同样的两个二维张量
2result_mm = torch.mm(tensor1, tensor2)
3print("使用torch.mm的结果张量:")
4print(result_mm)
3. torch.bmm
(批量矩阵乘法)
torch.bmm
用于执行两个三维张量之间的批量矩阵乘法。每个三维张量中的每个切片(在第一个维度上)被视为单独的矩阵。
语法:
1torch.bmm(batch_tensor1, batch_tensor2, out=None)
batch_tensor1
:第一个三维张量。batch_tensor2
:第二个三维张量。out
:可选参数,用于指定输出张量的位置。
规则:
batch_tensor1
的形状为(batch_size, n, m)
,batch_tensor2
的形状为(batch_size, m, p)
,则结果张量的形状为(batch_size, n, p)
。
示例:
1batch_tensor1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
2batch_tensor2 = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
3
4# 执行批量矩阵乘法
5result_bmm = torch.bmm(batch_tensor1, batch_tensor2)
6print("批量矩阵乘法的结果张量:")
7print(result_bmm)
总结
torch.matmul
是最通用的矩阵乘法操作,适用于多种形状的张量。torch.mm
专门用于二维张量的矩阵乘法。torch.bmm
用于执行批量矩阵乘法,适用于三维张量。
PS:a@b等同于torch.matmul
数据处理
在PyTorch中,有许多方法可以用来处理数据,包括计算范数(norm)、最大值(max)、最小值(min)、均值(mean)、乘积(prod)、最大值索引(argmax)、最小值索引(argmin)、最大值及其索引(topk)、第k个值(kthvalue)、比较运算(大于、小于、等于)等。下面将逐一介绍这些操作及其使用方法,并提供示例代码。
1.torch.norm
函数
torch.norm
是一个通用的函数,用于计算张量的各种范数。它支持多种类型的范数计算,并且可以指定范数的类型和计算范数的维度。
基本语法:
1torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
input
:输入的张量。p
:范数的类型。默认为'fro'
(Frobenius范数)。可以是'inf'
、'-inf'
、整数(例如1或2)、浮点数(例如1.5)。dim
:计算范数的维度。默认为None
,此时计算整个张量的范数。keepdim
:布尔值,如果设置为True
,则保留原维度,即输出的张量将在指定的维度上保持为1。out
:可选参数,用于指定输出张量的位置。dtype
:可选参数,指定输出张量的数据类型。
常见的范数类型
- L1范数 (
p=1
):所有元素绝对值之和。 - L2范数 (
p=2
):所有元素平方和的平方根。 - 无穷范数 (
p='inf'
):最大绝对值元素。 - Frobenius范数 (
p='fro'
):矩阵中所有元素的平方和的平方根。
1import torch
2
3# 创建一个张量
4x = torch.tensor([[1, 2], [3, 4]])
5
6# 计算Frobenius范数
7norm_fro = torch.norm(x, p='fro')
8print("Frobenius范数:", norm_fro)
9
10# 计算L2范数
11norm_l2 = torch.norm(x, p=2)
12print("L2范数:", norm_l2)
2. torch.max
和 torch.min
(最大值和最小值)
用于计算张量的最大值或最小值。
示例:
1# 计算整个张量的最大值
2max_val = torch.max(x)
3print("最大值:", max_val)
4
5# 计算每一列的最大值
6max_vals_per_column = torch.max(x, dim=0)
7print("每列的最大值:", max_vals_per_column.values)
8print("每列最大值的索引:", max_vals_per_column.indices)
9
10# 计算每一行的最大值
11max_vals_per_row = torch.max(x, dim=1)
12print("每行的最大值:", max_vals_per_row.values)
13print("每行最大值的索引:", max_vals_per_row.indices)
14
15# 计算整个张量的最小值
16min_val = torch.min(x)
17print("最小值:", min_val)
18
19# 计算每一列的最小值
20min_vals_per_column = torch.min(x, dim=0)
21print("每列的最小值:", min_vals_per_column.values)
22print("每列最小值的索引:", min_vals_per_column.indices)
23
24# 计算每一行的最小值
25min_vals_per_row = torch.min(x, dim=1)
26print("每行的最小值:", min_vals_per_row.values)
27print("每行最小值的索引:", min_vals_per_row.indices)
3. torch.mean
(均值)
用于计算张量的平均值。
示例:
1# 计算整个张量的平均值
2mean_val = torch.mean(x)
3print("平均值:", mean_val)
4
5# 计算每一列的平均值
6means_per_column = torch.mean(x, dim=0)
7print("每列的平均值:", means_per_column)
8
9# 计算每一行的平均值
10means_per_row = torch.mean(x, dim=1)
11print("每行的平均值:", means_per_row)
4. torch.prod
(乘积)
用于计算张量的元素乘积。
示例:
1# 计算整个张量的乘积
2prod_val = torch.prod(x)
3print("乘积:", prod_val)
4
5# 计算每一列的乘积
6prods_per_column = torch.prod(x, dim=0)
7print("每列的乘积:", prods_per_column)
8
9# 计算每一行的乘积
10prods_per_row = torch.prod(x, dim=1)
11print("每行的乘积:", prods_per_row)
5. torch.argmax
和 torch.argmin
(最大值和最小值的索引)
用于找到张量中的最大值或最小值的索引。
示例:
1# 找到整个张量的最大值索引
2argmax_val = torch.argmax(x)
3print("最大值索引:", argmax_val)
4
5# 找到每一列的最大值索引
6argmax_vals_per_column = torch.argmax(x, dim=0)
7print("每列最大值的索引:", argmax_vals_per_column)
8
9# 找到每一行的最大值索引
10argmax_vals_per_row = torch.argmax(x, dim=1)
11print("每行最大值的索引:", argmax_vals_per_row)
12
13# 找到整个张量的最小值索引
14argmin_val = torch.argmin(x)
15print("最小值索引:", argmin_val)
16
17# 找到每一列的最小值索引
18argmin_vals_per_column = torch.argmin(x, dim=0)
19print("每列最小值的索引:", argmin_vals_per_column)
20
21# 找到每一行的最小值索引
22argmin_vals_per_row = torch.argmin(x, dim=1)
23print("每行最小值的索引:", argmin_vals_per_row)
6. torch.topk
(最大值及其索引)
用于找到张量中最大的k个值及其索引。
示例:
1# 找到整个张量中最大的3个值及其索引
2topk_val, topk_indices = torch.topk(x.flatten(), k=3)
3print("最大的3个值:", topk_val)
4print("最大的3个值的索引:", topk_indices)
5
6# 找到每一列中最大的2个值及其索引
7topk_val_per_column, topk_indices_per_column = torch.topk(x, k=2, dim=0)
8print("每列最大的2个值:", topk_val_per_column)
9print("每列最大的2个值的索引:", topk_indices_per_column)
10
11# 找到每一行中最大的2个值及其索引
12topk_val_per_row, topk_indices_per_row = torch.topk(x, k=2, dim=1)
13print("每行最大的2个值:", topk_val_per_row)
14print("每行最大的2个值的索引:", topk_indices_per_row)
7. torch.kthvalue
(第k个值)
用于找到张量中第k个排序后的值及其索引。
示例:
1# 找到整个张量中第3个排序后的值及其索引
2kth_value, kth_index = torch.kthvalue(x.flatten(), k=3)
3print("第3个排序后的值:", kth_value)
4print("第3个排序后的值的索引:", kth_index)
5
6# 找到每一列中第2个排序后的值及其索引
7kth_value_per_column, kth_index_per_column = torch.kthvalue(x, k=2, dim=0)
8print("每列第2个排序后的值:", kth_value_per_column)
9print("每列第2个排序后的值的索引:", kth_index_per_column)
10
11# 找到每一行中第2个排序后的值及其索引
12kth_value_per_row, kth_index_per_row = torch.kthvalue(x, k=2, dim=1)
13print("每行第2个排序后的值:", kth_value_per_row)
14print("每行第2个排序后的值的索引:", kth_index_per_row)
8. 比较运算(大于、小于、等于)
用于比较张量中的元素。
1# 创建另一个张量用于比较
2y = torch.tensor([[2, 2], [2, 2]])
3
4# 大于运算
5greater_than = x > y
6print("大于比较结果:", greater_than)
7
8# 小于运算
9less_than = x < y
10print("小于比较结果:", less_than)
11
12# 等于运算
13equal_to = x == y
14print("等于比较结果:", equal_to)
9. keepdim
参数
在使用 torch.max
、torch.mean
等方法时,可以设置 keepdim=True
来保持原维度。
示例:
1# 计算每一列的平均值,并保持原维度
2means_per_column_keepdim = torch.mean(x, dim=0, keepdim=True)
3print("每列的平均值(保持维度):", means_per_column_keepdim)
4
5# 计算每一行的平均值,并保持原维度
6means_per_row_keepdim = torch.mean(x, dim=1, keepdim=True)
7print("每行的平均值(保持维度):", means_per_row_keepdim)
这些方法和操作在实际的数据处理和分析中非常有用,可以帮助你更好地理解和操作张量数据。通过这些操作,你可以执行复杂的统计分析、特征工程以及其他数据预处理任务。
在PyTorch中,torch.where
是一个非常有用的函数,用于根据条件选择元素。它类似于NumPy中的 np.where
函数,可以在条件表达式为真时选择一个张量中的值,否则选择另一个张量中的值。这对于基于条件进行数据筛选或修改非常有用。
高级用法(where,gather)
torch.where
函数的基本语法
torch.where
的基本语法如下:
1torch.where(condition, x, y)
condition
:一个布尔张量,用于指定选择x
或y
中的元素。x
:当condition
为True
时选择的张量。y
:当condition
为False
时选择的张量。
示例
下面通过一些具体的例子来说明 torch.where
的用法。
示例 1:基于条件选择元素
假设我们有两个张量 A
和 B
,我们想要根据某个条件来选择 A
或 B
中的元素。
1import torch
2
3# 创建两个张量
4A = torch.tensor([[1, 2], [3, 4]])
5B = torch.tensor([[5, 6], [7, 8]])
6
7# 定义一个条件张量
8condition = A > 2
9
10# 使用 torch.where 根据条件选择元素
11result = torch.where(condition, A, B)
12
13print("条件张量 condition:")
14print(condition)
15print("结果张量 result:")
16print(result)
输出将是:
1条件张量 condition:
2tensor([[False, False],
3 [ True, True]])
4结果张量 result:
5tensor([[5, 6],
6 [3, 4]])
在这个例子中,当 A
中的元素大于2时,选择 A
中的元素;否则选择 B
中对应的元素。
示例 2:基于条件修改元素
假设我们有一个张量 C
,我们想要将其中大于某个阈值的元素替换为另一个值。
python
深色版本
1# 创建一个张量
2C = torch.tensor([[1, 2], [3, 4]])
3
4# 定义一个条件
5threshold = 2
6condition = C > threshold
7
8# 将大于阈值的元素设置为 10
9modified_C = torch.where(condition, 10, C)
10
11print("原始张量 C:")
12print(C)
13print("修改后的张量 modified_C:")
14print(modified_C)
输出将是:
1原始张量 C:
2tensor([[1, 2],
3 [3, 4]])
4修改后的张量 modified_C:
5tensor([[1, 2],
6 [10, 10]])
在这个例子中,当 C
中的元素大于2时,将其替换为10;否则保持不变。
多维张量示例
假设我们有一个三维张量 D
,我们想要根据某个条件来选择元素。
1# 创建一个三维张量
2D = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
3
4# 定义一个条件
5condition = D > 4
6
7# 将大于4的元素设置为 10
8modified_D = torch.where(condition, 10, D)
9
10print("原始三维张量 D:")
11print(D)
12print("修改后的三维张量 modified_D:")
13print(modified_D)
输出将是:
1原始三维张量 D:
2tensor([[[1, 2],
3 [3, 4]],
4 [[5, 6],
5 [7, 8]]])
6修改后的三维张量 modified_D:
7tensor([[[ 1, 2],
8 [ 3, 4]],
9 [[10, 10],
10 [10, 10]]])
在这个例子中,当 D
中的元素大于4时,将其替换为10;否则保持不变。
总结
torch.where
是一个非常强大的函数,可以用于基于条件选择张量中的元素或修改张量中的元素。它广泛应用于数据处理、特征工程以及机器学习任务中,帮助实现条件选择和数据转换等功能。通过上述示例,你应该能够更好地理解和应用 torch.where
函数。
在PyTorch中,torch.gather
是一个用于从源张量中收集元素的方法,根据索引张量来选取元素。它常用于根据索引选择特定维度上的元素,这对于实现某些复杂的索引和选择操作非常有用。
torch.gather
的基本语法
torch.gather
的基本语法如下:
1torch.gather(input, dim, index, out=None)
input
:源张量,从中提取元素。dim
:沿着哪个维度进行收集。index
:索引张量,指示从源张量的哪个位置提取元素。out
:可选参数,用于指定输出张量的位置。
示例
下面通过一些具体的例子来说明 torch.gather
的用法。
示例 1:一维索引
假设我们有一个二维张量 A
,我们想要根据另一个索引张量 idx
来选择 A
中的元素。
1import torch
2
3# 创建一个二维张量
4A = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
5
6# 创建一个索引张量
7idx = torch.tensor([[0, 2], [1, 0], [2, 1]])
8
9# 沿着第一个维度收集元素
10result = torch.gather(A, 0, idx)
11
12print("源张量 A:")
13print(A)
14print("索引张量 idx:")
15print(idx)
16print("结果张量 result:")
17print(result)
输出将是:
1源张量 A:
2tensor([[1, 2, 3],
3 [4, 5, 6],
4 [7, 8, 9]])
5索引张量 idx:
6tensor([[0, 2],
7 [1, 0],
8 [2, 1]])
9结果张量 result:
10tensor([[ 1, 3],
11 [ 4, 4],
12 [ 7, 8]])
在这个例子中,我们沿着第一个维度(即行)进行收集。idx
张量中的每个元素指示从 A
的对应行中选择哪个元素。例如,idx[0, 0]
的值为0,意味着从 A
的第一行(即索引为0的行)选择第一个元素(即索引为0的元素);idx[0, 1]
的值为2,意味着从 A
的第一行选择第三个元素(即索引为2的元素)。
示例 2:二维索引
假设我们有一个三维张量 B
,我们想要根据索引张量 idx
来选择 B
中的元素。
1# 创建一个三维张量
2B = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]])
3
4# 创建一个索引张量
5idx = torch.tensor([[0, 1], [1, 0], [1, 1]])
6
7# 沿着第二个维度收集元素
8result = torch.gather(B, 1, idx)
9
10print("源张量 B:")
11print(B)
12print("索引张量 idx:")
13print(idx)
14print("结果张量 result:")
15print(result)
输出将是:
1源张量 B:
2tensor([[[ 1, 2],
3 [ 3, 4]],
4 [[ 5, 6],
5 [ 7, 8]],
6 [[ 9, 10],
7 [11, 12]]])
8索引张量 idx:
9tensor([[0, 1],
10 [1, 0],
11 [1, 1]])
12结果张量 result:
13tensor([[[ 1, 4],
14 [ 6, 5],
15 [10, 12]]])
在这个例子中,我们沿着第二个维度(即列)进行收集。idx
张量中的每个元素指示从 B
的对应行中选择哪个列的元素。
实际应用示例
示例 3:在自然语言处理中的应用
在自然语言处理中,torch.gather
常用于根据预测的概率分布来选择单词。假设我们有一个批处理的句子,每个句子都有一个预测的概率分布,我们想要根据每个句子的概率分布选择最可能的单词。
1# 假设我们有一个批处理的句子,每个句子都有一个预测的概率分布
2probs = torch.tensor([
3 [0.1, 0.2, 0.3, 0.4],
4 [0.4, 0.1, 0.2, 0.3],
5 [0.2, 0.3, 0.1, 0.4]
6])
7
8# 获取每个句子中概率最高的单词的索引
9max_prob_indices = torch.argmax(probs, dim=1, keepdim=True)
10
11# 选择每个句子中最可能的单词
12selected_words = torch.gather(probs, 1, max_prob_indices)
13
14print("概率分布 probs:")
15print(probs)
16print("概率最高的单词的索引 max_prob_indices:")
17print(max_prob_indices)
18print("选择的单词 selected_words:")
19print(selected_words)
输出将是:
1概率分布 probs:
2tensor([[0.10, 0.20, 0.30, 0.40],
3 [0.40, 0.10, 0.20, 0.30],
4 [0.20, 0.30, 0.10, 0.40]])
5概率最高的单词的索引 max_prob_indices:
6tensor([[3],
7 [0],
8 [3]])
9选择的单词 selected_words:
10tensor([[0.40],
11 [0.40],
12 [0.40]])
在这个例子中,我们首先找到了每个句子中概率最高的单词的索引,然后使用 torch.gather
选择了每个句子中最可能的单词。
总结
torch.gather
是一个非常有用的函数,可以用于从张量中根据索引选择特定的元素。通过指定索引张量和维度,你可以实现复杂的数据选择和索引操作。这对于数据处理、特征工程以及模型实现都非常重要。