图神经网络实战——分层自注意力网络

图神经网络实战——分层自注意力网络

    • 0. 前言
    • 1. 分层自注意力网络
      • 1.1 模型架构
      • 1.2 节点级注意力
      • 1.3 语义级注意力
      • 1.4 预测模块
    • 2. 构建分层自注意力网络
    • 相关链接

0. 前言

在异构图数据集上,异构图注意力网络的测试准确率为 78.39%,比之同构版本有了较大提高,但我们还能进一步提高准确率。在本节中,我们将学习一种专门用于处理异构图的图神经网络架构,分层自注意力网络 (hierarchical self-attention network, HAN)。我们将介绍其工作原理,以便更好地理解该架构与经典图注意力网络 (Graph Attention Networks,GAT) 之间的区别。最后,使用 PyTorch Geometric 实现此架构,并将结果与其它 GNN 模型进行比较。

1. 分层自注意力网络

1.1 模型架构

在本节中,我们将实现一个专为处理异构图而设计的图神经网络 (Graph Neural Networks, GNN) 模型——分层自注意力网络 (hierarchical self-attention network, HAN)。该架构由 Liu 等人于 2021 年提出。HAN 在两个不同层次上使用自注意力:

  • 节点级注意力 (Node-level attention):了解给定元路径中相邻节点的重要性
  • 语义级注意力 (Semantic-level attention):了解每个元路径的重要性。这是 HAN 的主要特点,它允许我们自动为给定任务选择最佳元路径。例如,在某些任务(如预测玩家人数)中,元路径 game-user-game 可能比 game-dev-game 更合适

接下来,我们将详细介绍 HAN 的三个主要组件——节点级注意力 (Node-level attention)、语义级注意力 (Semantic-level attention) 和预测模块 (prediction module),HAN 架构如下所示。

HAN 架构)

1.2 节点级注意力

与图注意力网络 (Graph Attention Networks,GAT) 一样,第一步将节点投影到每个元路径的统一特征空间中。然后,用第二个权重矩阵计算同一元路径中的节点对(两个投影节点的连接)的权重,并对这一结果应用非线性函数,然后用 softmax 函数对其进行归一化处理。 j j j 节点对 i i i 节点的归一化注意力分数(重要性)计算如下:
α i j Φ = exp ⁡ ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h j ] ) ) ∑ k ∈ N i Φ exp ⁡ ( σ ( a Φ T [ W Φ h i ∣ ∣ W Φ h k ] ) ) \alpha_{ij}^\Phi =\frac {\exp (\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_j]))}{\sum _{k\in \mathcal N_i^{\Phi}}\exp(\sigma(a_{\Phi}^T[W_{\Phi}h_i||W_{\Phi}h_k]))} αijΦ=kNiΦexp(σ(aΦT[WΦhi∣∣WΦhk]))exp(σ(aΦT[WΦhi∣∣WΦhj]))
其中, h i h_i hi 表示 i i i 节点的特征, W Φ W_{\Phi} WΦ Φ \Phi Φ 元路径的共享权重矩阵, a Φ a_{\Phi} aΦ Φ \Phi Φ 元路径的注意力权重矩阵, σ σ σ 是非线性激活函数(如 LeakyReLU), N i Φ \mathcal N_i^{\Phi} NiΦ 是节点(包括其自身)在 Φ \Phi Φ 元路径中的邻居集。使用多头注意力获得最终的嵌入:
Z i = ∣ ∣ k = 1 K σ ( ∑ k ∈ N i α i j Φ ⋅ W Φ h j ) Z_i=||_{k=1}^K\sigma(\sum _{k\in \mathcal N_i}\alpha _{ij}^{\Phi}\cdot W_{\Phi}h_j) Zi=k=1Kσ(kNiαijΦWΦhj)

1.3 语义级注意力

对于语义级注意力,我们对每个元路径的注意力得分(表示为 β Φ 1 , β Φ 2 , . . . , β Φ p β_{\Phi _1}, β_{\Phi _2}, ... , β_{\Phi _p} βΦ1,βΦ2,...,βΦp )重复类似的过程。对于给定元路径中的每个节点嵌入(表示为 Z Φ p Z_{\Phi _p} ZΦp),都将其馈送到一个多层感知机 (Multilayer Perceptron, MLP) 中,应用非线性变换。将这一结果与新的注意力向量 q q q 进行比较,作为相似性度量。我们将这一结果平均化,以计算给定元路径的重要性:
w Φ p = 1 ∣ V ∣ ∑ i ∈ V q T ⋅ tanh ⁡ ( W ⋅ z i Φ p + b ) w_{\Phi_p}=\frac 1{|V|}\sum_{i\in V}q^T\cdot \tanh(W\cdot z_i^{\Phi_p}+b) wΦp=V1iVqTtanh(WziΦp+b)
其中, W W W (MLP 的权重矩阵)、 b b b (MLP 的偏置)和 q q q (语义级注意力向量)在元路径中是共享的。
必须对这一结果进行归一化处理,以比较不同的语义级注意力得分。使用 softmax 函数来获得最终权重:
β Φ p = exp ⁡ ( w Φ p ) ∑ k = 1 P exp ⁡ ( w Φ k ) \beta _{\Phi_p}=\frac {\exp(w_{\Phi_p})}{\sum_{k=1}^P\exp(w_{\Phi_k})} βΦp=k=1Pexp(wΦk)exp(wΦp)
将节点级注意力和语义级注意力结合起来得到最终嵌入 Z Z Z
Z = ∑ p = 1 P β Φ p ⋅ Z Φ p Z=\sum_{p=1}^P\beta_{\Phi_p}\cdot Z_{\Phi_p} Z=p=1PβΦpZΦp

1.4 预测模块

最后一层(如多层感知机 (Multilayer Perceptron, MLP) )用于针对特定的下游任务(如节点分类或链接预测)对模型进行微调。

2. 构建分层自注意力网络

接下来,使用 PyTorch Geometric 在 DBLP 数据集上实现分层自注意力网络 (hierarchical self-attention network, HAN) 架构。

(1) 首先,导入 HAN 层:

import torch
import torch.nn.functional as F
from torch import nnimport torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HANConv, Linear

(2) 加载 DBLP 数据集,并为会议节点引入虚拟特征:

dataset = DBLP('.')
data = dataset[0]
print(data)data['conference'].x = torch.zeros(20, 1)

数据集

(3) 使用 HANConvHAN 卷积层和用于最终分类的线性层创建 HAN 类:

class HAN(nn.Module):def __init__(self, dim_in, dim_out, dim_h=128, heads=8):super().__init__()self.han = HANConv(dim_in, dim_h, heads=heads, dropout=0.6, metadata=data.metadata())self.linear = nn.Linear(dim_h, dim_out)

(4)forward() 方法中,我们必须指定需要关注作者:

    def forward(self, x_dict, edge_index_dict):out = self.han(x_dict, edge_index_dict)out = self.linear(out['author'])return out

(5) 使用懒初始化 (dim_in=-1) 来初始化模型,因此 PyTorch Geometric 会自动计算每个节点类型的输入大小:

model = HAN(dim_in=-1, dim_out=4)

(6) 实例化 Adam 优化器,并尝试将数据和模型传输到 GPU

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

(7) 实现 test() 函数计算分类任务的准确率:

@torch.no_grad()
def test(mask):model.eval()pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()return float(acc)

(8) 对模型进行 101epoch 的训练,与同构图神经网络 (Graph Neural Networks, GNN) 的训练循环唯一不同的是,需要指定关注的作者节点类型:

for epoch in range(101):model.train()optimizer.zero_grad()out = model(data.x_dict, data.edge_index_dict)mask = data['author'].train_maskloss = F.cross_entropy(out[mask], data['author'].y[mask])loss.backward()optimizer.step()if epoch % 20 == 0:train_acc = test(data['author'].train_mask)val_acc = test(data['author'].val_mask)print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')

训练过程如下:

模型训练过程

(9) 最后,在测试集上测试训练后的模型:

test_acc = test(data['author'].test_mask)
print(f'Test accuracy: {test_acc*100:.2f}%')# Test accuracy: 81.58%

HAN 的测试准确率为 81.58%,高于异构图注意力网络 (78.39%)和经典图注意力网络 (Graph Attention Networks,GAT) (73.29%)。这说明了构建良好的表示方法以聚合不同类型节点和关系的重要性。异构图的技术在很大程度上取决于具体应用,但尝试不同的模型对于构建高性能应能具有重要作用。

相关链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(18)——消息传播神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1549920.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

信息安全工程师(26)物理安全概念与要求

前言 物理安全是网络安全体系中的重要组成部分,它关注于保护物理环境、设备和资源免受未经授权的访问、破坏、损坏或盗窃。 一、物理安全概念 物理安全,也称为实体安全,是指通过采取各种物理措施来保护支持网络信息系统运行的硬件&#xff08…

你还在用Java8吗?

Java 11 在企业中,Java的不同版本使用情况随着时间在不断变化。根据最新的数据报告,以下是一些关键点: Java 11 和 Java 17 成为企业中最常用的长期支持(LTS)版本,使用率分别为 48% 和 45%,而 …

1Panel安装部署证书(httpsok.com)

1Panel安装部署证书(httpsok.com) 购买服务器 推荐购买香港服务器,这样通过域名访问就不需要备案。 创建静态站点 申请SSL证书 进入 httpsok.com,点击申请证书 输入站点域名 根据提示,添加DNS解析记录 添加成功后,提示域名验证…

免费送源码:Javaspringboot++MySQL springboot 社区互助服务管理系统小程序 计算机毕业设计原创定制

摘 要 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受居民的喜爱,社区互助服务管理系统小程序被居民普遍使用,为…

【二叉平衡搜索树】Treap

前置 本篇是平衡树-treap的补充学习笔记。 Treap - 树堆 学习基础:适合一定基础的:比如,实现了经典二叉搜索树(常用的几个函数写过), 和二叉堆(数组的上浮下沉会写吗?)&a…

Win11右键默认显示更多设置教程

Win11最大的变化之一莫过于右键菜单发生了变化,最大的问题是什么,是右键菜单很多时候需要点两次,实在是反人类,太麻烦了。 简直是反人类! 必须使用“显示更多选项”的右键菜单。 以管理员方式运行CMD 复制以下命令直…

IP6537_C_30W20V--移动设备快充的得力助手,集成 14 种快充协议的降压 SoC

IP6537_C_30W20V是一款集成同步开关的降压转换器、支 持 14 种输出快充协议、支持 Type-C 输出和 USB PD2.0/PD3.0(PPS)协议的 SoC,为车载充电器、 快充适配器、智能排插提供完整的解决方案。 IP6537_C_30W20V支持 USB Type-C 或者 USB A 输出, 5V 输出功…

C++的STL标准模板库容器--vector类

前言:vector类也可以叫作顺序表,他同样也在STL库中的容器模块中,我还是主要讲常用的几个成员函数和成员变量,同时它晚于string类,所以在设计上没有string冗余。 vector类是一个和数组类似的容器,它属于随机…

插上网线无法连接网络,控制面板以太网消失 | 如何重装网络驱动

如果你确定你的网线没问题,网线插口没问题,那你大概率就是驱动问题,可以试一下本方法。 0 以太网消失 事情是这样的,我工作时候需要接内网,插网线,摸鱼时候连外网,我就把网线关了。 每次插网线…

【Python大语言模型系列】开源机器人对话系统框架RASA介绍与使用(案例分析)

这是我的第361篇原创文章。 一、引言 Rasa是一个开源的对话式 AI 框架,用于构建自定义的对话式 AI 助手。它可以处理自然语言理解(NLU)和对话管理(DM),使得开发者能够轻松地创建功能丰富的对话式 AI 应用。…

番外篇 | 应对遮挡挑战,北航提出新型模型YOLOv5-FFM表现优异

前言:Hello大家好,我是小哥谈。在本文中,作者提出了一种改进的轻量级YOLOv5-FFM模型来解决行人检测遮挡问题。为了实现目标,作者在YOLOv5模型框架基础上进行了改进,并引入了Ghost模块和SE模块。此外,作者还设计了一个局部特征融合模块(FFM)来处理行人检测中的遮挡问题。…

I/O中断处理过程

中断控制器位于CPU和外设之间,用于处理I/O中断请求。以下是一个简化的中断控制器: 现在有A,B,C三个中断源。中断响应优先级:A>B>C,中断处理优先级:C>B>A 假设CPU正在处理A中断源的中断请求,此时…

Mixture-of-Experts (MoE): 条件计算的诞生与崛起【上篇】

大型语言模型(LLM)的现代进步主要是缩放定律的产物[6]。 假设模型是在足够大的数据集上训练出来的,那么随着底层模型规模的增加,我们会看到性能的平滑提升。 这种扩展规律最终促使我们创建了 GPT-3 以及随后的其他(更强…

Excel技巧:Excel批量提取文件名

Excel是大家经常用来制作表格的文件,比如输入文件名,如果有大量文件需要输入,用张贴复制或者手动输入的方式还是很费时间的,今天和大家分享如何批量提取文件名。 打开需要提取文件名的文件夹,选中所有文件&#xff0c…

在线翻译器工具横评:性能、准确率大比拼

无论是旅行者在异国他乡探寻风土人情,学者研究国外的前沿学术成果,还是商务人士与国际伙伴洽谈合作,都离不开一种高效、准确的语言沟通工具。而翻译器在线翻译能很好的帮我们解决这个问题。今天我们一起来探讨有那些好用的翻译工具。 1.福昕…

Golang | Leetcode Golang题解之第443题压缩字符串

题目: 题解: func compress(chars []byte) int {write, left : 0, 0for read, ch : range chars {if read len(chars)-1 || ch ! chars[read1] {chars[write] chwritenum : read - left 1if num > 1 {anchor : writefor ; num > 0; num / 10 {…

【题解】2022ICPC杭州-K

翻译 原题链接   简述一下就是每次询问重新定义一个字母排序表&#xff0c;问在这个顺序下n个字符串的序列的逆序数是多少。 字典树计算逆序数 先考虑初始状况下&#xff0c;即 a < b < . . . < z a<b<...<z a<b<...<z的情况下&#xff0c;逆序…

Arch - 架构安全性_验证(Verification)

文章目录 OverView导图1. 引言&#xff1a;数据验证的重要性概述2. 数据验证的基本概念3. 数据验证的层次前端验证后端验证 4. 数据验证的标准做法5. 自定义校验注解6. 校验结果的处理7. 性能考虑与副作用8. 小结 OverView 即使只限定在“软件架构设计”这个语境下&#xff0c…

金融科技革命:API接口开放平台,畅通金融服务之路

金融科技是近年来蓬勃发展的领域&#xff0c;它利用先进的技术手段来改善和创新金融服务。在金融科技的革命中&#xff0c;API接口开放平台扮演着重要的角色&#xff0c;它通过提供统一的接口服务&#xff0c;让金融机构和其他行业能够更方便地进行数据交换和合作。本文将以挖数…

聚星文社最新风格图库角色

聚星文社最新风格图库角色涵盖了各种不同的风格和类型。以下是一些可能的角色风格&#xff1a; Docs聚星文社https://iimenvrieak.feishu.cn/docx/ZhRNdEWT6oGdCwxdhOPcdds7nof 现代都市风格角色&#xff1a;这种角色通常穿着时尚的衣服&#xff0c;有时尚的发型和化妆。他们可…