Pytorch深度学习实践(5)逻辑回归

逻辑回归

逻辑回归主要是解决分类问题

  • 回归任务:结果是一个连续的实数
  • 分类任务:结果是一个离散的值

分类任务不能直接使用回归去预测,比如在手写识别中(识别手写 0 − − 9 0 -- 9 09),因为各个类别之间没有大小之差。

因此,对于分类问题,我们最终的输出是个概率,即属于某个类别的概率是多少,然后从概率集合里找最大值,作为当前预测的结果

下载MNIST数据集

import torchvision
train_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=True, download=True)
test_set = torchvision.dataset.MNIST(root="../dataset/mnist", train=False, dowload=True)
  • 通过train参数来指定训练集和测试集

逻辑回归

将之前的学习时长—考试分数转化为二分类任务,即学习时长—是否通过考试

x(hours)y(pass/fail)
10(fail)
20(fail)
31(pass)
4?

其中, P ( y ^ = 1 ) + P ( y ^ = 0 ) = 1 P(\hat y = 1) + P(\hat y = 0) = 1 P(y^=1)+P(y^=0)=1

当输出的概率在 0.5 0.5 0.5附近时,即模型不确定,因此通常会输出一个不确定的值

对于二分类任务,逻辑回归会先使用回归,生成一个得分值,即落在实数集区间内,然后再使用 s i g m o i d sigmoid sigmoid函数,将得分值映射到 [ 0 , 1 ] [0, 1] [0,1]区间内,得到预测概率

s i g m o i d sigmoid sigmoid函数
σ ( x ) = 1 1 + e − x \sigma (x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述

S i g m o i d Sigmoid Sigmoid常用来做二分类任务,常具备三个特征:

  • 饱和函数
  • 单调递增
  • 有极限

当我们使用线性回归来得到逻辑回归的得分值时,逻辑回归模型的函数定义就如下所示:
y ^ = σ ( x ∗ ω + b ) \hat y = \sigma (x*\omega + b) y^=σ(xω+b)

损失函数

线性回归使用的损失函数是计算预测值和真实值之差

而对于逻辑回归,由于我们得到的是概率,是一个 0 − 1 0-1 01分布,因此需要修改损失函数
l o s s = − ( y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ) loss = -(ylog\hat y + (1-y)log(1-\hat y)) loss=(ylogy^+(1y)log(1y^))
即我们比较的是分布之间的差异

交叉熵 c r o s s − e n t r o p y cross-entropy crossentropy

存在两个分布 P D 1 ( x ) P_{D1}(x) PD1(x) P D 2 ( x ) P_{D2}(x) PD2(x)

两个分布的差异程度使用公式: ∑ i = 1 n P D 1 ( x i ) l n P D 2 ( x i ) \sum_{i=1}^{n}P_{D1}(x_i)lnP_{D2}(x_i) i=1nPD1(xi)lnPD2(xi) 来衡量

上述公式越大时,两个分布的差异越小

模型的改变

模型构造的改变
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = F.sigmoid(self.linear(x))return y_pred

需要先将输入写入到linear()线性模型中,再使用Sigmoid()函数

模型损失函数的改变

使用交叉熵函数BCELoss

criterion = torch.nn.BCELoss(size_average=False)

整体代码

import torch
import matplotlib.pyplot as plt
import numpy as np########## 数据集准备 ##########
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])########## 模型定义 ##########
class LogisticRegressionModel(torch.nn.Module):def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1)# 由于逻辑回归中Sigmoid函数不需要传参 所以在forward中直接计算即可# 在这里不需要实例化def forward(self, x):y_pred = torch.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel()########## 损失函数和优化器的设置 ##########
criterion = torch.nn.BCELoss(size_average=False) # BCELoss -- 交叉熵函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)########## 模型训练 ##########
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data)print(epoch, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()########## 模型测试 ##########
x = np.linspace(0, 10, 200)
x_test = torch.Tensor(x).view((200, 1)) # view()相当于reshape
y_test = model(x_test)
y = y_test.data.numpy()  # 转化为np类型
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], 'r--')
plt.xlabel("Hours")
plt.ylabel("Probability of Pass")
plt.grid()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1487957.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【C++初阶】string类

【C初阶】string类 🥕个人主页:开敲🍉 🔥所属专栏:C🥭 🌼文章目录🌼 1. 为什么学习string类? 1.1 C语言中的字符串 1.2 实际中 2. 标准库中的string类 2.1 string类 2.…

栈-链栈的表示和实现

#include<stdio.h> typedef int Status; typedef int sElemType; //链栈 typedef struct StackNode{sElemType data;struct StackNode *next; }StackNode,*LinkStack; StackNode *p; //初始化 Status InitStack(LinkStack &S){SNULL;return 1; } //判空 Status Emp…

[用AI日进斗金系列]用码上飞在企微接单开发一个项目管理系统!

今天是【日进斗金】系列的第二期文章。 先给不了解这个系列的朋友们介绍一下&#xff0c;在这个系列的文章中&#xff0c;我们将会在企微的工作台的“需求发布页面”中寻找有软件开发需求的用户 并通过自研的L4级自动化智能软件开发平台「码上飞CodeFlying」让AI生成应用以解…

一文搞懂深度信念网络!DBN概念介绍与Pytorch实战

前言 本文深入探讨了深度信念网络DBN的核心概念、结构、Pytorch实战&#xff0c;分析其在深度学习网络中的定位、潜力与应用场景。 一、概述 1.1 深度信念网络的概述 深度信念网络&#xff08;Deep Belief Networks, DBNs&#xff09;是一种深度学习模型&#xff0c;代表了一…

Ruoyi-WMS部署

所需软件 1、JDK&#xff1a;8 安装包&#xff1a;https://www.oracle.com/java/technologies/javase/javase8-archive-downloads.htmlopen in new window 安装文档&#xff1a;https://cloud.tencent.com/developer/article/1698454open in new window 2、Redis 3.0 安装包&a…

跨域浏览器解决前端跨域问题

1.问题背景 这是一种属于非主流的解决跨域的方案&#xff0c;但是也是可以正常使用而且比较简单的。如果需要使用主流的解决前端跨域方案&#xff0c;请参考这篇文章。 我这边其实是优先建议大家使用主流的跨域方案&#xff0c;如果主流的实在不行&#xff0c;那么就使用跨域…

redis:清除缓存的最简单命令示例

清除redis缓存命令 1.打开cmd窗口&#xff0c;并cd进入redis所在目录 2.登录redis redis-cli 3.查询指定队列当前的记录数 llen 队列名称 4.清除指定队列所有记录 ltrim 队列名称 1 0 5.再次查询&#xff0c;确认队列的记录数是否已清除

用uniapp 及socket.io做一个简单聊天app 2

在这里只有群聊&#xff0c;二个好友聊天&#xff0c;可以认为是建了一个二人的群聊。 const express require(express); const http require(http); const socketIo require(socket.io); const cors require(cors); // 引入 cors 中间件const app express(); const serv…

探索算法系列 - 双指针

目录 移动零&#xff08;原题链接&#xff09; 复写零&#xff08;原题链接&#xff09; 快乐数&#xff08;原题链接&#xff09; 盛最多水的容器&#xff08;原题链接&#xff09; 有效三角形的个数&#xff08;原题链接&#xff09; 查找总价格为目标值的两个商品&…

科研绘图系列:R语言组合堆积图(stacked barplot with multiple groups)

介绍 通常堆积图的X轴表示样本,样本可能会存在较多的分组信息,通过组合堆积图和样本标签分组信息,我们可以得到一张能展示更多信息的可发表图形。 加载R包 knitr::opts_chunk$set(warning = F, message = F) library(tidyverse) library(cowplot) library(patchwork)导入…

GDAL访问HDFS集群中的数据

1.集群搭建 参考文章&#xff1a;hadoop2.10.0完全分布式集群搭建 HA(QJM)高可用集群搭建_hadoop 2.10 ha-CSDN博客 创建文件夹 hdfs dfs -mkdir -p hdfs://192.168.80.132:9000/test 开放权限 hdfs dfs -chmod -R 777 /test 上传文件 hadoop fs -put /home/wh/data/res…

JavaScript(16)——定时器-间歇函数

开启定时器 setInterval(函数,间隔时间) 作用&#xff1a;每隔一段时间调用这个函数&#xff0c;时间单位是毫秒 例如&#xff1a;每一秒打印一个hello setInterval(function () { document.write(hello ) }, 1000) 注&#xff1a;如果是具名函数的话不能加小括号&#xf…

【论文复现】Vision Transformer(ViT)

1. Transformer结构 1.1 编码器和解码器 翻译这个过程需要中间体。也就是说&#xff0c;编码&#xff0c;解码之间需要一个中介&#xff0c;英文先编码成一个意思&#xff0c;再解码成中文。 那么查字典这个过程就是编码和解码的体现。首先我们的大脑会把它编码&#xff0c;编…

数仓架构解析(第45天)

系列文章目录 经典数仓架构传统离线大数据架构 文章目录 系列文章目录烂橙子-终生成长社群群主&#xff0c;前言1. 经典数仓架构2. 传统离线大数据架构 烂橙子-终生成长社群群主&#xff0c; 采取邀约模式&#xff0c;不支持付费进入。 前言 经典数仓架构 传统离线大数据架…

2000-2023年上市公司融资约束指数FC指数(含原始数据+计算结果)

2000-2023年上市公司融资约束指数FC指数&#xff08;含原始数据计算结果&#xff09; 1、时间&#xff1a;2000-2023年 2、来源&#xff1a;上市公司年报 3、指标&#xff1a;证券代码、证券简称、统计截止日期、是否剔除ST或*ST或PT股、是否剔除上市不满一年、已经退市或被…

Linus: vim编辑器的使用,快捷键及配置等周边知识详解

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 vim的安装创建新用户 adduser 用户名Linus是个多用户的操作系统是否有创建用户的权限查看当前用户身份:whoami** 怎么创建设置密码passwdsudo提权(sudo输入的是用户…

前端网页打开PC端本地的应用程序实现方案

最近开发有一个需求&#xff0c;网页端有个入口需要跳转三维大屏&#xff0c;而这个大屏是一个exe应用程序。产品需要点击这个入口&#xff0c;并打开这个应用程序。这个就类似于百度网盘网页跳转到PC端应用程序中。 这里我们采用添加自定义协议的方式打开该应用程序。一开始可…

前端:Vue学习 - 购物车项目

前端&#xff1a;Vue学习 - 购物车项目 1. json-server&#xff0c;生成后端接口2. 购物车项目 - 实现效果3. 参考代码 - Vuex 1. json-server&#xff0c;生成后端接口 全局安装json-server&#xff0c;json-server官网为&#xff1a;json-server npm install json-server -…

vue3前端开发-小兔鲜项目-登录和非登录状态下的模板适配

vue3前端开发-小兔鲜项目-登录和非登录状态下的模板适配&#xff01;有了上次的内容铺垫&#xff0c;我们可以根据用户的token来判定&#xff0c;到底是显示什么内容了。 1&#xff1a;我们在对应的导航组件内修改完善一下内容即可。 <script setup> import { useUserSt…

深入理解TensorFlow底层架构

目录 深入理解TensorFlow底层架构 一、概述 二、TensorFlow核心概念 计算图 张量 三、TensorFlow架构组件 前端 后端 四、分布式计算 集群管理 并行计算 五、性能优化 内存管理 XLA编译 六、总结与展望 深入理解TensorFlow底层架构 一、概述 TensorFlow是一个开…