Pytorch学习--神经网络--损失函数与反向传播

一、对于损失函数的理解

  • 计算实际输出和目标之间的差距
  • 为我们更新输出提供一定的依据

二、头文件

nn.L1Loss

在这里插入图片描述
在这里插入图片描述
大概含义:
在这里插入图片描述
代码:

import torch
from torch.nn import L1Lossoutput = torch.tensor([1,2,3],dtype=float)
target = torch.tensor([1,2,5],dtype=float)output = torch.reshape(output,(1,1,1,3))
target = torch.reshape(target,(1,1,1,3))loss = L1Loss()
result = loss(output,target)
print(result)

输出:

tensor(0.6667, dtype=torch.float64)

nn.MSELoss

在这里插入图片描述
大概含义:
在这里插入图片描述
代码:

import torch
from torch.nn import L1Loss, MSELossoutput = torch.tensor([1,2,3],dtype=float)
target = torch.tensor([1,2,5],dtype=float)output = torch.reshape(output,(1,1,1,3))
target = torch.reshape(target,(1,1,1,3))loss = L1Loss()
result = loss(output,target)loss2 = MSELoss()
result2 = loss2(output,target)
print(result2)

输出:

tensor(1.3333, dtype=torch.float64)

nn.CrossEntropyLoss

在这里插入图片描述
在这里插入图片描述

import torch
from torch.nn import L1Loss, MSELoss, CrossEntropyLossoutput = torch.tensor([1,2,3],dtype=float)
target = torch.tensor([1,2,5],dtype=float)output = torch.reshape(output,(1,1,1,3))
target = torch.reshape(target,(1,1,1,3))loss = L1Loss()
result = loss(output,target)loss2 = MSELoss()
result2 = loss2(output,target)x = torch.tensor([0.1,0.2,0.3])
y = torch.tensor(1)  #正确的类别
loss3 = CrossEntropyLoss()
result3 = loss3(x,y)
print(result3)

输出:

tensor(1.1019)

包含 batch_size 的代码:

import torch
import torch.nn as nn# 假设有 3 个样本,每个样本有 5 个类别
x = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5],   # 第一个样本的 logits[0.3, 0.4, 0.2, 0.5, 0.1],   # 第二个样本的 logits[0.5, 0.1, 0.2, 0.4, 0.3]])  # 第三个样本的 logits# 每个样本的真实类别标签
y = torch.tensor([1, 2, 0])  # 第一个样本属于类别 1,第二个样本属于类别 2,第三个样本属于类别 0# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 计算损失
loss = loss_fn(x, y)print(f"Loss: {loss.item()}")

三、搭建模型与损失函数的结合

搭建模型
代码:

import torchvision
from torch import nn
from torch.nn import MaxPool2d, Conv2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("datasets",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)class Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return x
Yorelee = Mary()for data in dataloader:img,target = dataoutput = Yorelee(img)print(output)print(target)loss = nn.CrossEntropyLoss()result_loss = loss(output,target)print(result_loss)print("***********************")result_loss.backward()

输出:

tensor([[ 0.0956, -0.0318, -0.0674, -0.0565,  0.1168,  0.0389,  0.0496, -0.0039,-0.0221,  0.1028]], grad_fn=<AddmmBackward0>)
tensor([3])
tensor(2.3833, grad_fn=<NllLossBackward0>)

打个断点,单步执行前,注意此时还没有 grad:
在这里插入图片描述
单步执行后,此时就有了 grad,为后续优化器的选择做了铺垫:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/9381.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

多功能声学气膜馆:打造移动歌剧院新体验—轻空间

在广场、公园&#xff0c;甚至是郊野之间&#xff0c;多功能声学气膜馆为歌剧表演带来了全新的移动体验。作为高品质演出场馆&#xff0c;它不仅具备卓越的声学效果&#xff0c;还拥有灵活的搭建与拆卸能力&#xff0c;使艺术表演不再受限于固定的场地&#xff0c;让更多人得以…

构建智能防线 灵途科技光电感知助力轨交全向安全防护

10月27日&#xff0c;在南京南站至紫金山东站间的高铁联络线上&#xff0c;一头野猪侵入轨道&#xff0c;与D5515次列车相撞&#xff0c;导致设备故障停车。 事故不仅造成南京南站部分列车晚点&#xff0c;还在故障排查过程中导致随车机械师因被邻线限速通过的列车碰撞而不幸身…

DMX配置文件生成工具使用举例

最新软件版本: MaintainTool-v0.0.1-20241107. 如所用软件低于该版本,本文档描述可能有所差异,请索取最新版本软件再阅读本文档. (软件右上角处查看软件版本) 一、基础知识 这里结合一个示例工程来进行说明. 所用灯带: 使用两种型号的线条灯, XT1(一米八段&#xff0c;即8个…

kafka安装部署--详细教程

2.1 安装部署 每次进入 linux 都会自动进入 base 环境&#xff0c;如何关闭 base conda deactivate 手动关闭 conda config --set auto_activate_base false 关闭自动进入 2.1.1 集群规划 bigdata01 bigdata02 bigdata03 zk zk zk kafka kafka kafka 2.1.2 集群部…

工具进阶:如何利用 MAT 找到问题发生的根本原因

深入浅出 Java 虚拟机 作者&#xff1a; 李国 我们知道&#xff0c;在存储用户输入的密码时&#xff0c;会使用一些 hash 算法对密码进行加工&#xff0c;比如 SHA-1。这些信息同样不允许在日志输出里出现&#xff0c;必须做脱敏处理&#xff0c;但是对于一个拥有系统权限的攻击…

当AI遇上时尚:未来的衣橱会由机器人来打理吗?

内容概要 在当今这个快速发展的时代&#xff0c;人工智能与时尚的结合正在逐渐改写我们对衣橱管理的认知。传统的衣橱管理常常面临着空间不足、穿搭单调及库存过多等挑战&#xff0c;许多人在挑选服饰时难以做出决策。然而&#xff0c;随着技术的进步&#xff0c;智能推荐和自…

[OpenGL]使用OpenGL实现硬阴影效果

一、简介 本文介绍了如何使用OpenGL实现硬阴影效果&#xff0c;并在最后给出了全部的代码。本文基于[OpenGL]渲染Shadow Map&#xff0c;实现硬阴影的流程如下&#xff1a; 首先&#xff0c;以光源为视角&#xff0c;渲染场景的深度图&#xff0c;将light space中的深度图存储…

Kafka中如何做到数据唯一,即数据去重?

数据传递语义 至少一次&#xff08;At Least Once&#xff09; ACK级别设置为-1 分区副本大于等于2 ISR里应答的最小副本数量大于等于2 可以保障数据可靠 • 最多一次&#xff08;At Most Once&#xff09; ACK级别设置为0 • 总结&#xff1a; At Least Once可以保证数据不…

惊爆:抖音小程序广告掘金计划,游戏+广告双赢新趋势!

惊爆&#xff1a;抖音小程序广告掘金计划&#xff0c;游戏广告双赢新趋势&#xff01; 在当今信息爆炸的时代&#xff0c;抖音小程序广告掘金计划犹如一股清流&#xff0c;为游戏开发者、广告商以及广大用户带来了前所未有的机遇与财富。这一计划不仅融合了游戏的趣味性和广告的…

黑豹X2 armbian 编译rkmpp ffmpeg 实现CPU视频转码

硬件 arm64 4核cpu 4G内存 rk3566 1.编译rockmpp git clone https://gitee.com/hermanchen82/mpp.git 下载之后 进到 rkmpp\build\linux\aarch64目录 armbian的不需要做任何修改 直接执行 make-Makefiles.bash make && make install 编译完成后 会安装到…

《深度学习》YOLO V4 整体架构的由来及用法 详解

目录 一、关于YOLOv4 1、什么是YOLOv4 2、相较于YOLOv3 二、YOLO v4数据增强的做法 1、 Bag of freebies 2、关于数据增强 1&#xff09;概念 2&#xff09;种类 3、v4数据增强方法 1&#xff09;马赛克数据增强 • 由来 • 关于CutMix&#xff1a; 2&#xff09;…

【VR】PICO 手部追踪 steamvr内无法识别,依旧识别手柄的解决方案

一、问题描述 && 原因分析 1.PICO4 手部追踪 steamvr内无法识别&#xff0c;依旧识别手柄的解决方案 尽管平放&#xff08;或关闭手柄连接&#xff09;之后&#xff0c;在 PICO 一体机中进入了手部追踪状态&#xff0c; 但只要进入 steamvr&#xff0c;就无法正确识别…

Go 中的泛型,日常如何使用

泛型从 go 的 1.18 开始支持 什么是泛型编程 在泛型出现之前&#xff0c;如果需要计算两数之和&#xff0c;可能会这样写&#xff1a; func Add(a, b int) int {returb a b } 这个很简单&#xff0c;但是只能两个参数都是 int 类型的时候才能调用 如果想要计算两个浮点数…

年度目标5w浏览量达成

目录 前言&#xff1a;目标展示&#xff1a;达成展示&#xff1a; 前言&#xff1a; 去年定了一个目标&#xff0c;今年实现了&#xff0c;以后继续加油&#xff0c;争取2025可以获得15w的阅览量&#xff0c;3000的粉丝数量。 目标展示&#xff1a; 达成展示&#xff1a;

【Python TensorFlow】进阶指南(续篇一)

在前两篇文章中&#xff0c;我们介绍了TensorFlow的基础知识及其在实际应用中的初步使用&#xff0c;并探讨了更高级的功能和技术细节。本篇将继续深入探讨TensorFlow的高级应用&#xff0c;包括但不限于模型压缩、模型融合、迁移学习、强化学习等领域&#xff0c;帮助读者进一…

你不得不知的几种常见的向量数据库产品

产品介绍 在使用 LLM&#xff08;大型语言模型&#xff09;知识库时&#xff0c;经常会用到以下几种向量数据库&#xff1a; Milvus&#xff1a;这是一款开源的向量数据库&#xff0c;具有高度可扩展性和高性能。它支持多种向量相似性搜索算法&#xff0c;适用于大规模数据处理…

企业IT架构转型之道:阿里巴巴中台战略思想与架构实战感想

文章目录 第一章&#xff1a;数据库水平扩展第二章&#xff1a;中台战略第三章&#xff1a;阿里分布式服务架构HSF&#xff08;high speed Framework&#xff09;、早期Dubbo第四章&#xff1a;共享服务中心建设原则第五章&#xff1a;数据拆分实现数据库能力线性扩展第六章&am…

征程 6 工具链性能分析与优化 2|模型性能优化建议

01 引言 为了应对低、中、高阶智驾场景&#xff0c;以及当前 AI 模型在工业界的应用趋势&#xff0c;地平线推出了征程 6 系列芯片。 在软硬件架构方面&#xff0c;征程 6 不仅保持了对传统 CNN 网络的高效支持能力&#xff0c;还强化了对 Transformer 类型网络的支持&#xf…

字符编码和字符集

1. 字符编码和字符集 1.1. 字符编码 编码&#xff1a;字符 –>字节解码&#xff1a;字节 –>字符字符编码Character Encoding : 就是一套自然语言的字符与二进制数之间的对应规则。 1.2. 字符集 字符集 Charset&#xff1a;是一个系统支持的所有字符的集合&#xff0…

Kafka面试题解答(二)

1.怎么尽可能保证 Kafka 的可靠性 kafka是可能会出现数据丢失问题的&#xff0c;Leader维护了一个动态的in-sync replica set&#xff08;ISR&#xff09;&#xff0c;意为和 Leader保持同步的FollowerLeader集合(leader&#xff1a;0&#xff0c;isr:0,1,2)。 如果Follower长…