tensorflow快速入门--如何定义张量、定义网络结构、超参数设置、模型训练???

前言

  • 由于最近学习的东西涉及到tensorflow的使用,故先简单的学习了一下tensorflow中如何定义张量、定义网络结构、超参数设置、模型训练的API调用过程;
  • 欢迎大家,收藏+关注,本人将持续更新。

文章目录

  • 1、基本操作
    • 1、张量基础操作
      • 创建0维度张量
      • 创建1维张量
      • 创建多维张量
    • 2、转换成numpy
    • 3、常用函数
    • 4、变量
  • 2、用tensorflow构建神经网络
    • 1、数据分析
      • 1、导入库
      • 2、导入数据
      • 3、数据归一化
      • 4、图片展示
      • 5、图片格式归一化
    • 2、构建神经网络
    • 3、设置超参数
    • 4、模型训练
    • 5、预测

1、基本操作

1、张量基础操作

tensorflow中定义常量张量是constant,也就是不能改变的张量。

# 导入库
import tensorflow as tf 
import numpy as np

创建0维度张量

zeros = tf.constant(3)
zeros

输出:

<tf.Tensor: shape=(), dtype=int32, numpy=3>

输出说明:

  • shape: 数据维度,0
  • numpy:数据,3

创建1维张量

tf.constant([1.0, 2.0, 3.0])

输出:

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.], dtype=float32)>

输出说明:

  • shape: 数据维度,(1, 3)
  • numpy:数据,array[1., 2., 3.]
  • dtype:数据类型,浮点类型

创建多维张量

tf.constant([[1, 2],[3, 4],[5, 6]], dtype=tf.float32)

输出:

<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[1., 2.],[3., 4.],[5., 6.]], dtype=float32)>

输出说明:

  • shape: 数据维度,(3, 2)
  • numpy:数据
  • dtype:数据类型,浮点类型

三维乃至四维都是这样

2、转换成numpy

tf1 = tf.constant([1, 2, 3, 4, 5])
tf1

输出:

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([1, 2, 3, 4, 5])>
#tensor转numpy
#方法一
np.array(tf1)

输出:

array([1, 2, 3, 4, 5])
#方法二
tf1.numpy()

输出:

array([1, 2, 3, 4, 5])

3、常用函数

注意:tensorflow默认不是浮点类型。

# 这样默认是float32
a = tf.constant([[1.0, 2],[3, 4]])b = tf.constant([[1, 1],[1, 1]], dtype=tf.float32)

注意:constant是创建不可变的张量,不能修改,一下这个不能修改

# 会报错
a[0, 0] = 2
报错如下:
---------------------------------------------------------------------------TypeError                                 Traceback (most recent call last)Cell In[9], line 21 # 会报错
----> 2 a[0, 0] = 2TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
# 加法
tf.add(a, b)

输出:

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 3.],[4., 5.]], dtype=float32)>
# 乘法
tf.matmul(a, b)

输出:

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[3., 3.],[7., 7.]], dtype=float32)>
# 最大值
tf.reduce_max

输出:

<function tensorflow.python.ops.math_ops.reduce_max(input_tensor, axis=None, keepdims=False, name=None)>
# 最大值索引
tf.argmax(a)

输出:

<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 1], dtype=int64)>
# 平均值
tf.reduce_mean(a)

输出:

<tf.Tensor: shape=(), dtype=float32, numpy=2.5>

4、变量

tf.Variable:tensorflow是可变类,ptorch没有这个,但是pytorch的张量是可以变的

var = tf.Variable([[1,2],[3,4]])
var

输出:

<tf.Variable 'Variable:0' shape=(2, 2) dtype=int32, numpy=
array([[1, 2],[3, 4]])>
# 查看变量的维度
var.shape

输出:

TensorShape([2, 2])
# 查看数据类型
var.dtype

输出:

tf.int32
# 修改变量的值
# 不能直接这样修改值:var[0, 0] = 5
var[0,0].assign(2)

输出:

<tf.Variable 'UnreadVariable' shape=(2, 2) dtype=int32, numpy=
array([[2, 2],[3, 4]])>
# 也可以整体修改
var.assign([[3,4],[5,6]])

输出:

<tf.Variable 'UnreadVariable' shape=(2, 2) dtype=int32, numpy=
array([[3, 4],[5, 6]])>

输出:

# 但是不能修改成维度不匹配的
var.assgin([2,3]) # 错误
---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)Cell In[21], line 21 # 但是不能修改成维度不匹配的
----> 2 var.assgin([2,3]) # 错误
AttributeError: 'ResourceVariable' object has no attribute 'assgin'

2、用tensorflow构建神经网络

以实现手写字体的识别为例,按步骤一步一步实现,核心:

  • 导入数据
  • 构建神经网络
  • 设置超参数
  • 模型训练

1、数据分析

1、导入库

# 导入一些必要库
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt# 查看是否支持gpu
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")print(gpus)
[]
  • tf: 导入tensorflow框架
  • datasets:提供常用的数据集,方便快速加载和使用。
  • layers:提供各种神经网络层,用于构建模型。
  • models:提供模型类,如 Sequential,用于管理和训练模型。

2、导入数据

tensorflow中可以利用API直接导入mnist数据,导入数据分别依次为训练集图片、训练集标签、测试集图片、测试集标签

(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 4s 0us/step

3、数据归一化

作用:

  • 去除数据的量纲影响,减小数据的方差,使得模型更加准确;
  • 加快算法学习速度。
# 将数据标准化在0-1区间内
train_images, test_images = train_images / 255.0, test_images / 255.0
# 查看数据维度信息
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

输出:

((60000, 28, 28), (60000,), (10000, 28, 28), (10000,))

4、图片展示

plt.figure(figsize=(20, 10))
for i in range(20):plt.subplot(2, 10, i + 1)  # 宽、高、位置plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])plt.show()


在这里插入图片描述

5、图片格式归一化

train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

输出:

((60000, 28, 28, 1), (60000,), (10000, 28, 28, 1), (10000,))

2、构建神经网络

下面是我用AI大模型搜到的一些资料:
在 TensorFlow 中,tf.keras.layers 模块提供了构建神经网络所需的各种层。这些层可以用来构建不同类型的神经网络模型,如卷积神经网络(CNN)、循环神经网络(RNN)等。下面是一些常用的层类型及其简要说明:

基础层

  • Dense:全连接层,用于构建标准的前馈神经网络。
    • 参数:units(输出维度),activation(激活函数),use_bias(是否使用偏置项)等。
  • Activation:独立的激活层,可以应用于任何其他层的输出。
    • 参数:activation(激活函数名或函数对象)。
  • Dropout:用于防止过拟合,通过随机丢弃一部分神经元来实现。
    • 参数:rate(丢弃率,即每个神经元被丢弃的概率)。
  • Flatten:将多维输入展平成一维向量,常用于连接卷积层和全连接层之间。
    • 参数:无。

卷积层

  • Conv1DConv2DConv3D:一维、二维、三维卷积层,分别适用于时间序列数据、图像数据和视频数据。
    • 参数:filters(滤波器数量),kernel_size(卷积核大小),strides(步长),padding(填充方式),activation(激活函数)等。
  • SeparableConv1DSeparableConv2D:深度可分离卷积层,用于减少计算量和参数数量。
    • 参数与普通卷积层类似。
  • DepthwiseConv2D:深度卷积层,用于处理通道之间的信息。
    • 参数:kernel_sizestridespadding等。

池化层

  • MaxPooling1DMaxPooling2DMaxPooling3D:最大池化层,用于降低数据的空间尺寸。
    • 参数:pool_size(池化窗口大小),stridespadding等。
  • AveragePooling1DAveragePooling2DAveragePooling3D:平均池化层,作用类似于最大池化层,但采用的是平均值而不是最大值。
    • 参数与最大池化层相同。

循环层

  • SimpleRNN:简单的循环神经网络层。
    • 参数:units(输出维度),activationuse_bias等。
  • LSTM:长短期记忆网络层,是一种特殊的 RNN 层,能够学习长期依赖关系。
    • 参数与 SimpleRNN 相似。
  • GRU:门控循环单元层,是 LSTM 的简化版本。
    • 参数与 SimpleRNN 和 LSTM 相似。

正则化层

  • BatchNormalization:批量归一化层,用于加速训练过程并减少内部协变量转移。
    • 参数:axis(指定要进行归一化的轴),momentum(移动平均的动量),epsilon(防止除零的小值)等。

输入层

  • InputLayer:显式定义模型的输入层。
    • 参数:input_shape(输入张量的形状),batch_size(批处理大小)等。

注意力机制相关层

  • Attention:注意力机制层,用于在模型中引入注意力机制。
    • 参数:use_scale(是否使用缩放因子),causal(是否为因果注意力)等。

当然,这些只是 tf.keras.layers 模块提供的一部分层。

model = models.Sequential([# 激活函数,relu# 池化层:平均池化layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),  # 卷积层layers.MaxPooling2D((2, 2)),  # 池化层layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层layers.MaxPooling2D((2, 2)),   # 池化层layers.Flatten(),  # 全部展开layers.Dense(64, activation='relu'),   # 降维layers.Dense(10)  # 根据需要分类数来分类
])
# 打印模型结构
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_5 (Conv2D)            (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 1600)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 64)                102464    
_________________________________________________________________
dense_5 (Dense)              (None, 10)                650       
=================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
_________________________________________________________________

3、设置超参数

model.compile 方法用于配置模型的训练过程。这个方法允许你指定损失函数、优化器和评估指标等关键参数,以便在训练过程中使用

model.compile(# 设置优化器optimizer = 'adam',# 设置损失率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# 评价指标metrics=['accuracy']
)

4、模型训练

model.fit是tensorflow中模型训练的核心方法,其主要参数如下:

  • x:训练数据。可以是 NumPy 数组、Python 列表或 TensorFlow 数据集。
  • y:训练数据的标签。可以是 NumPy 数组或 Python 列表。
  • epochs:训练的轮数。每个 epoch 表示模型将遍历整个训练数据集一次。
  • batch_size:每个批次的样本数。默认情况下,model.fit 会在每个 epoch 内将训练数据分成多个批次进行训练。
  • validation_data:验证数据,用于在每个 epoch 结束时评估模型性能。可以是一个元组 (x_val, y_val),其中 x_val 和 y_val 分别是验证数据和标签。
  • validation_split:从训练数据中划分出一部分作为验证数据的比例。如果指定了 validation_split,则不需要再提供 validation_data。
  • shuffle:是否在每个 epoch 开始前打乱训练数据。默认为 True。
  • class_weight:类别权重,用于处理类别不平衡问题。可以是一个字典,键为类别索引,值为对应的权重。
  • sample_weight:样本权重,用于处理样本不平衡问题。可以是一个与训练数据长度相同的数组。
  • initial_epoch:开始训练的初始 epoch。用于恢复中断的训练。
  • steps_per_epoch:每个 epoch 的训练步骤数。当使用 TensorFlow 数据集时,默认情况下会遍历整个数据集。如果数据集很大,可以使用 steps_per_epoch 来限制每个 epoch 的训练步骤数。
  • validation_steps:验证步骤数。类似于 steps_per_epoch,用于限制验证数据集的遍历次数。
  • callbacks:回调函数列表,用于在训练过程中执行特定的操作,如保存模型、调整学习率、记录日志等。
model.fit(x = train_images,   # 训练集数据y = train_labels,   # 训练集标签epochs = 10,        # 训练轮次validation_data = (test_images, test_labels)
)
Epoch 1/10
1875/1875 [==============================] - 17s 9ms/step - loss: 0.1485 - accuracy: 0.9548 - val_loss: 0.0455 - val_accuracy: 0.9854
Epoch 2/10
1875/1875 [==============================] - 17s 9ms/step - loss: 0.0489 - accuracy: 0.9850 - val_loss: 0.0390 - val_accuracy: 0.9876
Epoch 3/10
1875/1875 [==============================] - 16s 9ms/step - loss: 0.0350 - accuracy: 0.9892 - val_loss: 0.0396 - val_accuracy: 0.9858
Epoch 4/10
1875/1875 [==============================] - 16s 9ms/step - loss: 0.0252 - accuracy: 0.9919 - val_loss: 0.0292 - val_accuracy: 0.9906
Epoch 5/10
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0187 - accuracy: 0.9940 - val_loss: 0.0278 - val_accuracy: 0.9908
Epoch 6/10
1875/1875 [==============================] - 18s 10ms/step - loss: 0.0140 - accuracy: 0.9957 - val_loss: 0.0315 - val_accuracy: 0.9905
Epoch 7/10
1875/1875 [==============================] - 16s 9ms/step - loss: 0.0108 - accuracy: 0.9964 - val_loss: 0.0361 - val_accuracy: 0.9899
Epoch 8/10
1875/1875 [==============================] - 16s 8ms/step - loss: 0.0099 - accuracy: 0.9966 - val_loss: 0.0327 - val_accuracy: 0.9905
Epoch 9/10
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0079 - accuracy: 0.9976 - val_loss: 0.0353 - val_accuracy: 0.9905
Epoch 10/10
1875/1875 [==============================] - 17s 9ms/step - loss: 0.0072 - accuracy: 0.9977 - val_loss: 0.0349 - val_accuracy: 0.9911

5、预测

model.predict是tensorflow用于预测的API,输出所述类别的概率

# 测试集中第一张图片数据
plt.imshow(test_images[1])


在这里插入图片描述

  • 第一张为数字2
# 测试集预测
pre = model.predict(test_images)
# 输出所属类别概率
pre[1]

输出概率:

array([  4.0634604,  -1.2556026,  25.725822 , -12.4212475,  -8.13141  ,-17.905268 ,   2.7954185,  -2.617863 ,  -4.929763 , -12.776581 ],dtype=float32)

上述输出分别是:属于数字0-9的概率
分析

  • 属于数字2的类别概率最大,且远超其他,故,他预测结果是数字2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1557745.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

时间序列+Transformer席卷而来,性能秒杀传统,创新性拉满,引爆顶会!

时间序列分析与Transformer模型的结合&#xff0c;已成为深度学习领域的一大趋势。这种结合能够高效捕捉序列中的长期依赖关系&#xff0c;提升时间序列分析和预测的准确性。 时间序列Transformer技术在股票价格预测、气候预测、交通流量预测、设备故障预测、自然语言处理等多…

封装vue-cropper,图片裁剪组件

组件基本使用: 这里的action同时也可以传相对路径&#xff0c;比如封装了axios&#xff0c;那么组件源码里就不能引入元素axios&#xff0c;可以替换为封装的axios。传 action"/file/upload" 源代码&#xff1a; <script setup> import WuyuCropper from /com…

【基础算法总结】字符串篇

目录 一&#xff0c;算法简介二&#xff0c;算法原理和代码实现14.最长公共前缀5.最长回文子串67.二进制求和43.字符串相乘 三&#xff0c;算法总结 一&#xff0c;算法简介 字符串 string 是一种数据结构&#xff0c;它一般和其他的算法结合在一起操作&#xff0c;比如和模拟&…

远程控制软件推荐:亲测好用!

无论是在家办公、技术支持还是远程协助家人&#xff0c;一个好的远程控制工具都能让我们的工作更加高效。下面&#xff0c;我将分享我对几款流行的远程控制软件的个人体验&#xff0c;并给出我的推荐。 向日葵远程控制 直达链接&#xff1a;down.oray.com 向日葵远程控制是…

如何实现一个基于 HTML+CSS+JS 的任务进度条

如何实现一个基于 HTMLCSSJS 的任务进度条 在网页开发中&#xff0c;任务进度条是一种常见的 UI 组件&#xff0c;它可以直观地展示任务的完成情况。本文将向你展示如何使用 HTML CSS JavaScript 来创建一个简单的、交互式的任务进度条。用户可以通过点击进度条的任意位置来…

Spring Boot读取resources目录下文件(打成jar可用),并放入Guava缓存

1、文件所在位置&#xff1a; 2、需要Guava依赖&#xff1a; <dependency><groupId>com.google.guava</groupId><artifactId>guava</artifactId><version>23.0</version></dependency>3、启动时就读取放入缓存的代码&#xf…

10.8作业

优化登录框&#xff1a; 当用户点击取消按钮&#xff0c;弹出问题对话框&#xff0c;询问是否要确定退出登录&#xff0c;并提供两个按钮&#xff0c;yes|No&#xff0c;如果用户点击的Yes&#xff0c;则关闭对话框&#xff0c;如果用户点击的No&#xff0c;则继续登录 当用户点…

26.删除有序数组中的重复项

题目::26. 删除有序数组中的重复项 - 力扣&#xff08;LeetCode&#xff09; 思路:只要不和前面的数一样就可以移动指针&#xff0c;进行赋值 代码: class Solution { public:int removeDuplicates(vector<int>& nums) {int slow 0 ;for(int fast 1; fast < …

Sharding-JDBC笔记04-分库分表实战

文章目录 前言一、需求描述二、数据库设计三、环境说明四、环境准备4.1.mysql主从同步(windows)4.2.初始化数据库 五、实现步骤5.1 搭建maven工程引入maven依赖 5.2 实体类5.3 dao层5.4 服务类5.5 测试类总结 5.6 查询商品DaoService单元测试输出小结 5.7 统计商品Dao单元测试统…

许昌文旅助手:AI智能体在文旅领域的创新应用

哈哈&#xff0c;大家好&#xff0c;我是王帅旭&#xff0c;来自大禹智库&#xff0c;也是《实战AI智能体》一书的作者。今天&#xff0c;咱们就来聊聊一个超级有趣的案例——许昌文旅助手&#xff0c;看看AI智能体是如何在文旅领域大放异彩的&#xff01; 无限拓展的能力集&am…

10.8QTQMessageBox练习

QQ界面 widget.cpp #include "widget.h"Widget::Widget(QWidget *parent): QWidget(parent) {//设置框体的大小和颜色this->setFixedSize(350,500);this->setStyleSheet("background-color:#e5f0ff;");//创建一个LineEdit edit1edit1 new QLineEdi…

面试淘天集团大模型算法工程师, 开心到飞起!!!

应聘岗位&#xff1a;淘天集团-大模型算法工程师 面试轮数&#xff1a; 整体面试感觉&#xff1a; 1. 自我介绍 在自我介绍环节&#xff0c;我清晰地阐述了个人基本信息、教育背景、工作经历和技能特长&#xff0c;展示了自信和沟通能力。 2. 技术问题 2.1 在大模型微调过程…

全网最详细的Python Locust性能测试框架实践

Locust的介绍 Locust是一个python的性能测试工具&#xff0c;你可以通过写python脚本的方式来对web接口进行负载测试。 Locust的安装 首先你要安装python2.6以上版本&#xff0c;而且有pip工具。之后打开命令行&#xff0c;分别安装locustio和pyzmq&#xff08;命令如下&…

【技术白皮书】内功心法 | 第一部分 | 数据结构与算法基础(数据结构)

数据结构与算法基础 内容简介数据结构数据模型数据结构的表现形式 基本概念数据&#xff08;Data&#xff09;数据元素&#xff08;data element&#xff09;数据结构的定义物理结构和逻辑结构逻辑结构逻辑结构表现形式二元组模型集合结构模型线性结构模型树结构模型图结构模型…

GNURadio 平台实现AM信号调制解调实验

文章目录​​ 一、AM调制解调原理 1.调制原理 2.解调原理 二、搭建的GRC流图 1.AM调制信号流图 2.AM解调信号流图 一、AM调制解调原理 1.调制原理 幅度调制&#xff08; Amplitude modulation——AM&#xff09; &#xff0c; 是常规的双边带调制&#xff0c; 即使载波的…

和鲸科技创始人范向伟:拐点即将来临,AI产业当前的三个瓶颈

在科技迅猛发展的时代&#xff0c;人工智能&#xff08;AI&#xff09;无疑已经成为引领新一轮产业革命的核心动力之一。全球企业纷纷拥抱AI技术&#xff0c;试图借助其变革力量在竞争中突围&#xff0c;然而业界对AI产业化的拐点何时来临却众说纷纭。毕竟AI技术从实验室到商业…

[SAP ABAP] INCLUDE程序创建

在ABAP中&#xff0c;INCLUDE是一种结构化编程技术&#xff0c;它允许将一段程序代码片段包含到其他程序段中&#xff0c;以便复用和维护 INCLUDE程序创建的好处 ① 代码模块化 将常用的功能或通用的子程序存放到单独的文件中&#xff0c;使得主程序更简洁、易于理解和管理 ② …

揭秘人工智能的奇幻构造:人工智能的组成及都涉及什么

作者简介&#xff1a;我是团团儿&#xff0c;是一名专注于云计算领域的专业创作者&#xff0c;感谢大家的关注 座右铭&#xff1a; 云端筑梦&#xff0c;数据为翼&#xff0c;探索无限可能&#xff0c;引领云计算新纪元 个人主页&#xff1a;团儿.-CSDN博客 目录 前言&#…

动态内存管理练习题的反汇编代码分析(底层)

1.C语言代码 #include <stdio.h> char* GetMemory(void) {char p[] "hello world";return p; }void Test(void) {char* str NULL;str GetMemory();printf(str); }int main() {Test();return 0; } 2.反汇编代码 VS2022x64debug #include <stdio.h>…

PCB进程初识

目录 一、什么是进程 1.课本概念 2.内核观点 二、进程的描述-PCB 1.什么是PCB 2.PCB的组织方式 3.task_struct是Linux操作系统下的PCB 4.task_struct内容分类 三、进程的查看 四、进程的创建 1.查看子进程id和父进程id 演示实例1&#xff1a; 2.fork初识 演示实例…