5.12【机器学习】卷积模型搭建

softmax输出时不可能为所有模型提供精确且数值稳定的损失计算

model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10)
])
mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

 打开一个遍历个周期的FOR循环

对于每个周期,打开一个分批遍历数据集的FOR循环

glob,返回所有匹配的文件路径列表,需要一个参数用来指定匹配的路径字符串(字符串可以为绝对路径,也可以为相对路径),其返回的文件名只包括当前目录里的文件名,不包括子文件夹里的文件

glob.glob(r'c:*.txt')

可以根据层将要运算的输入的形状启用变量创建,根据层将要运算的输入的形状启用变量创建

而在__init__则意味着需要指定创建变量所需的形状

卷积、批次归一化和捷径的组合

_ = layer(tf.zeros([10, 5])) # Calling the layer `.builds` it.
class ResnetIdentityBlock(tf.keras.Model):def __init__(self, kernel_size, filters):super(ResnetIdentityBlock, self).__init__(name='')filters1, filters2, filters3 = filtersself.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))self.bn2a = tf.keras.layers.BatchNormalization()self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')self.bn2b = tf.keras.layers.BatchNormalization()self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))self.bn2c = tf.keras.layers.BatchNormalization()def call(self, input_tensor, training=False):x = self.conv2a(input_tensor)x = self.bn2a(x, training=training)x = tf.nn.relu(x)x = self.conv2b(x)x = self.bn2b(x, training=training)x = tf.nn.relu(x)x = self.conv2c(x)x = self.bn2c(x, training=training)x += input_tensorreturn tf.nn.relu(x)block = ResnetIdentityBlock(1, [1, 2, 3])

自己的训练循环分为三个步骤,迭代Python生成器或tf.data.Dataset获得样本批次

使用tf.G收集梯度

tf.opt将权重更新应用于模型

tf.random.set_seed(2345)
current_time = datetime.datetime.now().strftime(('%Y%m%d-%H%M%S'))
log_dir = 'logs/'+current_time
summary_writer = tf.summary.create_file_writer(log_dir)def preprocess(x, y):# [0~1]x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1y = tf.cast(y, dtype=tf.int32)return x, ydata_dir = 'D:\\MachineLearning\\exp3\\flowers'batch_size = 32
img_height = 32
img_width = 32
#从磁盘中获取数据并进行划分
train_ds = tf.keras.utils.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.utils.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
# # Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
# train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
# val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
# for image, label in train_ds.take(1):
#   print("Image shape: ", image.numpy().shape)
#   print("Label: ", label.numpy())
# for image_batch, labels_batch in train_ds:
#   print(image_batch.shape)
#   print(labels_batch.shape)
#   break
def configure_for_performance(ds):ds = ds.cache()ds = ds.shuffle(buffer_size=1000)ds = ds.batch(batch_size)ds = ds.prefetch(buffer_size=AUTOTUNE)return ds
# train_ds = configure_for_performance(train_ds)
# val_ds = configure_for_performance(val_ds)
# train_ds= tf.squeeze(train_ds, axis=1)
# val_ds= tf.squeeze(val_ds, axis=1)
# (x, y), (x_test, y_test) = datasets.cifar10.load_data()
# y = tf.squeeze(y, axis=1)
# y_test = tf.squeeze(y_test, axis=1)
# print(x.shape, y.shape, x_test.shape, y_test.shape)
#
# train_db = tf.data.Dataset.from_tensor_slices((x, y))
# train_db = train_db.shuffle(1000).map(preprocess).batch(256)
#
# test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
# test_db = test_db.map(preprocess).batch(256)
#
# sample = next(iter(train_db))
# print('sample:', sample[0].shape, sample[1].shape,
#       tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))
def main():# [b, 32, 32, 3] => [b, 1, 1, 512]model = ResNetmodel()model.build([None, 32, 32, 3])# model.summary() # 统计网络参数optimizer = optimizers.Adam(learning_rate=1e-3)# [1, 2] + [3, 4] => [1, 2, 3, 4]variables = model.trainable_variablesfor epoch in range(100):for step, (x, y) in enumerate(train_ds):with tf.GradientTape() as tape:# [b, 32, 32, 3] => [b, 1, 1, 512]out = model(x)# [b] => [b, 5]y_onehot = tf.one_hot(y, depth=5)# compute lossloss = tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True)loss = tf.reduce_mean(loss)grads = tape.gradient(loss, variables)optimizer.apply_gradients(zip(grads, variables))if step % 100 == 0:with summary_writer.as_default():tf.summary.scalar('loss', loss, step=step)total_num = 0total_correct = 0for x, y in val_ds:out = model(x)prob = tf.nn.softmax(out, axis=1)pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)total_num += x.shape[0]total_correct += int(correct)acc = total_correct / total_numwith summary_writer.as_default():tf.summary.scalar('acc', float(acc), step=epoch)if __name__ == '__main__':main()

 需要在每个周期之间对指标调用

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):with tf.GradientTape() as tape:predictions = model(inputs, training=True)regularization_loss=tf.math.add_n(model.losses)pred_loss=loss_fn(labels, predictions)total_loss=pred_loss + regularization_lossgradients = tape.gradient(total_loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))for epoch in range(NUM_EPOCHS):for inputs, labels in train_ds:train_step(inputs, labels)print("Finished epoch", epoch)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/33579.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【H2O2|全栈】Node.js与MySQL连接

目录 前言 开篇语 准备工作 初始配置 创建连接池 操作数据库 封装方法 结束语 前言 开篇语 本节讲解如何使用Node.js实现与MySQL数据库的连接,并将该过程进行函数封装。 与基础部分的语法相比,ES6的语法进行了一些更加严谨的约束和优化&#…

Stable Diffusion Controlnet常用控制类型解析与实战课程 2

本节内容,给大家带来的是stable diffusion Controlnet常用控制类型解析与实战的第二节课程。在上期课程中,我们已经了解了关于线稿类控制类型的特征和用法,本节课程,我们将继续讲解一些常用的控制类型。 一:OpenPose …

TC3xx系列芯片--GPT12模块介绍

1、模块介绍 GPT1/2(General Purpose Timer Unit)是 Aurix TC3XX 内部的通用定时器模块,提供高精度定时功能,GPT1/2 包含 GPT1 和 GPT2 两个子模块,通用定时器单元块 GPT1 和 GPT2 具有非常灵活的多功能定时器结构,可…

uniapp 添加loading

在uniapp中添加loading可以使用uni的API uni.showLoading 方法。以下是一个简单的示例代码 // 显示loading uni.showLoading({title: 加载中 });// 假设这里是异步操作,比如网络请求 setTimeout(function () {// 隐藏loadinguni.hideLoading(); }, 2000);

基于反射内存的光纤交换机

在当今高度信息化的社会中,数据的高速传输与处理已成为各行各业不可或缺的一部分。特别是在航空航天、工业自动化、金融交易及高性能计算等领域,对数据实时性和可靠性的要求尤为严格。为满足这些需求,基于反射内存(Reflective Mem…

前端上传后端接收参数为null

记录一下工作中的问题 前端明明把文件传到后台了,但是后台接收参数为null 原因: 前端上传文件的name和后端接收参数名称不匹配 前端 后端 把前端上传的name由upfile改为file即可 本来是很基本的小问题,但因为自己钻了牛角尖一直没搞定&…

Web3的技术栈详解:解读区块链、智能合约与分布式存储

随着数字时代的不断发展,Web3作为下一代互联网的核心理念逐渐走进了大众视野。它承载着去中心化、用户主权以及更高效、更安全的网络环境的期望。Web3不再是由少数中心化机构主导的网络,而是通过一系列核心技术的支撑,给每个用户赋予了更多的…

芯食代冻干科技研究院:创新与品质并重,推动家用冻干机高质量发展

11月25日,芯食代首届食品冻干前沿与智能化升级创新大会在江苏常州成功举办。本次大会由芯食代冻干科技研究院(江苏)有限公司与芯食代(上海)科技发展有限公司联合主办,云集学界专家教授、商界企业精英,共议家用冻干机的未来创新发展。作为创新大会,芯食代冻干科技研究院也在本次…

相交的链表

力扣链接:160. 相交链表 - 力扣(LeetCode) 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据…

PETRv2: A Unified Framework for 3D Perception from Multi-Camera Images

全文摘要 本文介绍了一种名为PETRv2的统一框架,用于从多视图图像中进行三维感知。该框架基于先前提出的PETR框架,并探索了时间建模的有效性,利用前一帧的时间信息来提高三维物体检测效果。作者在PETR的基础上扩展了三维位置嵌入(…

项目基于oshi库快速搭建一个cpu监控面板

后端&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>com.github.oshi</groupId><artifactId>oshi-…

设计模式——Chain(责任链)设计模式

摘要 责任链设计模式是一种行为设计模式&#xff0c;通过链式调用将请求逐一传递给一系列处理器&#xff0c;直到某个处理器处理了请求或所有处理器都未能处理。它解耦了请求的发送者和接收者&#xff0c;允许动态地将请求处理职责分配给多个对象&#xff0c;支持请求的灵活传…

【Nacos02】消息队列与微服务之Nacos 单机部署

Nacos 部署 Nacos 部署说明 Nacos 快速开始 Nacos 快速开始 版本选择 当前推荐的稳定版本为2.X Releases alibaba/nacos GitHuban easy-to-use dynamic service discovery, configuration and service management platform for building cloud native applications. - Re…

查看 tomcat信息 jconsole.exe

Where is the jconsole.exe? location: JDK/bin/jconsole.exe

大数据新视界 -- Hive 元数据管理:核心元数据的深度解析(上)(27 / 30)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

大数据实验E5HBase:安装配置,shell 命令和Java API使用

实验目的 熟悉HBase操作常用的shell 命令和Java API使用&#xff1b; 实验要求 掌握HBase的基本操作命令和函数接口的使用&#xff1b; 实验平台 操作系统&#xff1a;Linux&#xff08;建议Ubuntu16.04或者CentOS 7 以上&#xff09;&#xff1b;Hadoop版本&#xff1a;3…

【Linux系统】Linux内核框架(详细版本)

Linux体系结构&#xff1a;Linux操作系统的组件详细介绍 Linux 是一个开源的类 UNIX 操作系统&#xff0c;由多个组件组成&#xff0c;具有模块化和层次化的体系结构。它的设计实现了内核、用户空间和硬件的高效协作&#xff0c;支持多用户、多任务操作&#xff0c;广泛应用于…

如何使用apache部署若依前后端分离项目

本章教程介绍,如何在apache上部署若依前后端分离项目 一、教程说明 本章教程,不介绍如何启动后端以及安装数据库等步骤,着重介绍apache的反向代理如何配置。 参考此教程,默认你已经完成了若依后端服务的启动步骤。 前端打包命令使用以下命令进行打包之后会生成一个dist目录…

oracle 11g中如何快速设置表分区的自动增加

在很多业务系统中&#xff0c;一些大表一般通过分区表的形式来实现数据的分离管理&#xff0c;进而加快数据查询的速度。分区表运维管理的时候&#xff0c;由于人为操作容易忘记添加分区&#xff0c;导致业务数据写入报错。所以我们一般通过配置脚本或者利用oracle内置功能实现…

【不稳定的BUG】__scrt_is_managed_app()中断

【不稳定的BUG】__scrt_is_managed_app函数中断 参考问题详细的情况临时解决方案 参考 发现出现同样问题的文章: 代码运行完所有功能&#xff0c;仍然会中断 问题详细的情况 if (!__scrt_is_managed_app())exit(main_result);这里触发了一个断点很奇怪,这中断就发生了一次,代…