【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算

# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包

pip install torchsummary

import torch
import torch.nn as nn
from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构# 创建神经网络模型类
class Model(nn.Module):# 初始化模型的构造函数def __init__(self):super().__init__()  # 调用父类 nn.Module 的初始化方法# 定义第一个全连接层(线性层),3个输入特征,3个输出特征self.linear1 = nn.Linear(3, 3)  # 使用 Xavier 正态分布初始化第一个全连接层的权重nn.init.xavier_normal_(self.linear1.weight)# 定义第二个全连接层,输入 3 个特征,输出 2 个特征self.linear2 = nn.Linear(3, 2)# 使用 Kaiming 正态分布初始化第二个全连接层的权重,适合 ReLU 激活函数nn.init.kaiming_normal_(self.linear2.weight)# 定义输出层,输入 2 个特征,输出 2 个特征self.out = nn.Linear(2, 2)# 定义前向传播过程 (forward 函数会自动执行,类似于模型的"推理"过程)def forward(self, x):# 第一个全连接层运算x = self.linear1(x)# 使用 Sigmoid 激活函数x = torch.sigmoid(x)# 第二个全连接层运算x = self.linear2(x)# 使用 ReLU 激活函数x = torch.relu(x)# 输出层运算x = self.out(x)# 使用 Softmax 激活函数,将输出转化为概率分布# dim=-1 表示在最后一个维度(通常是输出的类别维度)上做 softmax 归一化x = torch.softmax(x, dim=-1)return xif __name__ == '__main__':# 实例化神经网络模型my_model = Model()# 随机生成一个形状为 (5, 3) 的输入数据,表示 5 个样本,每个样本有 3 个特征my_data = torch.randn(5, 3)print("mydata shape", my_data.shape)# 通过模型进行前向传播,输出模型的预测结果output = my_model(my_data)print("output shape", output.shape)# 计算并显示模型的参数总量以及模型结构summary(my_model, input_size=(3,), batch_size=5)# 查看模型中所有的参数,包括权重和偏置项(bias)print("-----查看模型参数w 和 b  -----")for name, parameter in my_model.named_parameters():print(name, parameter)

mydata shape torch.Size([5, 3])
output shape torch.Size([5, 2])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                     [5, 3]              12
            Linear-2                     [5, 2]               8
            Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
-----查看模型参数w 和 b  -----
linear1.weight Parameter containing:
tensor([[ 0.4777, -0.2076,  0.4900],
        [-0.1776,  0.4441,  0.6924],
        [-0.5449,  1.6153,  0.0243]], requires_grad=True)
linear1.bias Parameter containing:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)
linear2.weight Parameter containing:
tensor([[-0.0510, -1.2731, -0.7253],
        [-0.6112,  0.1189, -0.4903]], requires_grad=True)
linear2.bias Parameter containing:
tensor([0.5391, 0.2552], requires_grad=True)
out.weight Parameter containing:
tensor([[-0.3271, -0.3483],
        [-0.0619, -0.0680]], requires_grad=True)
out.bias Parameter containing:
tensor([-0.5508,  0.5895], requires_grad=True)
 

 代码输出结果解读

​​​​​​​

这个代码的输出展示了两部分内容:

  1. 数据维度和模型输出维度

    • mydata shape torch.Size([5, 3])

    • output shape torch.Size([5, 2])

  2. 模型的结构、参数数量和每一层的权重与偏置

    • 模型的层结构、每一层的输出形状,以及每一层的参数数量。

    • 每层的权重(weight)和偏置(bias)的具体数值。

让我们详细分析每一部分的输出。

1. 输入数据和输出数据的形状

mydata shape torch.Size([5, 3])

这部分的输出说明:

  • 输入数据的形状(5, 3),表示有 5 个样本,每个样本有 3 个特征。这与模型定义时的输入层 nn.Linear(3, 3) 是一致的,输入层期望接收 3 个特征。

output shape torch.Size([5, 2])

这部分的输出说明:

  • 模型输出的形状 (5, 2),表示 5 个样本的输出,每个样本的输出有 2 个值。由于模型的输出层定义为 nn.Linear(2, 2),它接收 2 个输入特征并输出 2 个值,符合预期。

2. 模型结构和参数

模型结构和参数信息是通过 summary() 函数生成的,它列出了每一层的名称、输出形状和参数数量。

详细输出解释:
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1                     [5, 3]             12Linear-2                     [5, 2]               8Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
线性层 1(Linear-1
  • 层的类型Linear,这是一个全连接层,定义为 nn.Linear(3, 3)

  • 输出形状[5, 3],表示输入了 5 个样本,每个样本有 3 个特征,经过该层的输出仍然是 5 个样本,每个样本有 3 个特征。

  • 参数数量:12,其中 9 个是权重参数(3 x 3 的权重矩阵),另外 3 个是偏置项。

线性层 2(Linear-2
  • 层的类型Linear,定义为 nn.Linear(3, 2),将 3 个输入特征映射到 2 个输出特征。

  • 输出形状[5, 2],表示输入了 5 个样本,每个样本有 2 个输出特征。

  • 参数数量:8,其中 6 个是权重参数(3 x 2 的权重矩阵),另外 2 个是偏置项。

输出层(Linear-3
  • 层的类型Linear,定义为 nn.Linear(2, 2),接收 2 个输入特征,输出 2 个特征。

  • 输出形状[5, 2],表示 5 个样本,每个样本的输出为 2 个特征。

  • 参数数量:6,其中 4 个是权重参数(2 x 2 的权重矩阵),另外 2 个是偏置项。

参数统计
  • 总参数数量:26,模型中所有可训练参数(包括权重和偏置)的总数量。

  • 可训练参数:26,模型中所有参与训练的参数。这里所有的参数都是可训练的(requires_grad=True),没有非可训练的参数。

  • 非可训练参数:0,说明模型中没有被设置为不可训练的参数。

3. 查看每一层的权重和偏置

这一部分输出列出了每一层的具体参数(权重和偏置)的值。

linear1.weight:
tensor([[ 0.4777, -0.2076, 0.4900],[-0.1776, 0.4441, 0.6924],[-0.5449, 1.6153, 0.0243]], requires_grad=True)

这是 linear1 层的权重矩阵,形状是 (3, 3)。由于 linear1nn.Linear(3, 3),它的权重矩阵也是 3 行 3 列。权重参数是使用 Xavier 初始化(nn.init.xavier_normal_)初始化的。

linear1.bias:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)

这是 linear1 层的偏置项,形状是 (3,),因为每个输出特征对应一个偏置值。

linear2.weight:
tensor([[-0.0510, -1.2731, -0.7253],[-0.6112, 0.1189, -0.4903]], requires_grad=True)

这是 linear2 层的权重矩阵,形状是 (2, 3),因为 linear2nn.Linear(3, 2),需要 3 个输入特征映射到 2 个输出特征。权重是使用 Kaiming 初始化nn.init.kaiming_normal_初始化的。

linear2.bias:
tensor([0.5391, 0.2552], requires_grad=True)

这是 linear2 层的偏置项,形状是 (2,),因为每个输出特征对应一个偏置值。

out.weight:
tensor([[-0.3271, -0.3483],[-0.0619, -0.0680]], requires_grad=True)

这是输出层 out 的权重矩阵,形状是 (2, 2),因为 outnn.Linear(2, 2),接收 2 个输入特征并输出 2 个特征。

out.bias:
tensor([-0.5508, 0.5895], requires_grad=True)

这是输出层 out 的偏置项,形状是 (2,)

总结

  • 这段代码展示了一个简单的神经网络模型,包含 3 个全连接层(线性层),每层的输入输出特征数量逐步缩小。

  • 我们通过 summary() 查看了模型的整体结构,展示了每一层的输出形状和参数数量,总共有 26 个参数。

  • 每一层的权重和偏置参数值被输出,展示了它们是如何被初始化的(通过 Xavier 和 Kaiming 初始化)。

  • 该模型的前向传播通过激活函数(sigmoidReLU)以及 softmax 将输出转化为概率分布。

​​​​​​​​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1544096.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

一体化运维监控管理平台的全面监控能力

在当今高度信息化的时代,运维监控管理平台的重要性日益凸显。一个优秀的监控平台不仅要能够全面覆盖各类IT和智能设备,还需具备灵活性和可扩展性,以适应不断变化的监控需求。本文旨在深入探讨一体化运维监控管理平台的全面监控能力&#xff0…

MySQL record 08 part

数据库连接池: Java DataBase Connectivity(Java语言连接数据库) 答: 使用连接池能解决此问题, 连接池,自动分配连接对象,并对闲置的连接进行回收。 常用的数据库连接池: 建立数…

ChatGLM-6B 部署与使用——打造你的专属GLM

ChatGLM-6B 部署与使用指南 ChatGLM-6B 是清华大学与智谱 AI 开源的一款对话语言模型,基于 General Language Model (GLM) 架构,参数达到 62 亿,因其卓越的语言理解与生成能力,受到广泛关注。 一、在 DAMODEL 上部署 ChatGLM-6B…

基于单片机巡迹避障智能小车系统

文章目录 前言资料获取设计介绍设计程序具体实现截图设计获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对象是咱们…

RPA + 计算机视觉

随着超自动化成为顶级企业技术趋势之一,领先的机器人流程自动化 (RPA) 公司开始将人工智能功能集成到其自动化工具中,以创建能够自动化端到端流程并做出决策的智能机器人。计算机视觉是新一代 RPA 工具的关键 AI 功能之一。 在本文中,我们将…

2024年CSP-J认证 CCF信息学奥赛C++ 中小学初级组 第一轮真题-选择题解析

2024年 中小学信息学奥赛CSP-J真题解析 1、32 位 int 类型的存储范围是 A、 -2147483647 ~ 2147483647 B、 -2147483647 ~ 2147483648 C、 -2147483648 ~ 2147483647 D、 -2147483648 ~ 2147483648 答案:C 考点分析:主要考查小朋友们数据类型的存储…

Centos/fedora/openEuler 终端中文显示配置

注意:这里主要解决的是图形界面、远程登录界面的中文乱码问题 系统原生的终端(如虚拟机系统显示的终端),由于使用的是十分原始的 TTY 终端,使用点阵字体进行显示,点阵字体不支持中文,因此无法显…

前端——表单标签样式

1. form表单标签 块级元素 action: 表单提交地址 method: 表单提交格式 https网络协议请求格式: post/get等 通常: post方式是发送数据 而get是拿取数据 name: 表单的名称 target: 提交完表单之后 你的新页面在哪里打开 2. input输入控件 可以通过type属性 …

7,STM32CubeMX配置IIC工程(OLED显示)

1,前言 单片机型号:STM32F407 编程环境 :STM32CubeMX Keil v5 硬件连接 :串口1,ADC1CH5--->PA5 注:本工程在1,STM32CubeMX工程基础(配置Debug、时钟树)基础上…

【Unity保龄球项目】的实现逻辑以及代码解释

1.BaoLQManager.cs 这个脚本实现了基本的保龄球游戏逻辑,包括扔球功能。 using System.Collections; using System.Collections.Generic; using UnityEngine;public class BaoLQManager : MonoBehaviour {// 业务逻辑1:把保龄球扔出去// 业务逻辑2&am…

祝桥镇星光村火情闪电救援:速控之下,安全警钟长鸣

安科瑞武陈燕 在秋日的午后,阳光本应温柔地洒在浦东新区祝桥镇星光村的每一个角落,然而,一场突如其来的火灾打破了这份宁静。 9月2日中午12时许,该村1队的一户居民家中突然燃起熊熊大火,浓烟滚滚自二楼窗口腾空而起&…

gdb调试和makefile管理

一.gdb调试工具 命令 简写 作用 help h 按模块列出命令类 help class 查看某一类型的具体命令 lsit l 查看代码,可跟行号和函数名 quit q 退出gdb run r 全速运行程序 start 单步执行,运行程序,停在第一行执行语句 next …

Qt 窗口类的继承关系和作用

核心基类 [1] QObject:Qt中许多类的基类,支持Qt对象模型,包括信号和槽机制、对象树和事件系统等。虽然它本身不是直接用于创建窗口的,但它是许多窗口和控件类继承链中的重要一环。 注:如果你创建了一个自定义类&…

杰发科技——Eclipse环境安装

文件已传到网盘: 1. 安装文件准备 2. 安装Make 默认路径:C:\Program Files (x86)\GnuWin32\bin\ 不复制的话会报错 Error: Program "make" not found in PATH 3. 安装工具链 默认路径:C:\Program Files (x86)\Arm GNU Toolchain…

OpenAI converting API code from GPT-3 to chatGPT-3.5

题意:将OpenAI API代码从GPT-3转换为ChatGPT-3.5 问题背景: Below is my working code for the GPT-3 API. I am having trouble converting it to work with chatGPT-3.5. 以下是我用于GPT-3 API的工作代码。我在将其转换为适用于ChatGPT-3.5时遇到了…

Android Studio 真机USB调试运行频繁掉线问题

一、遇到问题 Android Studio使用手机运行项目时,总是频繁掉线,连接很不稳定,动不动就消失,基本上无法使用 二、问题出现原因 1、硬件问题:数据线 换条数据线试试,如果可以,那就是数据线的…

15年408-数据结构

第一题 解析: 栈第一次应该存main的信息。 然后进入到main里面,要输出S(1),将S(1)存入栈内, 进入到S(1)中,1>0,所以还要调用S(0) S(0)进入栈中,此时栈内从下至上依次是main(),S(1),S(0) 答案选A 第二题&…

昇腾AI异构计算架构CANN——高效使能AI原生创新

异构计算与人工智能的关系是什么?昇腾AI异构计算架构CANN是什么?有哪些主要特点?开发者如何利用CANN的原生能力进行大模型创新,构筑差异化竞争力?带着这些问题,我们来认识昇腾AI异构计算架构——CANN。 1 …

随机验证码验证【JavaScript】

这段 JavaScript 代码实现了随机验证码的生成和验证功能。 实现效果&#xff1a; 代码&#xff1a; <!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-…

Vue3教程 - 2 开发环境搭建

更好的阅读体验&#xff1a;点这里 &#xff08; www.foooor.com &#xff09; 2 开发环境搭建 要进行 Vue 开发&#xff0c;需要安装 Node.js&#xff0c;因为构建 Vue 项目的工具&#xff0c;例如 Webpack、Vite等&#xff0c;这些工具依赖于Node.js环境来运行。 Node.js…