从0开始机器学习--Day17--神经网络反向传播作业

题目:识别数字0-9,做梯度检测来验证是否在梯度下降过程中存在问题,并可视化隐藏层

代码:

import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from scipy.optimize import minimizedef sigmoid(z):return 1/(1+np.exp(-z))def sigmoid_derivation(z): # sigmoid函数求导return sigmoid(z)*(1-sigmoid(z))def one_hot(raw_y):result = []for i in raw_y: # 1-10y_temp = np.zeros(10)y_temp[i-1] = 1result.append(y_temp)return np.array(result) #返回成数组的格式def sequence(theta1,theta2): #序列化return np.append(theta1.flatten(), theta2.flatten())  # 方便后续调用scipy库方便, minimize要求初始化参数x0# 只有1列,def return_sequence(theta_sequence): #解序列化theta1 = theta_sequence[:25*401].reshape(25, 401)  # 解序列化 保证后续矩阵运算维度是一致的theta2 = theta_sequence[25*401:].reshape(10, 26)return theta1, theta2def forward_propagation(theta_sequence, X):theta1, theta2 = return_sequence(theta_sequence)a1 = Xz2 = a1@theta1.Ta2 = sigmoid(z2)a2 = np.insert(a2, 0, values=1, axis=1)z3 = a2@theta2.Th = sigmoid(z3)return a1, z2, a2, z3, hdef cost_function(theta_sequence, X, y):a1, z2, a2, z3, h = forward_propagation(theta_sequence, X)J = (-np.sum(y*np.log(h)+(1-y)*np.log(1-h)))/len(X)return Jdef reg_cost_function(theta_sequence, X, y, l=1):first = np.sum(np.power(theta1[:, 1:], 2))second =np.sum(np.power(theta2[:, 1:], 2))reg = (first + second) * l / (2 * len(X))return reg + cost_function(theta_sequence, X, y)def gradient(theta_sequence, X, y): # 反向传播计算误差deltatheta1, theta2 = return_sequence(theta_sequence)a1, z2, a2, z3, h = forward_propagation(theta_sequence, X)d3 = h-yd2 = d3@theta2[:,1:]*sigmoid_derivation(z2)D2 = (d3.T@a2) / len(X)D1 = (d2.T@a1) / len(X)return sequence(D1, D2)def reg_gradient(theta_sequence, X, y, l=1): # 正则化D = gradient(theta_sequence, X, y)D1, D2 = return_sequence(D)theta1, theta2 = return_sequence(theta_sequence)D1[:, 1:] = D1[:, 1:] + theta1[:, 1:] * l / len(X)D2[:, 1:] = D2[:, 1:] + theta2[:, 1:] * l / len(X)return sequence(D1, D2)def neutral_network(X, y, l):init_theta = np.random.uniform(-0.5, 0.5, 10285) # 随机化初始值,避免全为0结果只有一个特征res = minimize(fun=reg_cost_function,x0=init_theta,args=(X, y, l),method='TNC',jac=reg_gradient,options={'maxiter': 300}) # 设置最大迭代次数为300return resdata = sio.loadmat('ex4data1.mat')
raw_x = data['X']
raw_y = data['y']
print(raw_y)
X = np.insert(raw_x, 0, values=1, axis=1) # 添加偏置单元
print(X.shape)y = one_hot(raw_y)
print(y)
print(y.shape)theta = sio.loadmat('ex4weights.mat')
theta1 = theta['Theta1']
theta2 = theta['Theta2']
print(theta1.shape)
print(theta2.shape)theta_sequence = sequence(theta1, theta2)print(reg_cost_function(theta_sequence, X, y, l=1))l = 10
res = neutral_network(X, y, l)
raw_y = data['y'].reshape(5000) # 降为一维方便后面进行梯度检验时的比较
a1, z2, a2, z3, h = forward_propagation(res.x, X)
y_pred = np.argmax(h, axis=1)+1 # 取最大
accrancy = np.mean(y_pred == raw_y)
print(accrancy)def hidden_layer(theta):theta1, theta2 = return_sequence(theta)hidden_layer = theta1[:, 1:]fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(8, 8), sharex=True, sharey=True)for r in range(5):for c in range(5):ax[r, c].imshow(hidden_layer[5 * r + c].reshape(20, 20).T,cmap='gray_r')plt.xticks([])plt.yticks([])plt.show()hidden_layer(res.x)

输出:

[[10][10][10]...[ 9][ 9][ 9]]
(5000, 401)
[[0. 0. 0. ... 0. 0. 1.][0. 0. 0. ... 0. 0. 1.][0. 0. 0. ... 0. 0. 1.]...[0. 0. 0. ... 0. 1. 0.][0. 0. 0. ... 0. 1. 0.][0. 0. 0. ... 0. 1. 0.]]
(5000, 10)
(25, 401)
(10, 26)
0.38376985909092365
0.9394进程已结束,退出代码0

可视化隐藏层

总结:与之前相比,这次代码中数学的运算多了很多,尤其是偏导部分;注意写代码前要多推导数学运算的过程不要出现差错;有所改进的是跟之前在minimize中加flatten相比,直接添加了一个函数对参数进行序列化操作来方便调用scipy库。

作业订正参考:【作业讲解】编程作业4:神经网络(2)(上)_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/7344.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

前端学习笔记-Ajax篇

第1章:原生AJAX 1.1Ajax简介 AAX 全称为 Asynchronous JavaScript And XML,就是异步的 JS 和 XML。 通过 AAX 可以在浏览器中向服务器发送异步请求,最大的优势:无刷新获取数据。 AAX 不是新的编程语言,而是一种将现有的标准组合在一起使用…

【Python爬虫实战】DrissionPage 与 ChromiumPage:高效网页自动化与数据抓取的双利器

🌈个人主页:易辰君-CSDN博客 🔥 系列专栏:https://blog.csdn.net/2401_86688088/category_12797772.html ​ 目录 前言 一、DrissionPage简介 (一)特点 (二)安装 (三…

Halcon基于laws纹理特征的SVM分类

与基于区域特征的 SVM 分类不同,针对图像特征的 SVM 分类的算子不需要直接提取 特征,下面介绍基于 Laws 纹理特征的 SVM 分类。 纹理在计算机视觉领域的图像分割、模式识别等方面都有着重要的意义和广泛的应 用。纹理是指由于物体表面的物理属性不同所…

Netty篇(入门编程)

目录 一、Hello World 1. 目标 2. 服务器端 3. 客户端 4. 流程梳理 💡 提示 5. 运行结果截图 二、Netty执行流程 1. 流程分析 2. 代码案例 2.1. 引入依赖 2.2. 服务端 服务端 服务端处理器 2.3. 客户端 客户端 客户端处理器 2.4. 代码截图 一、Hel…

从0开始学习机器学习--Day14--如何优化神经网络的代价函数

在上一篇文章中,解析了神经网络处理分类问题的过程,类似的,在处理多元分类问题时,神经网络会按照类型分成多个输出层的神经元来表示,如下: 处理4个分类问题时的神经网络 我们可以看到,相较于之…

除草机器人算法以及技术详解!

算法详解 图像识别与目标检测算法 Yolo算法:这是目标检测领域的一种常用算法,通过卷积神经网络对输入图像进行处理,将图像划分为多个网格,每个网格生成预测框,并通过非极大值抑制(NMS)筛选出最…

Android MavenCentral 仓库更新问题

MavenCentral 仓库更新问题 前言正文一、Maven central repository的账户迁移二、获取加密账户信息三、问题和解决方式① 问题1② 解决1③ 问题2④ 解决2 前言 在去年的3、4月份的时候我发布了一个开源库EasyView,在MavenCentral上,可以说当时发布的时候…

腾讯为什么支持开源?

今天看到一条新闻,感觉腾讯在 AI 大模型方面确实挺厉害的,符合它低调务实的风格,在不知不觉中一天竟然开源了两个核心的,重要的 AI 大模型。 据新闻报道,11月 5 日,腾讯混元宣布最新的 MoE 模型“混元 Larg…

学习了,踩到一个坑!

前言 踩坑了啊,最近踩了一个 lombok 的坑,有点意思,给你分享一波。 我之前写过一个公共的服务接口,这个接口已经有好几个系统对接并稳定运行了很长一段时间了,长到这个接口都已经交接给别的同事一年多了。 因为是基…

『Django』APIView基于类的用法

点赞 关注 收藏 学会了 本文简介 上一篇文章介绍了如何使用APIView创建各种请求方法,介绍的是通过函数的方式写接口。 本文要介绍 Django 提供的基于类(Class)来实现的 APIView 用法,代码写起来更简单。 APIView基于类的基…

CentOS系统查看CPU、内存、操作系统等信息

Linux系统提供了一系列命令可以用来查看系统硬件信息,如CPU的物理个数、核数、逻辑CPU数量、内存信息和操作系统版本。 查看物理CPU、核数和逻辑CPU 在多核、多线程的系统中,了解物理CPU个数、每个物理CPU的核数和逻辑CPU个数至关重要。超线程技术进一步…

DNS配置

1.搭建dns服务器能够对自定义的正向或者反向域完成数据解析查询。 2.配置从DNS服务器,对主dns服务器进行数据备份。 options {listen-on port 53 { 192.168.111.130; };directory "/var/named";allow-query { any;};zone "openlab.com&qu…

【WebRTC】WebRTC的简单使用

目录 1.下载2.官网上的使用3.本地的使用 参考: 【webRTC】一、windows编译webrtc Windows下WebRTC编译 1.下载 下载时需要注意更新python的版本和网络连接,可以先试试ping google。比较关键的步骤是 cd webrtc-checkout set https_proxy127.0.0.1:123…

使用axois自定义基础路径,自动拼接前端服务器地址怎么办

请求路径: http://localhost:5173/http://pcapi-xiaotuxian-front-devtest.itheima.net/home/category/head 很明显多拼接了路径地址 查看基础路径文件发现: //axios基础封装 import axios from axiosconst httpInstance axios.create({baseURL: /h…

第J5周:DenseNet+SE-Net实战

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 任务: ●1. 在DenseNet系列算法中插入SE-Net通道注意力机制,并完成猴痘病识别 ●2. 改进思路是否可以迁移到其他地方呢 ●3. 测试集acc…

力扣最热一百题——杨辉三角

目录 题目链接:118. 杨辉三角 - 力扣(LeetCode) 题目描述 示例 提示: 解法一:利用特性构建杨辉三角 1. 结果存储结构: 2. 初始化和循环遍历每一层: 3. 构建每一层: 4. 填充中间的元素&…

道品科技智慧农业中的自动气象检测站

随着科技的进步,智慧农业已经成为现代农业发展的重要方向。农业自动气象检测站作为智慧农业的一个关键组成部分,发挥着不可或缺的作用。本文将从工作原理、功能特点、应用场景以及主要作用等方面对农业自动气象检测站进行深入探讨。 ## 一、工作原理 农…

Android——多线程、线程通信、handler机制

Android——多线程、线程通信、handler机制 模拟网络请求&#xff0c;会阻塞主线程 private String getStringForNet() {StringBuilder stringBuilder new StringBuilder();for (int i 0; i < 100; i) {stringBuilder.append("字符串" i);}try {Thread.sleep(…

练习LabVIEW第三十三题

学习目标&#xff1a; 刚学了LabVIEW&#xff0c;在网上找了些题&#xff0c;练习一下LabVIEW&#xff0c;有不对不好不足的地方欢迎指正&#xff01; 第三十三题&#xff1a; 用labview编写一个判断素数的程序 开始编写&#xff1a; LabVIEW判断素数&#xff0c;首先要搞…

我要精通前端-布局方式理解总结

一、浮动 1、传统网页布局的三种方式 ​CSS 提供了三种传统布局方式(简单说,就是盒子如何进行排列顺序)&#xff1a; 1.普通流&#xff08;标准流&#xff09; 2.浮动 3.定位 这三种布局方式都是用来摆放盒子的&#xff0c;盒子摆放到合适位置&#xff0c;布局自然就完成了…