CRWU凯斯西储大学轴承数据,12k频率,十分类

在这里插入图片描述
CRWU凯斯西储大学轴承数据,12k频率,十分类。

from torch.utils.data import Dataset, DataLoader
from scipy.io import loadmat
import numpy as np
import os
from sklearn import preprocessing  # 0-1编码
from sklearn.model_selection import StratifiedShuffleSplit  # 随机划分,保证每一类比例相同
import torch
from torch import nn
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch.optim as optimdef prepro(d_path, length=0, number=0, normal=True, rate=[0, 0, 0], enc=False, enc_step=28):# 获得该文件夹下所有.mat文件名filenames = os.listdir(d_path)def capture(original_path):files = {}for i in filenames:# 文件路径file_path = os.path.join(d_path, i)file = loadmat(file_path)file_keys = file.keys()for key in file_keys:if 'DE' in key:files[i] = file[key].ravel()return filesdef slice_enc(data, slice_rate= rate[1]):keys = data.keys()Train_Samples = {}Test_Samples = {}for i in keys:slice_data = data[i]all_lenght = len(slice_data)# end_index = int(all_lenght * (1 - slice_rate))samp_train = int(number * (1 - slice_rate))  # 1000(1-0.3)Train_sample = []Test_Sample = []for j in range(samp_train):sample = slice_data[j * 150: j * 150 + length]Train_sample.append(sample)# 抓取测试数据for h in range(number - samp_train):sample = slice_data[samp_train * 150 + length + h * 150: samp_train * 150 + length + h * 150 + length]Test_Sample.append(sample)Train_Samples[i] = Train_sampleTest_Samples[i] = Test_Samplereturn Train_Samples, Test_Samples# 仅抽样完成,打标签def add_labels(train_test):X = []Y = []label = 0for i in filenames:x = train_test[i]X += xlenx = len(x)Y += [label] * lenxlabel += 1return X, Ydef scalar_stand(Train_X, Test_X):# 用训练集标准差标准化训练集以及测试集data_all = np.vstack((Train_X, Test_X))scalar = preprocessing.StandardScaler().fit(data_all)Train_X = scalar.transform(Train_X)Test_X = scalar.transform(Test_X)return Train_X, Test_Xdef valid_test_slice(Test_X, Test_Y):test_size = rate[2] / (rate[1] + rate[2])ss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)Test_Y = np.asarray(Test_Y, dtype=np.int32)for train_index, test_index in ss.split(Test_X, Test_Y):X_valid, X_test = Test_X[train_index], Test_X[test_index]Y_valid, Y_test = Test_Y[train_index], Test_Y[test_index]return X_valid, Y_valid, X_test, Y_test# 从所有.mat文件中读取出数据的字典data = capture(original_path=d_path)# 将数据切分为训练集、测试集train, test = slice_enc(data)# 为训练集制作标签,返回X,YTrain_X, Train_Y = add_labels(train)# 为测试集制作标签,返回X,YTest_X, Test_Y = add_labels(test)# for i in Test_X:#     print(i.shape)# for i in Train_X:#     print(i.shape)# Train_X = np.stack(Train_X,axis=0)# Test_X = np.stack(Test_X,axis=0)# print(Train_X.shape,Test_X.shape)# 训练数据/测试数据 是否标准化.if normal:Train_X, Test_X = scalar_stand(Train_X, Test_X)# 将测试集切分为验证集和测试集.# Valid_X, Valid_Y, Test_X, Test_Y = valid_test_slice(Test_X, Test_Y)return Train_X, Train_Y,  Test_X, Test_Ynum_classes = 10  # 样本类别
length = 224*224  # 样本长度
number = 140  # 每类样本的数量
normal = True  # 是否标准化
rate = [0.5, 0.25, 0.25]  # 测试集验证集划分比例class BearingDataset(Dataset):def __init__(self, data, labels):self.data = torch.tensor(data, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]def create_dataloader(data, labels, batch_size=32, shuffle=True, num_workers=0):dataset = BearingDataset(data, labels)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)return dataloader# 使用前面定义的函数处理数据
path = './data12k'  # 注意路径格式可能需要根据您的操作系统调整
x_train, y_train,  x_test, y_test = prepro(d_path=path,length=112*112,  # 样本长度number=250,  # 每类样本的数量normal=True,  # 是否标准化rate=[0.8, 0.2]  # 测试集验证集划分比例
)# 创建 DataLoader
train_loader = create_dataloader(x_train, y_train, batch_size=32, shuffle=True, num_workers=0)
test_loader = create_dataloader(x_test, y_test, batch_size=32, shuffle=False, num_workers=0)class AlexNet(nn.Module):def __init__(self):super(AlexNet, self).__init__()self.Conv1 = nn.Conv2d(1, 24, 15, 3, 2)  # [1, 48, 107, 107]self.relu1 = nn.ReLU()self.maxpool1 = nn.MaxPool2d(2)  # [1, 24, 35, 35]self.Conv2 = nn.Conv2d(24, 64, 5, 1, 2)  # [1, 64, 37, 37]self.relu2 = nn.ReLU()self.maxpool2 = nn.MaxPool2d(2)  # [1, 64, 18, 18]self.Conv3 = nn.Conv2d(64, 96, 2, 1, 1)  # [1, 96, 19, 19]self.relu3 = nn.ReLU()self.Conv4 = nn.Conv2d(96, 96, 2, 1, 1)  # [1, 96, 20, 20]self.relu4 = nn.ReLU()self.Conv5 = nn.Conv2d(96, 64, 2, 1, 1)  # [1, 64, 21, 21]self.relu5 = nn.ReLU()self.maxpool3 = nn.MaxPool2d(3)  # [1, 64, 7, 7]self.Dro1 = nn.Dropout(p=0.5)self.flatten = nn.Flatten()self.line1 = nn.Linear(64 * 3 * 3, 1000)self.relu6 = nn.ReLU()self.Dro2 = nn.Dropout(p=0.5)self.line2 = nn.Linear(1000, 1000)self.relu7 = nn.ReLU()self.line3 = nn.Linear(1000, 500)self.line4 = nn.Linear(500, 10)def forward(self, x):x = self.Conv1(x)x = self.relu1(x)x = self.maxpool1(x)x = self.Conv2(x)x = self.relu2(x)x = self.maxpool2(x)x = self.Conv3(x)x = self.relu3(x)x = self.Conv4(x)x = self.relu4(x)x = self.Conv5(x)x = self.relu5(x)x = self.maxpool3(x)x = self.Dro1(x)x = self.flatten(x)x = self.line1(x)x = self.relu6(x)x = self.Dro2(x)x = self.line2(x)x = self.relu7(x)x = self.line3(x)x = self.line4(x)return xdef train_model(model, train_loader, criterion, optimizer, device):model.train()running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs.view(-1,1,112,112))loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()return running_loss / len(train_loader)def evaluate_model(model, data_loader, device):model.eval()all_preds = []all_labels = []with torch.no_grad():for inputs, labels in data_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs.view(-1,1,112,112))_, predicted = torch.max(outputs, 1)all_preds.extend(predicted.cpu().numpy())all_labels.extend(labels.cpu().numpy())accuracy = accuracy_score(all_labels, all_preds)precision, recall, f1_score, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')return accuracy, precision, recall, f1_scoredef main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = AlexNet().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)num_epochs = 50# 假设 train_loader 和 test_loader 已经被创建best_acc=0for epoch in range(num_epochs):train_loss = train_model(model, train_loader, criterion, optimizer, device)print(f'Epoch {epoch+1}, Loss: {train_loss:.4f}')accuracy, precision, recall, f1_score = evaluate_model(model, test_loader, device)print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1_score:.4f}')if best_acc<accuracy:best_acc = accuracytorch.save(model.state_dict(), 'best_alexnet.pth')if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1424907.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

基于HTML5和CSS3搭建一个Web网页(一)

倘若代码中有任何问题或疑问&#xff0c;欢迎留言交流~ 网页描述 创建一个包含导航栏、主内容区域和页脚的响应式网页。 需求: 导航栏: 在页面顶部创建一个导航栏&#xff0c;包含首页、关于我们、服务和联系我们等链接。 设置导航栏样式&#xff0c;包括字体、颜色和背景颜…

Unity | Spine动画动态加载

一、准备工作 Spine插件及基本知识可查看这篇文章&#xff1a;Unity | Spine动画记录-CSDN博客 二、Spine资源动态加载 1.官方说明 官方文档指出不建议这种操作。但spine-unity API允许在运行时从SkeletonDataAsset或甚至直接从三个导出的资产实例化SkeletonAnimation和Skel…

使用JasperReport工具,生成报表模版,及通过JavaBean传参,常见问题及建议

1.下载JasperReport工具 下载地址:社区版 - Jaspersoft 社区 邮箱:lorettepatri.ckoa5434gmail.com 密码:Zx123456. 2.工具使用方法注意 1.一次参数需要在左下角Parameters中新建,直接拖转右上角的TextField不会自动新建参数,到头来还是要在Parameters中新建 2.循环参数需…

ChatGPT 4o 使用案例之一

2024年GPT迎来重大更新&#xff0c;OpenAI发布GPT-4o GPT-4o&#xff08;“o”代表“全能”&#xff09; 它可以接受任意组合的文本、音频和图像作为输入&#xff0c;并生成任意组合的文本、音频和图像输出。它可以在 232 毫秒内响应音频输入&#xff0c;平均为 320 毫秒&…

Git使用(4):分支管理

一、新建分支 首先选择Git -> Branches... 然后选择 New Branch&#xff0c;输入新分支名称&#xff0c;例如dev。 可以看到右下角显示已经切换到新建的dev分支了。 push到远程仓库&#xff0c;可以看到新添加的分支。 二、切换分支与合并分支 为了演示合并分支&#xff0c…

码农慎入 | 入坑软路由,退烧IDC,Homelab折腾记

点击文末“阅读原文”即可参与节目互动 剪辑、音频 / 卷圈 运营 / SandLiu 卷圈 监制 / 姝琦 封面 / 姝琦Midjourney 产品统筹 / bobo 场地支持 / 声湃轩北京录音间 俗话说&#xff0c;入门软路由&#xff0c;退坑IDC 这一期&#xff0c;我们将深入探讨一个许多科技爱好者…

【oracle】图片转为字节、base64编码等形式批量插入oracle数据库并查询

1.熟悉、梳理、总结下Oracle相关知识体系 2.欢迎批评指正&#xff0c;跪谢一键三连&#xff01; 资源下载&#xff1a; oci.dll、oraocci11.dll、oraociei11.dll3个资源文件资源下载&#xff1a; Instant Client Setup.exe资源下载&#xff1a; oci.dll、oraocci11.dll、oraoc…

[数据集][目标检测]蕃茄核桃桔子龙眼青枣5种水果检测数据集VOC+YOLO格式270张5类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;270 标注数量(xml文件个数)&#xff1a;270 标注数量(txt文件个数)&#xff1a;270 标注类别…

人工智能|深度学习——YOLOV8结构图

YoloV8相对于YoloV5的改进点&#xff1a; Replace the C3 module with the C2f module.Replace the first 6x6 Conv with 3x3 Conv in the Backbone.Delete two Convs (No.10 and No.14 in the YOLOv5 config).Replace the first 1x1 Conv with 3x3 Conv in the Bottleneck.Use…

使用Docker进行Jmeter分布式搭建

大家好&#xff0c;随着技术的不断发展&#xff0c;对性能测试的要求也日益提高。在这样的背景下&#xff0c;如何利用 Docker 来巧妙地搭建 Jmeter 分布式成为了关键所在。现在&#xff0c;就让我们开启这场探索之旅&#xff0c;揭开其神秘的面纱。前段时间给大家分享了关于 L…

Linux上编译安装和卸载软件

在maven官网下载maven时候&#xff0c;看到maven-3.9.5这个版本有2份安装包&#xff0c;一个是binaries&#xff0c;一个是source binaries是已编译好的文件&#xff0c;可以直接使用的版本&#xff1b;source是源代码版本&#xff0c;需要自己编译 源码的安装一般由这三个步…

nn.BatchNorm中affine参数的作用

在PyTorch的nn.BatchNorm2d中&#xff0c;affine参数决定是否在批归一化&#xff08;Batch Normalization&#xff09;过程中引入可学习的缩放和平移参数。 BN层的公式如下&#xff0c; affine参数决定是否在批归一化之后应用一个可学习的线性变换&#xff0c;即缩放和平移。具…

信息系统项目管理师0601:项目立项管理 — 考点总结(可直接理解记忆)

点击查看专栏目录 项目立项管理 — 考点总结(可直接理解记忆) 1.项目建议书(又称立项申请)是项目建设单位向上级主管部门提交项目申请时所必须的文件,是对拟建项目提出的框架性的总体设想。在项目建议书批准后,方可开展对外工作(掌握)。 2.项目建议书应该包括的核心内…

k8s 二进制安装 详细安装步骤

目录 一 实验环境 二 操作系统初始化配置&#xff08;所有机器&#xff09; 1&#xff0c;关闭防火墙 2&#xff0c;关闭selinux 3&#xff0c;关闭swap 4, 根据规划设置主机名 5, 做域名映射 6&#xff0c;调整内核参数 7&#xff0c; 时间同步 三 部署 dock…

redis7基础篇2 redis的3种模式(主从,哨兵,集群)模式

一 主从复制模式 1.1 主从模式 主从模式&#xff1a; 主机可以读&#xff0c;写&#xff0c;重机只能写操作。 主机shutdown后&#xff0c;从机上位还是原地待命&#xff1a;从机不动&#xff0c;原地待命&#xff0c;数据正常使用&#xff0c;等待主机重启归来。 主机shu…

【Docker学习】查询容器镜像的docker search

这个命令是使用Docker的必备技能。我们使用的各种官方镜像&#xff0c;一般都能通过这个命令找到。 命令&#xff1a; docker search 描述&#xff1a; 在Docker Hub上查找镜像。Docker Hub是为开发者和开源贡献者设计的容器镜像注册中心&#xff0c;它允许用户查找、使用和…

《天空之城》观后感

曾经很长一段时间都着迷于《天空之城》这段旋律&#xff0c;一遍一遍不厌其烦地听&#xff0c;静谧而温馨、豪迈却苍凉&#xff0c;各种复杂的感受随着起伏的音符流淌进心里。多年之后才知道这首曲子出自宫崎骏的同名动画电影。说来也有意思&#xff0c;似乎大多数人是通过电影…

TypeScript中的泛型(Generics)

TypeScript中的泛型&#xff08;Generics&#xff09; 在前面的几篇文章中&#xff0c;我们了解了TypeScript的类、接口和基本的数据类型系统。本文将重点介绍TypeScript中的泛型&#xff0c;这是一种强大的工具&#xff0c;它允许我们创建可重用的组件&#xff0c;同时保持类…

AI网络爬虫:用kimi提取网页中的表格内容

一个网页中有一个很长的表格&#xff0c;要提取其全部内容&#xff0c;还有表格中的所有URL网址。 在kimi中输入提示词&#xff1a; 你是一个Python编程专家&#xff0c;要完成一个编写爬取网页表格内容的Python脚步的任务&#xff0c;具体步骤如下&#xff1a; 在F盘新建一个…

三.使用HashiCorp Vault工具管理数据库

三.ubuntu安装使用HashiCorp Vault工具管理数据库 HashiCorp Vault 是一个基于身份的秘密和加密管理系统。机密是您想要严格控制访问的任何内容,例如 API 加密密钥、密码和证书。Vault 提供由身份验证和授权方法门控的加密服务。使用 Vault 的 UI、CLI 或 HTTP API,可以安全…