softmax回归的从零实现(附代码)

softmax回归是一个多分类模型,但是他跟线性回归一样将输入特征与权重做线性叠加,与线性不同的是他有多个输出,输出的个数对应分类标签的个数,比如四个特征和三种输出动物类别,则权重包含12个标量(带下标的w),偏差包含三个标量(带下标的b),且对每个输入计算o1,o2,o3

然后再对这些输出值进行softmax‘运算,softmax也是单层模型


import torch
from IPython import display
from d2l import torch as d2lbatch_size=256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)#初始化模型参数
num_inputs = 784
num_outputs = 10
w = torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)
b = torch.zeros(num_outputs,requires_grad=True)X = torch.tensor([[1.,2.,3.],[4.,5.,6.]])
X.sum(0,keepdim=True),X.sum(1,keepdim=True)Output:  (tensor([[5., 7., 9.]]),tensor([[ 6.],[15.]]))#定义softmax回归
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1,keepdim=True)return X_exp/partition#这里应用了广播机制x= torch.normal(0,1,(2,5))
x_prob= softmax(x)
x_prob,x_prob.sum(1)Output:   (tensor([[0.0902, 0.0850, 0.2683, 0.1946, 0.3619],[0.0551, 0.4104, 0.2667, 0.1486, 0.1192]]),tensor([1., 1.]))#实现softmax回归
def net(X):return softmax(torch.matmul(X.reshape((-1,w.shape[0])),w)+b)y = torch.tensor([0,2])
y_hat= torch.tensor([[0.1,0.3,0.6],[0.3,0.3,0.5]])
y_hat[[0,1],y]Output:   tensor([0.1000, 0.5000])#定义损失函数
def cross_entropy(y_hat,y):return -torch.log(y_hat[range(len(y_hat)),y])
cross_entropy(y_hat,y)Output:  tensor([2.3026, 0.6931])#分类精度
def accuracy(y_hat,y):if(len(y_hat.shape)>1 and y_hat.shape[0]>1):y_hat=y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) ==yreturn float(cmp.type(y.dtype).sum())
accuracy(y_hat,y)/len(y)Output:  0.5#我们可以评估任意模型的net的准确率
def evaluate_accuracy(net,data_iter):if isinstance(net,torch.nn.Module):net.eval()#将模型设置为评估模式metric = Accumulator(2)#正确预测数,预测总数,是一个累加的迭代器for X,y in data_iter:metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1]class Accumulator:def __init__(self,n):self.data=[0.0]*ndef add(self,*args):self.data=[a+float(b) for a,b in zip(self.data,args)]def reset(self):self.data=[0.0]*len(self.data)def __getitem__(self,idx):return self.data[idx]
evaluate_accuracy(net,test_iter)Output:  0.1196def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]class Animator:  #@save"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)#训练函数
def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):animator = Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0.3,0.9],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net,train_iter,loss,updater)test_acc=evaluate_accuracy(net,test_iter)animator.add(epoch+1,train_metrics+(test_acc,))train_loss,train_acc = train_metrics
#     assert train_loss<0.5,train_loss
#     assert train_acc <=1 and train_acc>0.7,train_acc
#     assert test_acc <=1 and test_acc>0.7,test_acclr = 0.1
#设置优化函数
def updater(batch_size):return d2l.sgd([w,b],lr,batch_size)num_epochs =10
train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)

Output:


#预测
def predict_ch3(net,test_iter, n=6):for X,y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds=d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + '\n'+pred for true,pred in zip(trues,preds)]d2l.show_images(X[0:n].reshape((n,28,28)),1,n,titles=titles[0:n])
predict_ch3(net,test_iter)

output:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1536873.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习之线性代数预备知识点

概念定义公式/案例标量(Scalar)一个单独的数值&#xff0c;表示单一的量。例如&#xff1a;5, 3.14, -2向量 (Vector)一维数组&#xff0c;表示具有方向和大小的量。 &#xff0c;表示三维空间中的向量 模(Magnitude)向量的长度&#xff0c;也称为范数&#xff08;通常为L2范数…

HCIA--实验十六:ACL通信实验(2)

2.高级ACL配置 一、实验内容 1.需求/要求&#xff1a; 使用三台PC和一台交换机&#xff0c;在交换机上配置高级ACL&#xff0c;测试PC1、PC2、PC3间的连通性。 二、实验过程 1.拓扑图&#xff1a; 2.步骤&#xff1a; 1.给PC3配置ip地址&#xff1a; 2.给交换机SW3配置高…

Hello,Spring Boot...

今天开启了Spring Boot学习之旅。 首先就是&#xff0c;JDK、Maven、IDEA以及各种官网的下载、安装与配置 然后通过组件创建小类&#xff0c;最让人头痛的就是&#xff0c;这个spring-boot-starter-thymeleaf&#xff0c;下错版本了 其他的一切顺利&#xff0c;自动化明显 最后…

2024最新版mysql数据库表的查询操作-总结

序言 1、MySQL表操作(创建表&#xff0c;查询表结构&#xff0c;更改表字段等)&#xff0c; 2、MySQL的数据类型(CHAR、VARCHAR、BLOB,等)&#xff0c; 本节比较重要&#xff0c;对数据表数据进行查询操作&#xff0c;其中可能大家不熟悉的就对于INNER JOIN(内连接)、LEFT JOIN…

Learn ComputeShader 15 Grass

1.Using Blender to create a single grass clump 首先blender与unity的坐标轴不同&#xff0c;z轴向上&#xff0c;不是y轴 通过小键盘的数字键可以快速切换视图&#xff0c;选中物体以后按下小键盘的点可以将物体聚焦于屏幕中心 首先我们创建一个平面&#xff0c;宽度为0.2…

SpringBoot中使用EasyExcel并行导出多个excel文件并压缩zip后下载

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

SysML图例-农业无人机

DDD领域驱动设计批评文集>> 《软件方法》强化自测题集>> 《软件方法》各章合集>>

dll修复工具4DDiG DLL Fixer,解决电脑dll丢失问题

4DDiG DLL Fixer是一款专业的DLL修复工具&#xff0c;旨在解决Windows系统中各种DLL相关问题。该工具能够快速全面地扫描计算机&#xff0c;检测并修复导致程序功能异常的DLL错误。它支持一键式操作&#xff0c;自动扫描、识别和替换缺失或损坏的DLL文件&#xff0c;从而帮助用…

推荐3款AIai论文大纲一键生成文献,精选整理!

在当前的学术写作环境中&#xff0c;AI论文大纲生成工具已经成为许多学者和学生的重要助手。这些工具不仅能够快速生成高质量的论文大纲&#xff0c;还能提供内容填充、文献引用和查重修改等全方位的服务。以下是三款值得推荐的AI论文大纲一键生成文献工具&#xff1a;千笔-AIP…

爬虫--翻页tips

免责声明&#xff1a;本文仅做分享&#xff01; 伪线程 from DrissionPage import ChromiumPage import timepage ChromiumPage() page.get("https://you.ctrip.com/sight/taian746.html") # 初始化 第0页 index_page 0# 翻页点击函数 sleep def page_turn():page…

C/C++语言基础--从C到C++的不同(下),15个部分说明C与C++的不同

本专栏目的 更新C/C的基础语法&#xff0c;包括C的一些新特性 前言 1-10在上篇C/C语言基础–从C到C的不同(上&#xff09;&#xff1b;当然C和C的不同还有很多&#xff0c;本人暂时只总结这些&#xff0c;其他的慢慢更新&#xff1b;上一篇C/C语言基础–从C到C的不同(上&…

node.js 中的进程和线程工作原理

本文所有的代码均基于 node.js 14 LTS 版本分析 概念 进程是对正在运行中的程序的一个抽象&#xff0c;是系统进行资源分配和调度的基本单位&#xff0c;操作系统的其他所有内容都是围绕着进程展开的 线程是操作系统能够进行运算调度的最小单位&#xff0c;其是进程中的一个执…

康养小站:长者舒缓疼痛的港湾

【导语】在老龄化日益加剧的当下&#xff0c;如何关爱和照顾好长者&#xff0c;成为社会关注的焦点。近日&#xff0c;笔者走进深圳宝安区一家专注于长者康养的社区小站&#xff0c;探访它如何帮助长者缓解疼痛&#xff0c;提高生活质量。 随着我国人口老龄化问题日益显著&…

算法:30.串联所有单词的子串

题目 链接&#xff1a;leetcode链接 思路分析&#xff08;滑动窗口&#xff09; 这道题目类似寻找异位词的题目&#xff0c;我认为是寻找异位词的升级版 传送门:寻找异位词 为什么说像呢&#xff1f; 注意&#xff1a;这道题目中words数组里面的字符串长度都是相同的&…

[JAVA]介绍怎样在Java中通过字节字符流实现文件读取与写入

一&#xff0c;初识File类及其常用方法 File类是java.io包下代表与平台无关的文件和目录&#xff0c;程序中操作文件和目录&#xff0c;都可以通过File类来完成。 通过这个File对象&#xff0c;可以进行一系列与文件相关的操作&#xff0c;比如判断文件是否存在&#xff0c;获…

Java毕业设计 基于SpringBoot和Vue药店管理系统

Java毕业设计 基于SpringBoot和Vue药店管理系统 这篇博文将介绍一个基于SpringBoot框架和Vue开发的药店管理系统&#xff0c;适合用于Java毕业设计。 功能介绍 首页 图片轮播 登录 注册 药品信息 药品详情 评论 收藏 购买 添加到购物车 用药指南 公告资讯 购物车 …

在深圳停车场我居然能看到很漂亮的瓦房

石岩街道在宝安确实是小透明哈&#xff0c;从市区搬到石岩快4年了&#xff0c;确实这里的建筑特别像老家的感觉&#xff0c;马路很狭窄。如果是开车的话&#xff0c;我是不会进入罗租大道来着&#xff0c;人车太复杂。由于上屋社区适合儿童的室内场所太少了&#xff0c;石岩这块…

python之模块和包的导入与使用,pip的使用(13)

文章目录 1、模块1.1 模块的分类1.1.1 内置模块1.1.2 第三方模块&#xff08;比较重要&#xff09;1.1.3 自定义模块 1.2 模块的导入1.2.1 单个模块的导入1.2.2 同时导入多个模块1.2.3 模块导入规范1.2.4 给导入的模块取别名1.2.5 同时导入模块和名字1.2.6 给导入的名字取别名扩…

【Python机器学习】序列到序列建模——使用序列到序列网络构建一个聊天机器人

为了寻聊天机器人&#xff0c;下面使用康奈尔电影对话语料库训练一个序列到序列的网络来“适当的”湖大问题或语句。以下聊天机器人示例采用的是Keras blog中的序列到序列的示例。 为训练准备语料库 首先&#xff0c;需要加载语料库并从中生成训练集&#xff0c;训练数据将决…

【刷题】Day5--数字在升序数组中出现的次数

Hi! 今日份刷题~ 数字在升序数组中出现的次数_牛客题霸_牛客网 我感觉题目简单&#xff0c;我的解答也很简单&#xff0c;二分法遗忘&#xff0c;有时间复习一下尝试新的解法。 /*** 代码中的类名、方法名、参数名已经指定&#xff0c;请勿修改&#xff0c;直接返回方法规定的…