李沐深度学习记录3:11模型选择、欠拟合和过拟合

通过多项式拟合探索欠拟合与过拟合

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l#生成数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))#生成均值为0,标准差为1的正态分布概率密度随机数
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)  #加上噪声#分别对应常数系数、x值、x的次方再除以阶乘、算得的y值
true_w,features,poly_features,labels=[torch.tensor(x,dtype=torch.float32) for x in [true_w,features,poly_features,labels]]
true_w[:2],features[:2],poly_features[:2,:],labels[:2]

在这里插入图片描述

#实现函数评估模型损失
def evaluate_loss(net,data_iter,loss):'''评估给定数据集上模型的损失'''metric=d2l.Accumulator(2) #记录 损失的总和,样本数量for X,y in data_iter:out=net(X)y=y.reshape(out.shape)l=loss(out,y)metric.add(l.sum(),l.numel())return metric[0]/metric[1]#from torch.utils import data#定义训练函数
def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss(reduction='none')input_shape = train_features.shape[-1]# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])#print(train_labels.shape)   [100]#print(train_labels.shape[0])   100#print(batch_size)   10train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)#等价于以下两行代码
#     train_dataset=data.TensorDataset(train_features,train_labels.reshape(-1,1))
#     train_iter=data.DataLoader(train_dataset,batch_size,shuffle=True)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer) #训练模型一个迭代周期if epoch == 0 or (epoch + 1) % 20 == 0:   #每20个epoch训练迭代周期,评估一次模型损失(包括训练集和测试集),画一次数据点animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),  #评估训练数据集上模型损失evaluate_loss(net, test_iter, loss)))  #评估测试数据集上模型损失print('weight:', net[0].weight.data.numpy())  #输出最终模型的参数权重
#使用三阶多项式拟合,与数据生成函数的阶数相同(正常)
# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

在这里插入图片描述

#使用线性函数拟合非线性函数(这里是三阶多项式函数),线性模型很容易欠拟合
# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

![在这里插入图片描述](https://img-blog.csdnimg.cn/10ad87a962cf40ccb39909b0fa84c7ea.png

#使用一个阶数过高的复杂多项式模型来训练会造成过拟合。在这种情况下,没有足够的数据用于学到高阶系数应该具有接近于零的值。 因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。 虽然训练损失可以有效地降低,但测试损失仍然很高。 
# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:], num_epochs=1500)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/149166.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

园林园艺服务经营小程序商城的作用是什么

园林园艺属于高单价服务,同时还有各种衍生服务,对企业来说,多数情况下都是线下生意拓展及合作等,但其实线上也有一定深度,如服务售卖或园艺产品售卖等。 基于线上发展可以增强获客引流、品牌传播、产品销售经营、会员…

很普通的四非生,保研破局经验贴

推免之路 个人情况简介夏令营深圳大学情况机试面试结果 预推免湖南师范大学面试结果 安徽大学面试结果 北京科技大学笔试面试结果 合肥工业大学南京航空航天大学面试结果 暨南大学东北大学 最终结果一些建议写在后面 个人情况简介 教育水平:某中医药院校的医学信息…

STL-stack、queue和priority_queue的模拟实现

目录 一、容器适配器 (一)什么是适配器 (二)stack和queue的底层结构 二、Stack 三、queue 四、deque双端队列 (一)优点 (二)缺陷 五、优先级队列 (一&#xff…

成都建筑模板批发市场在哪?

成都作为中国西南地区的重要城市,建筑业蓬勃发展,建筑模板作为建筑施工的重要材料之一,在成都也有着广泛的需求。如果您正在寻找成都的建筑模板批发市场,广西贵港市能强优品木业有限公司是一家值得关注的供应商。广西贵港市能强优…

华为云云耀云服务器L实例评测|Ubuntu云锁防火墙安装搭建使用

华为云云耀云服务器L实例评测|Ubuntu安装云锁防火墙对抗服务器入侵和网络攻击 1.前言概述 华为云耀云服务器L实例是新一代开箱即用、面向中小企业和开发者打造的全新轻量应用云服务器。多种产品规格,满足您对成本、性能及技术创新的诉求。云耀云服务器L…

基于阴阳对优化的BP神经网络(分类应用) - 附代码

基于阴阳对优化的BP神经网络(分类应用) - 附代码 文章目录 基于阴阳对优化的BP神经网络(分类应用) - 附代码1.鸢尾花iris数据介绍2.数据集整理3.阴阳对优化BP神经网络3.1 BP神经网络参数设置3.2 阴阳对算法应用 4.测试结果&#x…

数据结构与算法--算法

这里写目录标题 线性表顺序表链表插入删除算法 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 线性表 顺序表 链表 插入删除算法 步骤 1.通过循环到达指定位置的前一个位置 2.新建…

VS的调式技巧你真的掌握了吗?

目录 什么是bug? 调式是什么?有多重要? 调试是什么? 调试的基本步骤 debug和release的介绍 windows环境调试介绍 1.调试环境的准备 2.学会快捷键 F11 VS F10 F9 & F5 3.调试时查看程序当前信息 查看临时变量的值 查看内存信…

【物联网】STM32的中断机制不清楚?看这篇文章就足够了

在嵌入式系统中,中断是一种重要的机制,用于处理来自外部设备的异步事件。STM32系列微控制器提供了强大的中断控制器,可以方便地处理各种外部中断和内部中断。本文将详细介绍STM32中断的结构和使用方法。 文章目录 1. 什么叫中断2. 中断优先级…

<学习笔记>从零开始自学Python-之-常用库篇(十二)Matplotlib

Matplotlib 是Python中类似 MATLAB的绘图工具,Matplotlib是Python中最常用的可视化工具之一,可以非常方便地创建2D图表和一些基本的3D图表,可根据数据集(DataFrame,Series)自行定义x,y轴,绘制图…

IntelliJ IDEA配置Cplex12.6.3详细步骤

Cplex12.6.3版IntelliJ IDEA配置详细步骤 一、Cplex12.6.3版下载地址二、Cplex安装步骤三、IDEA配置CPLEX3.1 添加CPLEX安装目录的cplex.jar包到项目文件中3.2 将CPLEX的x64_win64文件夹添加到IDEA的VM options中 四、检查IDEA中Cplex是否安装成功卸载Cplex 一、Cplex12.6.3版下…

Docker通过Dockerfile创建Redis、Nginx--详细过程

创建Nginx镜像 我们先创建一个目录,在目录里创建Dockerfile [rootdocker-3 ~]# mkdir mynginx [rootdocker-3 ~]# cd mynginx [rootdocker-3 ~]# vim Dockerfile Dockerfile的内容 FROM daocloud.io/library/centos:7 RUN buildDepsreadline-devel pcre-devel o…

代码:对鱼眼相机图像进行去畸变处理

图像投影模型:针孔[fx, fy, cx, cy] 图像畸变模型:切向径向畸变[k1, k2, p1, p2] 说明:用于备忘 第一部分是常规的去畸变操作,在已知内参的情况下对鱼眼相机进行去畸变,这里使用的是remap映射在对图像去畸变后&#x…

竞赛 机器视觉的试卷批改系统 - opencv python 视觉识别

文章目录 0 简介1 项目背景2 项目目的3 系统设计3.1 目标对象3.2 系统架构3.3 软件设计方案 4 图像预处理4.1 灰度二值化4.2 形态学处理4.3 算式提取4.4 倾斜校正4.5 字符分割 5 字符识别5.1 支持向量机原理5.2 基于SVM的字符识别5.3 SVM算法实现 6 算法测试7 系统实现8 最后 0…

Windows下启动freeRDP并自适应远端桌面大小

几个二进制文件 xfreerdp # Linux下的,an X11 Remote Desktop Protocol (RDP) client which is part of the FreeRDP project wfreerdp.exe # Windows下的,freerdp2.0 主程序,freerdp3.0将废弃 sdl-freerdp.exe # Windows下的&…

appscan的两种手动探索扫描方式

文章目录 一、使用火狐FoxyProxy浏览器代理探索二、使用appscan内置浏览器探索 一、使用火狐FoxyProxy浏览器代理探索 首先火狐浏览器需安装FoxyProxy 先在扩展和主题里搜FoxyProxy 选FoxyProxy Standard,然后添加到浏览器就行 添加后浏览器右上角会有这个插件 打开apps…

【算法学习】-【双指针】-【快乐数】

LeetCode原题链接:202. 快乐数 下面是题目描述: 「快乐数」 定义为: 对于一个正整数,每一次将该数替换为它每个位置上的数字的平方和。 然后重复这个过程直到这个数变为 1,也可能是 无限循环 但始终变不到 1。 如果…

cad图纸如何防止盗图(一个的制造设计型企业如何保护设计图纸文件)

在现代企业中,设计图纸是公司的重要知识产权,关系到公司的核心竞争力。然而,随着技术的发展,员工获取和传播设计图纸的途径越来越多样化,如何有效地防止员工复制设计图纸成为了企业管理的一大挑战。本文将从技术、管理…

【动手学深度学习-Pytorch版】Transformer代码总结

本文是纯纯的撸代码讲解,没有任何Transformer的基础内容~ 是从0榨干Transformer代码系列,借用的是李沐老师上课时讲解的代码。 本文是根据每个模块的实现过程来进行讲解的。如果您想获取关于Transformer具体的实现细节(不含代码)可…

MySQL的复合查询

文章目录 1. 多表查询2. 自连接3. 子查询3.1 单行子查询3.2 多行单列子查询3.3 单行多列子查询3.4 在from子句中使用子查询 4. 合并查询4.1 union all4.2 union 5. 内连接6. 外连接6.1 左外连接6.2 右外连接 1. 多表查询 前面我们讲解的mysql表的查询都是对一张表进行查询&…