【PyTorch】张量操作与线性回归

张量的操作

Tensor Operation

拼接与切分

1.1 torch.cat()

torch.cat(tensors, dim=0, out=None)

功能:将张量按维度dim进行拼接

  • tensors:张量序列
  • dim:要拼接的维度

1.2 torch.stacok()

torch.stack(tensors, dim=0, out=None)

功能:在新创建的维度dim上进行拼接

  • tensors:张量序列
  • 要拼接的维度

1.3 torch.chunk()

torch.chunk(input, chunks, dim=0)

功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量

  • input:要切分的张量
  • chunks:要切分的份数
  • dim:要切分的维度

1.4 torch.split()

torch.chunk(input, chunks, dim=0)

功能:将张量按维度dim进行切分
返回值:张量列表

  • tensor:要切分的张量
  • split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
  • dim:要切分的维度

索引

2.1 torch.index_select()

torch.index_select(input, dim, index, out=None)

功能:在维度dim上,按index索引数据返回值(依index索引数据拼接的张量)

  • input:要索引的张量
  • dim:要索引的维度
  • index:要索引数据的序号

2.2 torch.masked_select()

torch.masked_select(input, mask, out=None)

功能:按mask中的True进行索引
返回值:一维张量

  • input:要索引的张量
  • mask:与input同形状的布尔类型张量

变换

3.1 torch.reshape()

torch.reshape(input, shape)

功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存(改变一个变量时,另一个变量也会被改变)。

  • input:要变换的张量
  • shape:新张量的形状

3.2 torch.transpose()

torch.transpose(input, dim0, dim1)

功能:交换张量的两个维度

  • input:要变换的张量
  • dim0:要交换的维度
  • dim1:要交换的维度

3.3 torch.t()

功能:2维张量转置,对矩阵而言,等价于

torch.transpose(input, 0, 1)

3.4 torch.squeeze()

torch.squeeze(input, dim=None, out=None)

功能:压缩长度为1的维度(轴)

  • dim:若为None,移除所有长度为1 的轴;若指定维度,当且仅当该轴长度为1时,可以被移除

3.5 torch.unsqueeze()

torch.unsqueeze(input, dim)

功能:依据dim扩展维度

  • dim:扩展的维度

张量的数学运算

Tensor Math Operations

加减乘除

torch.add()
torch.addcidv()
torch.addcmul()
torch.sub()
torch.div()
torch.mul()

对数、指数、幂函数

torch.log(input, out=None)
torch.log10(input, out=None)
torch.log2(input, out=None)
torch.exp(input, out=None)
torch.pow()

三角函数

torch.abs(input, out=None)
torch.acos(input, out=None)
torch.cosh(input, out=None)
torch.cos(input, out=None)
torch.asin(input, out=None)
torch.atan(input, out=None)
torch.atan2(input, other, out=None)

实例

torch.add()

torch.add(input, other, out=None)
torch.add(input, other, *, alpha=1, out=None)

功能:逐元素计算 input + alpha × other

  • input:第一个张量
  • alpha:乘项因子
  • other:第二个张量

Pythonic:
torch.addcdiv()

torch.addcdiv(input, tensor1, tensor2, *, value=1, out=None)

在这里插入图片描述

torch.addcmul()

torch.addcmul(input, tensor1, tensor2, *, value=1, out=None)

在这里插入图片描述

线性回归

Linear Regression

基本概念

线性回归是分析一个变量与另外一(多)个变量之间关系的方法。

因变量:y
自变量:x
关系:线性

y = wx + b

分析:求解w,b

求解步骤

  1. 确定模型
    Model:y = wx + b
  2. 选择损失函数
    MSE
    在这里插入图片描述
  3. 求解梯度并更新w,b
    w = w - LR * w.grad
    b = b - LR * w.grad

完整代码

import torch
import matplotlib.pyplot as plt
torch.manual_seed(10)lr = 0.05  # 学习率# 创建训练数据
x = torch.rand(20, 1) * 10  # x data (tensor), shape=(20, 1)
# torch.randn(20, 1) 用于添加噪声
y = 2*x + (5 + torch.randn(20, 1))  # y data (tensor), shape=(20, 1)# 构建线性回归参数
w = torch.randn((1), requires_grad=True) # 设置梯度求解为 true
b = torch.zeros((1), requires_grad=True) # 设置梯度求解为 true# 迭代训练 1000 次
for iteration in range(1000):# 前向传播,计算预测值wx = torch.mul(w, x)y_pred = torch.add(wx, b)# 计算 MSE lossloss = (0.5 * (y - y_pred) ** 2).mean()# 反向传播loss.backward()# 更新参数b.data.sub_(lr * b.grad)w.data.sub_(lr * w.grad)# 每次更新参数之后,都要清零张量的梯度w.grad.zero_()b.grad.zero_()# 绘图,每隔 20 次重新绘制直线if iteration % 20 == 0:plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)plt.text(2, 20, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color':  'red'})plt.xlim(1.5, 10)plt.ylim(8, 28)plt.title("Iteration: {}\nw: {} b: {}".format(iteration, w.data.numpy(), b.data.numpy()))plt.pause(0.5)# 如果 MSE 小于 1,则停止训练if loss.data.numpy() < 1:break

参考链接

PyTorch 学习笔记

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1540952.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

java自定义线程池详解

目录 线程池使用线程池的目的线程池工作原理线程池常用方法自定义线程池等待队列拒绝策略线程工厂 线程池 使用线程池的目的 资源复用&#xff0c;降低开销。重复利用已创建的线程&#xff0c;避免线程频繁地创建和销毁带来的性能开销。方便线程的可管理性。线程是稀缺资源&a…

C++速通LeetCode中等第14题-旋转图像

思路图解&#xff1a; class Solution { public:void rotate(vector<vector<int>>& matrix) {// 设矩阵行列数为 nint n matrix.size();// 起始点范围为 0 < i < n / 2 , 0 < j < (n 1) / 2// 其中 / 为整数除法for (int i 0; i < n / 2; i)…

传知代码-多示例AI模型实现病理图像分类

代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 概述 本文将基于多示例深度学习EPLA模型实现对乳腺癌数据集BreaKHis_v1的分类。EPLA模型是处理组织病理学图像的经典之作。EPLA模型是基于多示例学习来进行了&#xff0c;那么多示例学习模型对处理病理学图像具有…

滚动条指定距离滚动

/*** scroller 滚动条元素* to 滚动到位置* duration 滚动时间*/ function scrollLeftTo (scroller, to, duration) {let rafIdlet count 0const from scroller.scrollLeftconst frames duration 0 ? 1 : Math.round((duration * 1000) / 16)function cancel () {cancelAn…

中间件知识点-消息中间件(Kafka)二

Kafka 一、Kafka介绍及基本原理 kafka是一个分布式的、支持分区的、多副本、基于zookeeper的分布式消息系统/中间件。 kafka一般不会删除消息&#xff0c;不管这些消息有没有被消费。只会根据配置的日志保留时间(log.retention.hours)确认消息多久被删除&#xff0c;默认保留…

Navicat数据库管理工具实现Excel、CSV文件导入到MySQL数据库

1.所需要的工具和环境 navicat等第三方数据库管理工具云服务器中安装了 1Panel面板搭建的mysql数据库 2.基于 1Panel启动mysql容器 2.1 环境要求 安装前请确保您的系统符合安装条件&#xff1a; 操作系统&#xff1a;支持主流 Linux 发行版本&#xff08;基于 Debian / Re…

【Python机器学习】NLP信息提取——提取人物/事物关系

目录 词性标注 实体名称标准化 实体关系标准化和提取 单词模式 文本分割 断句 断句的方式 使用正则表达式进行断句 词性标注 词性&#xff08;POS&#xff09;标注可以使用语言模型来完成&#xff0c;这个语言模型包含词及其所有可能词性组成的字典。然后&#xff0c;该…

Jboss Administration Console弱⼝令

漏洞描述 Administration Console管理⻚⾯存在弱⼝令&#xff0c;admin:admin&#xff0c;登陆后台上传war包 , getshell 影响版本 全版本 环境搭建 因为这⾥⽤的环境是CVE-2017-12149的靶机 cd vulhub-master/jboss/CVE-2017-12149 docker-compose up -d 密码⽂件 /j…

开发易忽视的问题:InnoDB 行锁设计与实现

开发易忽视的问题&#xff1a;InnoDB 行锁设计与实现 存储模型和锁机制 存储结构 数据页&#xff1a; InnoDB 将表的数据存储在数据页中&#xff0c;每个页默认大小为 16KB。数据页中存储多个行记录&#xff0c;行记录按照主键顺序存放。 行格式&#xff1a; InnoDB 支持多种…

VSCode开发ros程序无法智能提示的解决方法(二)

VSCode开发ros程序无法智能提示的解决方法&#xff08;二&#xff09; 说明解决 说明 在Ubuntu下使用vscode开发ros程序&#xff0c;无法进行智能提示。 解决 将C/C更换为v1.20.5版本&#xff0c;如下图

sheng的学习笔记-AI-强化学习(Reinforcement Learning, RL)

AI目录&#xff1a;sheng的学习笔记-AI目录-CSDN博客 基础知识 什么是强化学习 强化学习&#xff08;Reinforcement Learning, RL&#xff09;&#xff0c;又称再励学习、评价学习或增强学习&#xff0c;是机器学习的范式和方法论之一&#xff0c;用于描述和解决智能体&#…

Trainer API训练属于自己行业的本地大语言模型 医疗本地问答大模型示例

Trainer API 是 Hugging Face transformers 库中强大而灵活的工具&#xff0c;简化了深度学习模型的训练和评估过程。通过提供高层次的接口和多种功能&#xff0c;Trainer API 使研究人员和开发者能够更快地构建和优化自然语言处理模型 文章目录 前言一、Trainer API它能做什么…

RNN的反向传播

目录 1.RNN网络&#xff1a;通过时间反向传播(through time back propagate TTBP) 2.RNN梯度分析 2.1隐藏状态和输出 2.2正向传播&#xff1a; 2.3反向传播&#xff1a; 2.4问题瓶颈&#xff1a; 3.截断时间步分类&#xff1a; 4.截断策略比较 5.反向传播的细节 ​编辑…

达梦数据库踩坑

提示&#xff1a;第一次接触达梦&#xff0c;是真的不好用&#xff0c;各种报错不提示详细信息&#xff0c;吐槽归吐槽&#xff0c;还是需要学习使用的。 前言 题主刚接触达梦数据库时&#xff0c;本来是想下载官网的连接工具进行数据库连接的&#xff0c;但是谁曾想&#xff…

监控易监测对象及指标之:全面监控GBase数据库

在数字化时代&#xff0c;数据库作为企业核心数据资产的管理中心&#xff0c;其稳定性和性能直接关系到业务的连续性和企业的运营效率。GBase数据库作为高性能的分布式数据库系统&#xff0c;广泛应用于各类业务场景。为了确保GBase数据库的稳定运行和高效性能&#xff0c;对其…

git安装包夸克网盘下载

git安装包夸克网盘下载 git夸克网盘 git网站上的安装包下载速度有点慢&#xff0c;因此为了方便以后下载就将文件保存到夸克网盘上&#xff0c;链接&#xff1a;我用夸克网盘分享了「git」&#xff0c;点击链接即可保存。 链接&#xff1a;https://pan.quark.cn/s/07c73c4a30…

C++速通LeetCode中等第12题-矩阵置零(空间O(1)含注释)

class Solution { public:void setZeroes(vector<vector<int>>& matrix) {int m matrix.size();int n matrix[0].size();int flag_col0 false, flag_row0 false;//先记录第一行和第一列是否有零for (int i 0; i < m; i) {if (!matrix[i][0]) {flag_col…

基于单片机的智能健康水杯设计

摘要&#xff1a;随着时代的发展&#xff0c;单片机领域不断扩张。人工智能产品的出现改变了人们的生活方式。智能产品不仅加快了人们的生活节奏&#xff0c;还为人们的安全提供了保障。在快节奏生活的同时&#xff0c;人们开始越来越关注自己的身体健康&#xff0c;基于 52 单…

高级java每日一道面试题-2024年9月20日-分布式篇-什么是CAP理论?

如果有遗漏,评论区告诉我进行补充 面试官: 什么是CAP理论&#xff1f; 我回答: 在Java高级面试中&#xff0c;CAP理论是一个经常被提及的重要概念&#xff0c;它对于理解分布式系统的设计和优化至关重要。CAP理论是分布式系统理论中的一个重要概念&#xff0c;它描述了一个分…

c++11右值引用和移动语义

一.左值引用和右值引用 什么是左值引用&#xff0c;什么是右值引用 左值是一个表示数据的表达式&#xff08;变量名解引用的指针&#xff09;&#xff0c;我们可以获取到它的地址&#xff0c;可以对它赋值&#xff0c;左值可以出现在符号的左边。使用const修饰后&#xff0c;…