用神经网络求解微分方程

微分方程是物理科学的主角之一,在工程、生物、经济甚至社会科学中都有广泛的应用。粗略地说,它们告诉我们一个量如何随时间变化(或其他参数,但通常我们对时间变化感兴趣)。我们可以了解人口、股票价格,甚至某个社会对某些主题的看法如何随时间变化。

通常,用于解决微分方程的方法不是分析性的(即没有解决方案的“封闭公式”),我们必须利用数值方法。然而,从计算的角度来看,数值方法可能很昂贵,更糟糕的是:累积误差可能非常大。

本文将展示神经网络如何成为解决微分方程的宝贵盟友,以及我们如何借用物理信息神经网络的概念来解决这个问题:我们可以使用机器学习方法来解决微分方程吗?

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割

1、物理学信息神经网络

在本节中,我将简要介绍物理学信息神经网络(PINN)。我想你知道“神经网络”部分,但是是什么让它们受到物理学的影响?好吧,它们并不是完全由物理学决定的,而是由(微分)方程决定的。

通常,神经网络经过训练可以找到模式并弄清楚一组训练数据发生了什么。但是,当你训练神经网络遵循训练数据的行为并希望拟合看不见的数据时,你的模型高度依赖于数据本身,而不是系统的底层性质。这听起来几乎像一个哲学问题,但它比这更实际:如果你的数据来自对洋流的测量,这些洋流必须遵循描述洋流的物理方程。但是请注意,你的神经网络对这些方程完全不可知,并且只试图拟合数据点。

这就是物理学信息发挥作用的地方。如果你的模型除了学习如何拟合数据之外,还学习如何拟合控制该系统的方程,那么你的神经网络的预测将更加精确,并且泛化能力会更好,这只是物理信息模型的一些优点。

请注意,系统的控制方程根本不需要涉及物理,“物理信息”只是一种命名法(而且这种技术无论如何都是物理学家最常用的)。如果你的系统是城市中的交通,并且你恰好有一个很好的数学模型,你希望神经网络的预测遵循该模型,那么物理信息神经网络非常适合你。

3、如何告知模型物理信息?

希望我已经说服了你,让模型了解控制我们系统的基础方程是值得的。但是,我们该怎么做呢?有几种方法可以做到这一点,但主要方法是调整损失函数,使其除了通常的数据相关部分之外,还有一个考虑控制方程的项。也就是说,损失函数 L 将由总和组成

这里,数据损失是通常的损失:均方差,或其他适合的损失函数形式;但方程部分是迷人的。想象一下你的系统由以下微分方程控制:

我们如何将其拟合到损失函数中?好吧,由于我们在训练神经网络时的任务是最小化损失函数,我们想要的是最小化以下表达式:

所以我们的方程相关损失函数结果是

也就是说,它是我们的 DE 的均方差。如果我们设法最小化这个值(即使这个项尽可能接近零),我们就会自动满足系统的控制方程。很聪明,对吧?

现在,需要解决损失函数中的额外项 L_IC:它考虑了系统的初始条件。如果没有提供系统的初始条件,则微分方程有无数个解。

例如,从地面扔出的球的轨迹由与从 10 楼扔出的球相同的微分方程控制;但是,我们确信这些球的路径不会相同。这里发生的变化是系统的初始条件。我们的模型如何知道我们正在讨论哪些初始条件?此时,我们自然会使用损失函数项来强制执行它!

对于我们的 DE,让我们规定当 t = 0 时,y = 1。因此,我们希望最小化初始条件损失函数,该函数的内容为:

如果我们最小化这个项,那么我们就会自动满足系统的初始条件。现在,剩下需要理解的是如何使用它来解决微分方程。

4、求解微分方程

如果神经网络既可以用损失函数的数据相关项进行训练(这通常是在经典架构中完成的),也可以用数据和方程相关项进行训练(这就是​​我刚才提到的物理信息神经网络),那么它一定可以训练为仅最小化方程相关项。这正是我们要做的!这里使用的唯一损失函数将是 L_equation。希望下面的图表能够说明我刚才所说的内容:今天我们的目标是右下角的模型类型,即我们的 DE 求解器 NN。

图 1:显示了各种神经网络及其损失函数的图表。在本文中,我们针对右下方的神经网络。

5、代码实现

为了展示我们刚刚学到的理论知识,我将使用机器学习的 PyTorch 库,在 Python 代码中实现所提出的解决方案。

首先要做的是创建一个神经网络架构:

import torch
import torch.nn as nnclass NeuralNet(nn.Module):def __init__(self, hidden_size, output_size=1,input_size=1):super(NeuralNet, self).__init__()self.l1 = nn.Linear(input_size, hidden_size)self.relu1 = nn.LeakyReLU()self.l2 = nn.Linear(hidden_size, hidden_size)self.relu2 = nn.LeakyReLU()self.l3 = nn.Linear(hidden_size, hidden_size)self.relu3 = nn.LeakyReLU()self.l4 = nn.Linear(hidden_size, output_size)def forward(self, x):out = self.l1(x)out = self.relu1(out)out = self.l2(out)out = self.relu2(out)out = self.l3(out)out = self.relu3(out)out = self.l4(out)return out

这只是具有 LeakyReLU 激活函数的简单 MLP。然后,我将定义损失函数,以便在训练循环中稍后计算它们:

# Create the criterion that will be used for the DE part of the loss
criterion = nn.MSELoss()# Define the loss function for the initial condition
def initial_condition_loss(y, target_value):return nn.MSELoss()(y, target_value)

现在,我们将创建一个用作训练数据的时间数组,并实例化模型,并选择一种优化算法:

# Time vector that will be used as input of our NN
t_numpy = np.arange(0, 5+0.01, 0.01, dtype=np.float32)
t = torch.from_numpy(t_numpy).reshape(len(t_numpy), 1)
t.requires_grad_(True)# Constant for the model
k = 1# Instantiate one model with 50 neurons on the hidden layers
model = NeuralNet(hidden_size=50)# Loss and optimizer
learning_rate = 8e-3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)# Number of epochs
num_epochs = int(1e4)

最后,让我们开始训练循环:

for epoch in range(num_epochs):# Randomly perturbing the training points to have a wider range of timesepsilon = torch.normal(0,0.1, size=(len(t),1)).float()t_train = t + epsilon# Forward passy_pred = model(t_train)# Calculate the derivative of the forward pass w.r.t. the input (t)dy_dt = torch.autograd.grad(y_pred, t_train, grad_outputs=torch.ones_like(y_pred), create_graph=True)[0]# Define the differential equation and calculate the lossloss_DE = criterion(dy_dt + k*y_pred, torch.zeros_like(dy_dt))# Define the initial condition lossloss_IC = initial_condition_loss(model(torch.tensor([[0.0]])), torch.tensor([[1.0]]))loss = loss_DE + loss_IC# Backward pass and weight updateoptimizer.zero_grad()loss.backward()optimizer.step()

请注意使用 torch.autograd.grad 函数自动对输出 y_pred 相对于输入 t 进行微分,以计算损失函数。

6、结果

经过训练,我们可以看到损失函数迅速收敛。图 2 显示了损失函数与 epoch 数的关系图,其中的插图显示了损失函数下降最快的区域。

图 2:按时期划分的损失函数。在插图中,我们可以看到收敛速度最快的区域。

你可能已经注意到,这个神经网络并不常见。它没有训练数据(我们的训练数据是手工制作的时间戳向量,这只是我们想要研究的时间域),因此它从系统获得的所有信息都以损失函数的形式出现。它的唯一目的是在它被设计用于解决的时间域内求解微分方程。因此,为了测试它,我们使用它训练的时间域是公平的。图 3 显示了 NN 预测与理论答案(即解析解)之间的比较。

图 3:所示神经网络预测和微分方程的解析解预测。

我们可以看到两者之间有相当好的一致性,这对神经网络来说非常好。

这种方法的一个缺点是它不能很好地概括未来的时间。图 4 显示了如果我们将时间数据点向前移动五步会发生什么,结果简直是一片混乱。

图 4:神经网络和未见数据点的解析解。

因此,这里的教训是,这种方法是作为时间域内微分方程的数值求解器,不应将其用作常规神经网络,使用未见的训练域外数据进行预测并期望它能很好地推广。

7、结束语

毕竟,还有一个问题是:

为什么要费心训练一个不能很好地推广到未见数据的神经网络,而且它显然比解析解更差,因为它有内在的统计误差?

首先,这里提供的示例是一个微分方程的示例,其解析解是已知的。对于未知的解,仍然必须使用数值方法。话虽如此,用于微分方程求解的数值方法通常会累积误差。这意味着如果你试图在许多时间步骤中求解方程,解将在此过程中失去其准确性。另一方面,神经网络求解器学习如何在其每个训练时期为所有数据点求解 DE。

另一个原因是神经网络是良好的插值器,因此如果你想知道看不见的数据中的函数值(但这种“看不见的数据”必须位于你训练的时间间隔内),神经网络将迅速为你提供经典数值方法无法迅速给出的值。


原文链接:用神经网络求解微分方程 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1486965.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Python 使用TCP\UDP协议创建一个聊天室

server端代码: #encodingutf-8 # 服务端代码 import socketdef server():server_socket socket.socket(socket.AF_INET, socket.SOCK_STREAM)host socket.gethostname()port 12345server_socket.bind((host, port))server_socket.listen(5)print(等待客户端连接…

使用Gradle构建编译Spring boot 2.7.x

一、环境准备 JDK 1.8Spring boot 2.7.xGradle 7.5.1 (安装参考:win11安装Gradle)Idea 2023.1 二、源码导入gitee(可选) 按需导入。如果能科学上网,可跳过这一步。 为避免github访问不稳定问题,建议将对应的代码导入到gitee。然后通过git管…

内存泄漏详解

文章目录 什么是内存泄漏内存泄漏的原因排查及解决内存泄漏避免内存泄漏及时释放资源设置合理的变量作用域及时清理不需要的对象避免无限增长避免内部类持有外部类引用使用弱引用 什么是内存泄漏 内存泄漏是指不使用的对象持续占有内存使得内存得不到释放,从而造成…

【Java语法基础】1、变量、运算符、输入输出

1.变量、运算符、输入输出 跟C一样,先把必须写的框架写出来: package org.example; public class Main{public static void main(String[] args){//在里面写实际的代码} }变量 必须先定义,才能使用。与C、C差不多。 没有赋初值的变量无法…

windows网络应急排查

一、系统排查 msinfo32 #GUI显示的系统信息systeminfo #简单了解系统信息用户信息排查 排查恶意账号: 黑客喜欢建立相关账号用作远控: 1.建立新账号2.激活默认账号3.建立隐藏账号(windows中账号名$)cmd方法 net user #打印用户账号信息 ---看不到$结尾的隐藏账…

Linux - 进程的概念、状态、僵尸进程、孤儿进程及进程优先级

目录 进程基本概念 描述进程-PCB task_struct-PCB的一种 task_struct内容分类 查看进程 通过系统目录查看 通过ps命令查看 通过系统调用获取进程的PID和PPID 通过系统调用创建进程- fork初始 fork函数创建子进程 使用if进行分流 Linux进程状态 运行状态-R 浅度睡眠状态-S…

Apache Filnk----入门

文章目录 Flink 概述Flink 是什么有界流和无界流有状态流处理Flink 特点Flink vs SparkStreamingFlink 分层API Flink 快速上手WordCount 代码编写批处理流处理读取socket文本流 Flink 概述 Flink 是什么 有界流和无界流 无界数据流: 有定义流的开始,但没有定义流…

ts一些解决vscode飘红的方法

1、查看是否有些ts的数据类型定义问题,属性缺少或者属性类型不对 把对应属性加上即可 2、在飘红的代码前面设置// ts-ignore忽略此行校验(不过一般不建议用这个方法) 3、移除高版本不用的属性(版本属性兼容问题) 原因…

PP-Human行为识别(RTSP协议视频流实时检测)

基于PaddleDetection本地实现PP-Human行为识别模块(RTSP协议视频流实时检测) 项目介绍环境准备1. Anaconda 创建环境2. 获取 PaddleDetection3. 获取 [MediaMTX](https://github.com/bluenviron/mediamtx/releases/tag/v1.8.4)4. FFmpeg 获取5. VLC 获取…

.NET开源、简单、实用的数据库文档生成工具

前言 今天大姚给大家分享一款.NET开源(MIT License)、免费、简单、实用的数据库文档(字典)生成工具,该工具支持CHM、Word、Excel、PDF、Html、XML、Markdown等多文档格式的导出:DBCHM。 支持的数据库 Sq…

IEEE官方列表会议 | 第三届能源与环境工程国际会议(CFEEE 2024)

会议简介 Brief Introduction 2024年第三届能源与环境工程国际会议(CFEEE 2024) 会议时间:2024年12月2日-4日 召开地点:澳大利亚凯恩斯 大会官网:CFEEE 2024-2024 International Conference on Frontiers of Energy and Environment Engineer…

Android APP 音视频(01)MediaCodec解码H264码流

说明: 此MediaCodec解码H264实操主要针对Android12.0系统。通过读取sd卡上的H264码流Me获取视频数据,将数据通过mediacodec解码输出到surfaceview上。 1 H264码流和MediaCodec解码简介 1.1 H264码流简介 H.264,也被称为MPEG-4 AVC&#xff…

uni-app 影视类小程序开发从零到一 | 开源项目分享

引言 在数字娱乐时代,对于电影爱好者而言,随时随地享受精彩影片成为一种日常需求。分享一款基于 uni-app 开发的影视类小程序。它不仅提供了丰富的影视资源推荐,还融入了个性化知乎日报等内容,是不错的素材,同时对电影…

就业管理功能概述:构建智慧校园企业招聘平台

在智慧校园整体解决方案中,就业管理模块连接着学校与企业两端,更成为学生们步入社会、开启职业生涯梦想的关键门户。这一功能的核心价值,在于它如何巧妙地运用科技的力量,简化招聘流程,提升招聘效率,同时为…

5G赋能车联网,无人驾驶引领未来出行

无人驾驶车联网应用已成为智能交通领域的重要发展趋势。随着无人驾驶技术的不断进步和5G网络的广泛部署,5G工业路由器在无人驾驶车联网中的应用日益广泛,为无人驾驶车辆提供了稳定、高效、低时延的通信保障。 5G工业路由器的优势 低时延:5G网…

Python教程(一):环境搭建及PyCharm安装

目录 引言1. Python简介1.1 编译型语言 VS 解释型语言 2. Python的独特之处3. Python应用全览4. Python版本及区别5. 环境搭建5.1 安装Python: 6. 开发工具(IDE)6.1 PyCharm安装教程6.2 永久使用教程 7. 编写第一个Hello World结语 引言 在当…

Open3D 可视化窗口中查看点的坐标数据

目录 一、概述 1.1实现步骤 1.2应用 二、代码实现 2.1关键函数 2.2完整代码 三、实现效果 3.1选取点 3.2数据显示 前期试读,后续会将博客加入下列链接的专栏,欢迎订阅 Open3D与点云深度学习的应用_白葵新的博客-CSDN博客 一、概述 可以使用Op…

Java语言程序设计基础篇_编程练习题**15.19 (游戏:手眼协调)

**15.19 (游戏:手眼协调) 请编写一个程序,显示一个半径为10像素的实心圆,该圆放置在面板上的随机位置,并填充随机的顔色,如图15-29b所示。单击这个圆时,它会消失,然后在另一个随机的位置显示新的随机颜色的…

【工具】轻松转换JSON与Markdown表格——自制Obsidian插件

文章目录 一、插件简介二、功能详解三、使用教程四、插件代码五、总结 一、插件简介 JsonMdTableConverter是一款用于Obsidian的插件,它可以帮助用户在JSON格式和Markdown表格之间进行快速转换。这款插件具有以下特点: 轻松识别并转换JSON与Markdown表格…

Java | Leetcode Java题解之第278题第一个错误的版本

题目&#xff1a; 题解&#xff1a; public class Solution extends VersionControl {public int firstBadVersion(int n) {int left 1, right n;while (left < right) { // 循环直至区间左右端点相同int mid left (right - left) / 2; // 防止计算时溢出if (isBadVers…