[Machine Learning]pytorch手搓一个神经网络模型

因为之前虽然写过一点点关于pytorch的东西,但是用的还是他太少了。

这次从头开始,尝试着搓出一个神经网络模型

(因为没有什么训练数据,所以最后的训练部分使用可能不太好跑起来的代码作为演示,如果有需要自己连上数据集合进行修改捏)

1.先阐述一下什么是神经网络块(block)

一般来说,我们之前遇到的一些神经网络,网络中是这样子的结构

net----> layer ----> neuron

而块的存在,就是给这样一个神经网络的整体做了一个封装操作,让神经网络能复合实现一些功能。

结构就变成了这个样子(图片来自D2l)

这样子,神经网络结构就变成了四层

  block ----》 net ---》layer ---》neuron

这样子自然是可以使用诸如一些奇怪的方法,通过三层索引去进行调用什么的,不过这个我们到后面再说。先看一下如何构建一个块。

我们这里构建了一个类,这个类的计算方法实际就是实现了几个层的输入和输出,相当与封装了一个神经网络。

class MLP(nn.Module):# 用模型参数声明层。这里,我们声明两个全连接的层def __init__(self):# 调用MLP的父类Module的构造函数来执行必要的初始化。# 这样,在类实例化时也可以指定其他函数参数,例如模型参数params(稍后将介绍)super().__init__()self.hidden = nn.Linear(20, 256)  # 隐藏层self.out = nn.Linear(256, 10)  # 输出层# 定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self, X):# 注意,这里我们使用ReLU的函数版本,其在nn.functional模块中定义。return self.out(F.relu(self.hidden(X)))# 这个东西就相当与先隐藏层,然后relu,然后最后进行一次输出#创建这个神经网路块,然后开始输出
net = MLP()
net(X)

这段代码没有使用squential容器进行封装,但是可以很清楚地看到我们定义了两层(隐藏层256个神经元,输出层10个神经元,不知道为什么没用softmax函数),并且在返回函数计算的时候,中间还经过了一步‘relu’激活函数的操作

(注意和tf不同,pytorch框架下面是不能把激活函数存入层中的,需要单独作为一个‘层’来进行一个输入和输出的控制)

注意下(在后面自定义层的时候也是这样子)由于继承了nn.Module这个类 , 所以我们必须要实现两个函数,首先是_init_,这个在python中是最终要的构造函数。其次就是forward,我对py不是很了解,不过这应该是通过面向对象实现的集成。forward这个方法就是向前传播,也就是接受参数,内部计算,然后返回值传递下去。

我们直接给net对象传递我们随机生成的两条数据的时候,底层时就调用了这个函数。

其他的一些比如sequential的实现方法,在这里我们就不加以赘述了。

为了更好的解释forward这个函数的作用,在这里我们自己创建一个单层,通过类创建,仍然是获取一个集成nn.Module的类,然后内部设置好初始化(为了创建对象),设置好向前传播(为了用来调用)

# 自定义一个不需要参数的层
class CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):    #该层向前传播的方法return X - X.mean()
# 这一层最终也是返回一个张量
# sequential是一个简单的线性封装容器,所以只要是符合输入张量,输出张量
# 并且在内部会调用他们的forward方法layer = CenteredLayer()
print('自定义层,每个元素都 - 平均值2',layer(torch.FloatTensor([1, 2, 3, 4, 5])))

这个单层的效果就是对每个元素,都减去平均值。

并且如果想的话,我们也可以创建一些拥有自己属性的层


#现在创建一个带有权重和偏好
class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units)) #这个需要手动输入一下输入特征数目还有神经元数目self.bias = nn.Parameter(torch.randn(units,))            def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.data  #向前传播其实就是接受输入return F.relu(linear)
#创建一个层,这个层可以直接用在sequential之中
linear = MyLinear(5, 3) #五个输入三个神经元

这里可以看到只要重写了forward方法,那么这个类就能变成一个能用来计算的类,甚至是一个层可以单独计算。并且这样子写好以后是可以放在sequential容器中,作为一个统一训练的。

因此,如果我们有多个块的话,也是可以自己去写一个容器,进行组合。

#创建一个新类
class MySequential(nn.Module): #()就是py中的继承语法def __init__(self, *args):super().__init__()for idx, module in enumerate(args):# 这里,module是Module子类的一个实例。我们把它保存在'Module'类的成员# 变量_modules中。_module的类型是OrderedDictself._modules[str(idx)] = moduledef forward(self, X):# OrderedDict保证了按照成员添加的顺序遍历它们for block in self._modules.values():X = block(X)return X#这个类实现的效果就类似原声的sequential
net = MySequential(net1,net2,net3)
print(net(X))

这个就大概是在拼接块,层的时候,内部所做的底层原理。

当然直接用sequential容器是更加省力气的方法,对吧

2.关于参数如何进行检查

假设现在有一个单独的神经网络

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

众所周知,这个神经网络是两层(中间的一层是激活函数我们不做讨论)

我们可以通过索引来调用和获取某个层的属性

#返回结果是这个全链接层的weight和bias,正好对应八个神经元
print(net[0].state_dict())
#检查参数
print(net[2].bias) #还会返回一些具体的属性
print(net[2].bias.data) #单纯的数据

对于block组成的神经网路社区中(我也不知道很多块组在一起应该叫什么了),仍然是一个嵌套的结构,我们可以创建这样一个社区

#这段代码其实也能看出来,sequential也是一个能容纳block的东西
#     容器 --》 block --》 layer --》 神经元   这三层架构(或者说四层)
def block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())
def block2():net = nn.Sequential()for i in range(4):# 在这里嵌套net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))

然后我们对这个rgnet进行打印,可以直接看到工作状态

#这样子打印会展示整个网络的状态
print('检查工作状态',rgnet)

可以很清晰地看到,这样一个嵌套结构

所以比如说我们想要访问第一个社区中,第2个块,中的第一个层中的参数,我们可以直接这样子读取

rgnet[0][1][0].bias.data

另外如果想要对已经形成的模型做初始化,这里还有一个例子

def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)   #平均值0,标准差为0.01nn.init.zeros_(m.bias)                        #偏移直接设置为0
net.apply(init_normal)
print('手动初始化的效果为',net[0].weight.data[0],'手动初始化bias:', net[0].bias.data[0])

函数实现的功能是先检测传进来的是不是正常的线性层,然后分别初始化。

补充一下,apply函数和js里的用法差不多,对内部的每个单元进行遍历,然后做一些操作。

(当然这不是唯一一种方法,自然还有别的)。

3.关于张量的保存和获取

在pytorch中,张量的保存主要有两种形式,第一种是保存数据,用于其他模型的训练

#===========读写张量===========#
x = torch.arange(4)       #[0,1,2,3],创建了一个张量
torch.save(x, 'x-file')   #这是保存在x-file这个文件下面的
loaded_x = torch.load('x-file')  #反过来加载
print(loaded_x)                  #输出
#这样子读取列表和读出,也可以使用字典{x:x,Y:y}或者列表[x,y],反正是变成文件形式了

另一种是保存模型的参数,可以直接套在其他模型上

#=====读写参数并且保存在内存=====#class MLP(nn.Module):   #手动创建多层感知机def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256)self.output = nn.Linear(256, 10)def forward(self, x):return self.output(F.relu(self.hidden(x)))net = MLP()           #构建对象
print('MLP的参数',net.state_dict()) #这里输出一下参数#保存这个模型的参数
torch.save(net.state_dict(), 'mlp.params')#然后对一个新模型使用这个参数
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))  #内置函数加载参数clone.eval()#设置为评估模式,禁止训练什么的,这应该是module中附带的功能print('clone的参数',clone.state_dict())#可以看到参数被完全复制了

但是注意一个问题,如果使用另一个模型初始化自身的时候,要保证两个模型的结构一致

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/149311.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

linux下的永久保存行号

linux下的永久保存行号 1.首先 这里是引用 输入命令:vi ~/.vimrc 其次 这里是引用 输入命令 set number

一文看懂功率MOSFET FCP190N60 N沟道 基础知识

什么是MOSFET的原意是:MOS(Metal Oxide Semiconductor金属氧化物半导体),FET(Field Effect Transistor场效应晶体管),即以金属层(M)的栅极隔着氧化层(O&#…

8.Vue_Element

1 Ajax 1.1 Ajax介绍 1.1.1 Ajax概述 我们前端页面中的数据,如下图所示的表格中的学生信息,应该来自于后台,那么我们的后台和前端是互不影响的2个程序,那么我们前端应该如何从后台获取数据呢?因为是2个程序&#xf…

【云备份项目】【Linux】:环境搭建(g++、json库、bundle库、httplib库)

文章目录 1. g 升级到 7.3 版本2. 安装 jsoncpp 库3. 下载 bundle 数据压缩库4. 下载 httplib 库从 Win 传输文件到 Linux解压缩 1. g 升级到 7.3 版本 🔗链接跳转 2. 安装 jsoncpp 库 🔗链接跳转 3. 下载 bundle 数据压缩库 安装 git 工具 sudo yum…

Vue中如何进行响应式图像与图片懒加载优化

Vue中响应式图像与图片懒加载优化 在现代的Web开发中,图像在网站性能和用户体验方面扮演着至关重要的角色。然而,加载大量的图像可能会导致网页加载速度变慢,从而影响用户的满意度。为了解决这个问题,Vue.js提供了一些强大的工具…

Docker Tutorial

什么是Docker 为每个应用提供完全隔离的运行环境 Dockerfile, Image,Container Image: 相当于虚拟机的快照(snapshot)里面包含了我们需要部署的应用程序以及替它所关联的所有库。通过image,我们可以创建很…

SSRF+redis未授权漏洞复现

1.SSRF漏洞简介 SSRF(Server-Side Request Forgery)即服务器端请求伪造,是一种由攻击者构造攻击链传给服务器,服务器执行并发起请求造成安全问题的漏洞,一般用来在外网探测或攻击内网服务。当网站需要调用指定URL地址…

[软件工具]opencv-svm快速训练助手教程解决opencv C++ SVM模型训练与分类实现任务支持C# python调用

opencv中已经提供了svm算法可以对图像实现多分类,使用svm算法对图像分类的任务多用于场景简单且对时间有要求的场景,因为opencv的svm训练一般只需要很短时间就可以完成训练任务。但是目前网上没有一个工具很好解决训练问题,大部分需要自己编程…

网络安全的发展方向是什么?网络安全学什么内容

前言 不少小伙伴开始学习网络安全技术,但却不知道学习网络安全能找什么工作?网络安全是现下较为火热的职业岗位,吸引了许多企业和个人对网络安全技术的青睐。学习网络安全的人越来越多,网络安全也有很多发展方向。那么如何选择网…

《视觉 SLAM 十四讲》V2 第 5 讲 相机与图像

文章目录 相机 内参 && 外参5.1.2 畸变模型单目相机的成像过程5.1.3 双目相机模型5.1.4 RGB-D 相机模型 实践5.3.1 OpenCV 基础操作 【Code】OpenCV版本查看 5.3.2 图像去畸变 【Code】5.4.1 双目视觉 视差图 点云 【Code】5.4.2 RGB-D 点云 拼合成 地图【Code】 习题题…

【Linux】文件权限详解

🍁 博主 "开着拖拉机回家"带您 Go to New World.✨🍁 🦄 个人主页——🎐开着拖拉机回家_Linux,Java基础学习,大数据运维-CSDN博客 🎐✨🍁 🪁🍁 希望本文能够给您带来一定的…

在word文档里面插入漂亮的伪代码

推荐用texsword.0.8 安装与界面 下载链接:https://sourceforge.net/projects/texsword/ 极为轻便,是Word的一个宏 安装过程也是极为简单,复制解压后的 texsword.dotm 文件到 C:\Users\{YOUR_USER_NAME}\AppData\Roaming\Microsoft\Word\ST…

分布式架构篇

1、微服务 微服务架构风格,就像是把一个单独的应用程序开发为一套小服务,每个服务运行在自己的进程中,并使用轻量级机制通信,通常是 HTTP API。这些服务围绕业务能力来构建,并通过完全自动化部署机制来独立部署。这些…

React框架核心原理

一、整体架构 三大核心库与对应的组件 history -> react-router -> react-router-dom react-router 可视为react-router-dom 的核心&#xff0c;里面封装了<Router>&#xff0c;<Route>&#xff0c;<Switch>等核心组件,实现了从路由的改变到组件的更新…

Ubuntu Server CLI专业提示

基础 网络 获取所有接口的IP地址 networkctl status 显示主机的所有IP地址 hostname -I 启用/禁用接口 ip link set <interface> up ip link set <interface> down 显示路线 ip route 将使用哪条路线到达主机 ip route get <IP> 安全 显示已登录的用户 w…

MySQL数据库单表查询

素材: 表名: worker-- 表中字段均为中文&#xff0c;比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 float(8,2) NOT NULL, 政治面貌 varchar(10) NOT NULL DEFAULT 群…

想要精通算法和SQL的成长之路 - 恢复二叉搜索树和有序链表转换二叉搜索树

想要精通算法和SQL的成长之路 - 恢复二叉搜索树和有序链表转换二叉搜索树 前言一. 恢复二叉搜索树二. 有序链表转换二叉搜索树 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 恢复二叉搜索树 原题链接 首先&#xff0c;一个正常地二叉搜索树在中序遍历下&#xff0c;遍历…

Vue组件路由

1&#xff0c;安装vue-router组件&#xff0c;终端输入&#xff1a; npm i vue-router3.5.3 2&#xff0c;在src文件夹下创建router目录 3&#xff0c;创建index.js文件&#xff0c;配置路由&#xff0c;导入需要路由的组件。以后每次添加路由只要在routes中改变即可。 impo…

CTFHUB - SSRF

目录 SSRF漏洞 攻击对象 攻击形式 产生漏洞的函数 file_get_contents() fsockopen() curl_exec() 提高危害 利用的伪协议 file dict gopher 内网访问 伪协议读取文件 端口扫描 POST请求 总结 上传文件 总结 FastCGI协议 CGI和FastCGI的区别 FastCGI协议 …

如何查看postgresql中的数据库大小?

你可以使用以下命令来查看PostgreSQL数据库的大小&#xff1a; SELECT pg_database.datname as "database_name", pg_size_pretty(pg_database_size(pg_database.datname)) AS size_in_mb FROM pg_database ORDER by size_in_mb DESC;这将返回一个表格&#xff0…