昇思25天学习打卡营第21天 | 基于MindSpore的红酒分类实验

内容简介

在这里插入图片描述

本实验介绍了使用MindSpore框架实现K近邻算法(KNN)对红酒数据集进行分类的全过程。通过数据读取、预处理、模型构建与预测,展示了KNN算法在红酒数据集上的应用。实验中详细解释了KNN的原理、距离度量方式及其在分类问题中的应用,最后通过验证集评估模型性能,验证了KNN算法在该3分类任务上的有效性。

实验代码及注释

# 导入必要的库
import os
import csv
import numpy as np
import matplotlib.pyplot as pltimport mindspore as ms
from mindspore import nn, ops# 设置MindSpore的运行环境
ms.set_context(device_target="CPU")# 读取数据集
with open('wine.data') as csv_file:data = list(csv.reader(csv_file, delimiter=','))
print(data[56:62]+data[130:133])# 数据处理
# 将数据集的13个属性作为自变量 X,将3个类别作为因变量 Y
X = np.array([[float(x) for x in s[1:]] for s in data[:178]], np.float32)
Y = np.array([s[0] for s in data[:178]], np.int32)# 可视化样本分布
attrs = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols','Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue','OD280/OD315 of diluted wines', 'Proline']
plt.figure(figsize=(10, 8))
for i in range(0, 4):plt.subplot(2, 2, i+1)a1, a2 = 2 * i, 2 * i + 1plt.scatter(X[:59, a1], X[:59, a2], label='1')plt.scatter(X[59:130, a1], X[59:130, a2], label='2')plt.scatter(X[130:, a1], X[130:, a2], label='3')plt.xlabel(attrs[a1])plt.ylabel(attrs[a2])plt.legend()
plt.show()# 划分训练集和测试集
train_idx = np.random.choice(178, 128, replace=False)
test_idx = np.array(list(set(range(178)) - set(train_idx)))
X_train, Y_train = X[train_idx], Y[train_idx]
X_test, Y_test = X[test_idx], Y[test_idx]# 构建KNN模型
class KnnNet(nn.Cell):def __init__(self, k):super(KnnNet, self).__init__()self.k = kdef construct(self, x, X_train):x_tile = ops.tile(x, (128, 1))  # 平铺输入x以匹配X_train中的样本数square_diff = ops.square(x_tile - X_train)square_dist = ops.sum(square_diff, 1)dist = ops.sqrt(square_dist)values, indices = ops.topk(-dist, self.k)  # -dist表示值越大,样本就越接近return indicesdef knn(knn_net, x, X_train, Y_train):x, X_train = ms.Tensor(x), ms.Tensor(X_train)indices = knn_net(x, X_train)topk_cls = [0]*len(indices.asnumpy())for idx in indices.asnumpy():topk_cls[Y_train[idx]] += 1cls = np.argmax(topk_cls)return cls# 模型预测
acc = 0
knn_net = KnnNet(5)
for x, y in zip(X_test, Y_test):pred = knn(knn_net, x, X_train, Y_train)acc += (pred == y)print('label: %d, prediction: %s' % (y, pred))
print('Validation accuracy is %f' % (acc/len(Y_test)))

学习心得

通过本次实验,我深入了解了K近邻算法(KNN)及其在分类任务中的应用。KNN是一种基于实例的学习算法,利用训练样本的多数表决结果来对新样本进行分类。它具有简单、直观的优点,但在大规模数据集上计算复杂度较高,因此需要在应用时进行适当的优化和改进。

在实验过程中,首先进行了数据读取与处理。Wine数据集包含13个属性,每个属性都对分类结果有不同程度的影响。通过可视化展示了不同类别样本在某两个属性上的分布,帮助我们直观地理解数据的可分性。接下来,通过划分训练集和测试集,保证了模型的训练和验证能够在不同的数据上进行,从而提高模型的泛化能力。

模型构建部分,通过使用MindSpore框架,利用其提供的高效算子如tile、square、ReduceSum等,构建了KNN模型。在计算距离时,选择了欧氏距离,并使用TopK算子找出距离最近的k个邻居。对于分类决策,采用了多数表决的方式,即统计k个邻居中每个类别的数量,选择最多的类别作为预测结果。
在这里插入图片描述

在验证阶段,取k值为5进行模型预测,验证精度约为70%。虽然准确率不算特别高,但对于一个简单的3分类任务,KNN算法仍然展现出了其有效性。通过调整k值或者加入样本权重,可以进一步优化模型性能。

label: 2, prediction: 3
label: 3, prediction: 2
label: 1, prediction: 1
label: 3, prediction: 3
label: 1, prediction: 1
label: 1, prediction: 1
label: 3, prediction: 3
label: 3, prediction: 3
label: 1, prediction: 1
label: 1, prediction: 1
label: 1, prediction: 1
label: 3, prediction: 3
label: 1, prediction: 1
label: 1, prediction: 1
label: 3, prediction: 1
label: 1, prediction: 3
label: 1, prediction: 3
label: 3, prediction: 3
label: 1, prediction: 1
label: 3, prediction: 3
label: 3, prediction: 3
label: 3, prediction: 3
label: 1, prediction: 1
label: 3, prediction: 3
label: 3, prediction: 1
label: 1, prediction: 1
label: 1, prediction: 1
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 3
label: 2, prediction: 1
label: 2, prediction: 1
label: 2, prediction: 3
label: 2, prediction: 1
label: 2, prediction: 3
label: 2, prediction: 2
label: 2, prediction: 3
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 3
label: 2, prediction: 3
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
Validation accuracy is 0.700000

本次实验不仅让我掌握了KNN算法的实现过程,还了解了MindSpore框架在机器学习任务中的应用。通过实验操作,进一步巩固了机器学习理论知识,提升了编程实战能力。同时,也深刻认识到在处理实际问题时,数据预处理和特征工程的重要性。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1483871.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

项目实用linux 操作详解-轻松玩转linux

我之前写过完整的linux系统详解介绍: LInux操作详解一:vmware安装linux系统以及网络配置 LInux操作详解二:linux的目录结构 LInux操作详解三:linux实际操作及远程登录 LInux操作详解四:linux的vi和vim编辑器 LInux操作…

商业数据分析思维的培训PTT制作大纲分享

商业数据分析思维的培训PTT制作大纲: 基本步骤: 明确PPT的目的和主题 收集并整理相关内容资料 构思并确定PPT的框架大纲 编写PPT的内容文字 插入图片、图表等视觉元素 设计PPT的版式和模板 排练并修改PPT 输出并备份最终版本 目的:数据思维培养; 主题:商业数据分…

【吊打面试官系列-ZooKeeper面试题】zookeeper 是如何保证事务的顺序一致性的?

大家好,我是锋哥。今天分享关于 【zookeeper 是如何保证事务的顺序一致性的?】面试题,希望对大家有帮助; zookeeper 是如何保证事务的顺序一致性的? zookeeper 采用了全局递增的事务 Id 来标识,所有的 prop…

Seaborn库学习之heatmap()函数

Seaborn库学习之heatmap(函数) 一、简介 seaborn.heatmap是Seaborn库中用于绘制热图(Heatmap)的函数。热图是一种数据可视化技术,通过颜色的变化来展示数据矩阵中的数值大小。这种图表非常适合展示数值数据的分布和关系,尤其是在…

韦东山嵌入式linux系列-驱动进化之路:设备树的引入及简明教程

1 设备树的引入与作用 以 LED 驱动为例,如果你要更换LED所用的GPIO引脚,需要修改驱动程序源码、重新编译驱动、重新加载驱动。 在内核中,使用同一个芯片的板子,它们所用的外设资源不一样,比如A板用 GPIO A&#xff0c…

TI毫米波雷达1843 Out-of-box Demo 总结

总结 以上就是基于MATLAB实现1843 Out-of-box Demo的实时数据采集的相关内容,里面包含了 如何快速上手TI的毫米波雷达开发板;如何使用CCS构建TI的工程代码框架;如何阅读CCS源码确定串口输出的通讯协议;如何使用MATLAB实时接收串口数据;如何使用MATLAB编写上位机软件;成品…

13 循环神经网络—序列模型,语言模型

目录 1.序列模型序列数据统计工具自回归模型马尔可夫模型因果关系前向算法举例(根据过去的事件推测未来的事件)方案 A -马尔科夫假设方案 B -潜变量模型总结代码实现 使用马尔科夫假设 训练一个MLP2.文本预处理常见的文本预处理步骤代码实现3.语言模型**使用计数来建模**N 元…

大模型评测技术研讨会暨国际标准IEEE P3419第二次工作组会议成功召开

7月12日,由北京智源人工智能研究院主办的大模型评测技术研讨会暨国际标准IEEE P3419第二次工作组会议在智源大厦举办,来自百度、信通院、移动、联通、电信、浪潮、南方电网、南瑞、清华、北航等互联网大厂、科研机构、运营商、知名高校以及海外的50余位专…

Android:创建自定义View

点击查看创建自定义view官网文档 一、简介 设计良好的自定义视图与任何其他精心设计的类一样。它通过一个简单的接口封装一组特定的功能,高效使用 CPU 和内存,诸如此类。除了是一个精心设计的类之外,自定义视图还必须执行以下操作&#xff1…

vue echarts 柱状图表,点击柱子,路由代参数(X轴坐标)跳转

一 myChart.on(click, (params) > {if (params.componentType series && params.dataIndex ! undefined) {const months this.month_htqd[params.dataIndex]; // 获取点击柱状图的 X 轴坐标值alert(点击了柱状图,值为: ${months});// 根据点击的柱状图…

哪种SSL证书可以快速签发保护http安全访问?

用户访问网站,经常会遇到访问http网页时,提示网站不安全或者不是私密连接的提示,因为http是使用明文传输,数据传输中可能被篡改,数据不被保护,通常需要SSL证书来给数据加密。 SSL证书的签发速度&#xff0…

自动化测试中如何应对网页弹窗的挑战!

在自动化测试中,网页弹窗的出现常常成为测试流程中的一个难点。无论是警告框、确认框、提示框,还是更复杂的模态对话框,都可能中断测试脚本的正常执行,导致测试结果的不确定性。本文将探讨几种有效的方法来应对网页弹窗的挑战&…

Postgresql-12.5 安装及配置 -银河麒麟V10服务器版本

Postgresql-12.5 安装及配置 环境基于银河麒麟V10 服务器版本操作 此安装步骤Linux操作系统几乎通用 下载数据库安装包 链接:https://pan.baidu.com/s/1wt4Yjwv79W-fCd4tlMC4-w 提取码:0117 1.下载依赖 可以用系统自带的依赖库下载 yum install -…

基于PHP+MYSQL开发制作的趣味测试网站源码

基于PHPMYSQL开发制作的趣味测试网站源码。可在后台提前设置好缘分, 自己手动在数据库里修改数据,数据库里有就会优先查询数据库的信息, 没设置的话第一次查询缘分都是非常好的 95-99,第二次查就比较差 , 所以如果要…

什么是SQL锁

SQL锁是数据库系统中的一个重要概念,主要用于保证多用户环境下的数据库完整性和一致性。在多用户并发访问数据库时,通过加锁的方式防止其他事务访问指定的资源,从而控制并发的访问,确保数据的完整性和一致性。 SQL锁可以分为以下…

msyql (8.4,9.0) caching_sha2_password 转换 mysql_native_password用户认证

mysql 前言 caching_sha2_password 主要特性 用于增强用户账户密码的存储和验证安全性。这种插件利用 SHA-256 散列算法的变体来存储和验证密码 安全的密码散列: caching_sha2_password 使用基于 SHA-256 的算法来生成密码的散列值。这意味着即使数据库被未授权访…

地图项目涉及知识点总结

序:最近做了一个在地图上标记点的项目,用户要求是在地图上显示百万量级的标记点,并且地图仍要可用(能拖拽,能缩放)。调研了不少方法和方案,最终实现了相对流畅的地图系统,加载耗时用…

spring-boot 整合 redisson 实现延时队列(文末有彩蛋)

应用场景 通常在一些需要经历一段时间或者到达某个指定时间节点才会执行的功能,比如以下这些场景: 订单超时提醒收货自动确认会议提醒代办事项提醒 为什么使用延时队列 对于数据量小且实时性要求不高的需求来说,最简单的方法就是定时扫描数据…

【IEEE出版】第四届能源工程与电力系统国际学术会议(EEPS 2024)

第四届能源工程与电力系统国际学术会议(EEPS 2024) 2024 4th International Conference on Energy Engineering and Power Systems 重要信息 大会官网:www.iceeps.com 大会时间:2024年8月9-11日 大会…

S7-1200PLC使用西门子报文 111 和 FB38002(Easy_SINA_Pos)实现V90 PN总线伺服定位控制

1、博途1200/1500 PLC V90 PN通信 博途1200/1500PLC V90 PN通信控制 (FB284功能块)_fb284功能块文档说明-CSDN博客文章浏览阅读7k次。先简单说下如何获取FB284,一般有2种方法,Startdrive软件可以操作大部分西门子的驱动器,建议安装调试方便,缺点就是软件太大。_fb284功能…