【HGT】文献精讲:Heterogeneous Graph Transformer
标题: Heterogeneous Graph Transformer
(异构图Transformer)
作者团队: 加利福尼亚大学Yizhou Sun
摘要: 近年来,图神经网络(GNN)在结构化数据建模方面取得了成功。然而,大多GNN是为同质图设计的,其中所有节点和边都属于同一类型,这使得表示异构结构变得不可行。本文提出了异构图转换器(HGT)架构来建模web规模的异构图。为了建模异构性,本文设计了节点和边缘类型相关的参数来表征每个边缘上的异构注意力,使HGT能够为不同类型的节点和边缘维护专用表示。为了处理web规模的图数据,文章设计了异构小批图采样算法——HGSampling,以实现高效和可扩展的训练。在包含1.79亿个节点和20亿个边的开放学术图上进行的大量实验表明,所提出的HGT模型在各种下游任务上的性能始终优于所有最先进的GNN基线9%-21%。
文献链接: https://dl.acm.org/doi/abs/10.1145/3366423.3380027
代码链接: https://github.com/acbull/pyHGT
1. 背景
过去尝试采用GNN与异构网络进行学习的工作往往面临着以下几个问题:
- 大多数涉及为每种类型的异构图设计元路径或者变体,需要特定的领域知识;
- 要么简单地假设不同类型的节点和边共享相同的特征和表示空间,要么单独为节点类型或边类型保留不同的非共享权重,使得它们不足以捕获异构图的属性;
- 它们固有的设计和实现使得它们无法对web规模的异构图进行建模。
鉴于以上这些限制,本文设计了研究异构图的神经网络HGT,目标是保持节点和边类型的依赖表示,同时避免自定义元路径,并且可以扩展到web规模的异构图。
为了处理图的异构性,文章引入了依赖于节点和边类型的注意力机制,HGT中的异构相互注意不是参数化每种类型的边,而是通过基于其元关系三元组分解每个边 e = < s , t > e=<s,t> e=<s,t>来定义,其中s、t表示节点类型,e表示s与t之间的边类型。具体来说,HGT使用这些元关系来参数化权重矩阵,以计算每条边的注意力分数。因此,这样允许不同类型的节点和边保持其特定的表示空间。与此同时,不同类型的连接节点人若干可以交互、传递和聚合消息,不受其分布间隙的限制。
由于其体系结构的性质,HGT可以通过跨层的消息传递来合并来自不同类型的高阶邻居的信息,这也被称为“软”元路径。也就是说HGT只把它的一跳边作为输入,而不用手动设计元路径,所提出的注意力机制也可以自动和隐式地学习和提取对不同下游任务很重要的元路径。
2.方法
上图显示了HGT的整体架构。给定一个采样的异构子图,HGT提取所有连接的节点对,其中目标节点t通过边e被源节点s连接,HGT的目标是聚合来自s的信息,以获得节点目标t的上下文表示。该过程主要是三个组件:异构互注意力、异构消息传递和目标特定聚合。
2.1 Heterogeneous Mutual Attention 异构互注意力
将第 l l l个HGT层的输出表示为 H ( l ) H^{(l)} H(l),同时也作为第 l + 1 l+1 l+1个HGT层的输入。通过堆叠L层,可以得到整个图 H ( L ) H^{(L)} H(L)的节点表示,可以用于端到端训练或馈送到下游任务。
第一步首先是计算源节点 s s s与目标节点 t t t之间的相互注意力分数。在基于注意力的GNN中,有
H l [ t ] ← A g g r e g a t e ∀ s ∈ N ( t ) , ∀ e ∈ E ( s , t ) ( A t t e n t i o n ( s , t ) ⋅ M e s s a g e ( s ) ) H^{l} [t]\gets \underset{\forall s\in N(t),\forall e\in E(s,t)}{\mathbf{Aggregate } }(\mathbf{Attention}(s,t)\cdot \mathbf{Message}(s) ) Hl[t]←∀s∈N(t),∀e∈E(s,t)Aggregate(Attention(s,t)⋅Message(s))
其中三个基本的运算符:
- A t t e n t i o n \mathbf{Attention} Attention用于计算每个源节点的重要性
- M e s s a g e \mathbf{Message} Message用于使用源节点来提取消息
- A g g r e g a t e \mathbf{Aggregate } Aggregate通过关注权重对邻居信息进行聚合
例如,图注意力网络(GAT)采用了一种加性机制作为 A t t e n t i o n \mathbf{Attention} Attention,使用相同的权重来计算 M e s s a g e \mathbf{Message} Message,并利用简单平均和非线性激活函数来进行 A g g r e g a t e \mathbf{Aggregate } Aggregate步骤。形式上来看,GAT有:
A t t e n t i o n G A T ( s , t ) = ∀ s ∈ N ( t ) ( a ⃗ ( W H l − 1 [ t ] ∥ W H l − 1 [ s ] ) ) M e s s a g e G A T ( s ) = W H l − 1 [ s ] A g g r e g a t e G A T ( ⋅ ) = σ ( Mean ( ⋅ ) ) \begin{aligned} \mathbf{ Attention }_{G A T}(s, t) & =\underset{\forall s \in N(t)}{ }\left(\vec{a}\left(W H^{l-1}[t] \| W H^{l-1}[s]\right)\right) \\ \mathbf { Message }_{G A T}(s) & =W H^{l-1}[s] \\ \mathbf { Aggregate }_{G A T}(\cdot) & =\sigma(\operatorname{Mean}(\cdot)) \end{aligned} AttentionGAT(s,t)MessageGAT(s)AggregateGAT(⋅)=∀s∈N(t)(a(WHl−1[t]∥WHl−1[s]))=WHl−1[s]=σ(Mean(⋅))
虽然GAT对重要节点给予高关注度是有效的,但是它通过使用一个权重矩阵 W W W来假设节点 s s s和节点 t t t具有相同的特征分布,这种假设对于异构图来说通常是不正确的,因为每种类型的节点都会有自己的特征分布。基于此,作者设计了异构互注意力机制。
给定一个目标节点 t t t,以及它所有的邻居节点 s ∈ N ( t ) s \in N(t) s∈N(t) ,它们可能属于不同的分布,根据它们的元关系计算它们的相互注意力,即 < τ ( s ) , ϕ ( e ) , τ ( s ) > <\tau(s), \phi(e), \tau(s)> <τ(s),ϕ(e),τ(s)>三元组。
作者将目标节点 t t t映射为一个Query向量,将源节点 s s s映射为一个Key向量,并计算它们的点积作为注意力。其与普通的Transformer的关键区别在于,普通Transformer对所有单词使用一组投影,而在HGT中,每个元关系都有一组不同的投影权重。为了最大限度地实现参数共享,同时保持不同关系地特定特征,作者提出将交互算子地权重矩阵参数化为源节点投影、边投影和目标节点投影。具体来说,通过以下公式为每条边 e = ( s , t ) e=(s,t) e=(s,t)计算 h h h个头注意力:
A t t e n t i o n H G T ( s , e , t ) = Softmax ∀ s ∈ N ( t ) ( ∏ i ∈ [ 1 , h ] A T T − head i ( s , e , t ) ) A T T − head i ( s , e , t ) = ( K i ( s ) W ϕ ( e ) A T T Q i ( t ) T ) ⋅ μ ⟨ τ ( s ) , ϕ ( e ) , τ ( t ) ⟩ d K i ( s ) = K-Linear τ ( s ) i ( H ( l − 1 ) [ s ] ) Q i ( t ) = Q-Linear τ ( t ) i ( H ( l − 1 ) [ t ] ) \begin{aligned} \mathbf{Attention}_{H G T}(s, e, t) & =\underset{\forall s \in N(t)}{\operatorname{Softmax}}\left(\prod_{i \in[1, h]} A T T-\text { head }^{i}(s, e, t)\right) \\ A T T-\text { head }^{i}(s, e, t) & =\left(K^{i}(s) W_{\phi(e)}^{A T T} Q^{i}(t)^{T}\right) \cdot \frac{\mu_{\langle\tau(s), \phi(e), \tau(t)\rangle}}{\sqrt{d}} \\ K^{i}(s) & =\text { K-Linear }_{\tau(s)}^{i}\left(H^{(l-1)}[s]\right) \\ Q^{i}(t) & =\text { Q-Linear }_{\tau(t)}^{i}\left(H^{(l-1)}[t]\right) \end{aligned} AttentionHGT(s,e,t)ATT− head i(s,e,t)Ki(s)Qi(t)=∀s∈N(t)Softmax i∈[1,h]∏ATT− head i(s,e,t) =(Ki(s)Wϕ(e)ATTQi(t)T)⋅dμ⟨τ(s),ϕ(e),τ(t)⟩= K-Linear τ(s)i(H(l−1)[s])= Q-Linear τ(t)i(H(l−1)[t])
首先,对于第 i i i个注意的头 A T T − h e a d i ( s , e , t ) ATT-head^i(s,e,t) ATT−headi(s,e,t),通过线性函数 $K-Linear_{\tau(s)}^{i}:\mathbb{R} ^{d} \to \mathbb{R} ^{\frac{d}{h} } $ 将源节点 τ ( s ) \tau (s) τ(s)投影到第 i i i个键向量 K i ( s ) K^i(s) Ki(s),其中 h h h是注意力头的数量, d h \frac{d}{h} hd是每个头的向量维度。注意, K − L i n e a r τ ( s ) i K-Linear_{\tau(s)}^{i} K−Linearτ(s)i是由源节点 s s s的类型 τ ( s ) \tau (s) τ(s)索引的,这意味着每种类型的节点都有一个唯一的线性投影,以最大限度地模拟分布差异。类似地,通过线性函数 Q − L i n e a r τ ( t ) i Q-Linear _{\tau(t)}^{i} Q−Linearτ(t)i将目标节点 t t t投影到第 i i i个查询向量中。
接下来,计算查询向量 Q i ( t ) Q^i(t) Qi(t)和键向量 K i ( s ) K^i(s) Ki(s)的相似性。异构图的一个独特之处就是有可能存在不同边类型之间的一堆节点类型,例如 τ ( s ) \tau(s) τ(s)和 τ ( t ) \tau(t) τ(t)。因此仍然为每种类型的边 ϕ ( e ) \phi(e) ϕ(e)使用独特的边缘权重矩阵 W ϕ ( e ) A T T ∈ R d h × d h W^{ATT}_{\phi (e)} \in \mathbb{R}^{\frac{d}{h} \times \frac{d}{h}} Wϕ(e)ATT∈Rhd×hd。这样一来,即使在相同的节点类型对之间,模型也可以捕获不同的语义关系。另外,由于并非所有关系对目标节点的贡献相同,作者添加一个先验张量 μ ∈ R ∣ A ∣ × ∣ R ∣ × ∣ A ∣ \mu \in \mathbb{R}^{|\mathcal{A} |\times |\mathcal{R} | \times |\mathcal{A} |} μ∈R∣A∣×∣R∣×∣A∣来表示每个元关系三元组的一般意义,作为对注意力的自适应缩放。
最后,将 h h h个注意力头连接在一起,等到每个节点对的注意向量。随后对于每个目标节点 t t t,从它的邻居节点 N ( t ) N(t) N(t)收集所有的注意力向量,并使用softmax函数,使其满足 $ { \sum_{\forall s\in N(t)}}\mathbf{Attention} _{HGT} (s,e,t)=\mathbf{1} _{h\times 1} $.
2.2 Heterogeneous Message Passing 异构消息传递
与第一步计算异构互注意力并行,将信息从源节点传递到目标节点。对于一对节点 e = ( s , t ) e=(s,t) e=(s,t),通过以下公式计算其多头消息:
Message H G T ( s , e , t ) = ∥ i ∈ [ 1 , h ] MSG-head i ( s , e , t ) M S G − head i ( s , e , t ) = M-Linear τ ( s ) i ( H ( l − 1 ) [ s ] ) W ϕ ( e ) M S G \begin{array}{l} \operatorname{Message}_{H G T}(s, e, t)=\|_{i \in[1, h]} \text { MSG-head }^{i}(s, e, t) \\ M S G-\text { head }^{i}(s, e, t)=\text { M-Linear }{ }_{\tau(s)}^{i}\left(H^{(l-1)}[s]\right) W_{\phi(e)}^{M S G} \end{array} MessageHGT(s,e,t)=∥i∈[1,h] MSG-head i(s,e,t)MSG− head i(s,e,t)= M-Linear τ(s)i(H(l−1)[s])Wϕ(e)MSG
为了得到第 i i i个消息头 M S G − h e a d i ( s , e , t ) MSG-head^i(s,e,t) MSG−headi(s,e,t),首先使用线性函数$M-Linear_{\tau(s)}^{i}:\mathbb{R} ^{d} \to \mathbb{R} ^{\frac{d}{h} } 将 将 将\tau(s) 型源节点 型源节点 型源节点s 投影到第 投影到第 投影到第i 个消息向量中。随后使用矩阵 个消息向量中。随后使用矩阵 个消息向量中。随后使用矩阵W^{MSG}{\phi (e)} \in \mathbb{R}^{\frac{d}{h} \times \frac{d}{h}} 合并边缘依赖性。最后连接所有 合并边缘依赖性。最后连接所有 合并边缘依赖性。最后连接所有h 个消息头以获取每个节点对的 个消息头以获取每个节点对的 个消息头以获取每个节点对的\mathbb{Message}{HGT}(s,e,t)$
2.3 Target-Specific Aggregation 目标特定聚合
最后将计算的异构多头注意力和消息从源节点聚合到目标节点。在计算异构互注意力过程中softmax已经使每个目标节点 t t t的注意力向量汇聚成了一个和。因此此时可以简单地使用注意力向量作为权重来平均来自源节点的对应消息,并得到更新的向量 H ~ ( l ) [ t ] \widetilde{H}^{(l)}[t] H (l)[t]为:
H ~ ( l ) [ t ] = ⊕ ∀ s ∈ N ( t ) ( A t t e n t i o n H G T ( s , e , t ) ⋅ M e s s a g e H G T ( s , e , t ) ) \widetilde{H}^{(l)}[t]=\underset{\forall s \in N(t)}{\oplus}\left(\mathbf { Attention }_{H G T}(s, e, t) \cdot \mathbf { Message }_{H G T}(s, e, t)\right) H (l)[t]=∀s∈N(t)⊕(AttentionHGT(s,e,t)⋅MessageHGT(s,e,t))
以此将不同特征分布的源节点所有的邻居节点的信息聚合到目标节点 t t t。
最后再将目标节点 t t t的向量由其节点类型 τ ( t ) \tau(t) τ(t)的索引映射回其特定类型的分布。为此,使用线性函数 A − L i n e a r τ ( t ) A-Linear_{\tau(t)} A−Linearτ(t)来更新向量 H ~ ( l ) [ t ] \widetilde{H}^{(l)}[t] H (l)[t],然后进行非线性激活和残差连接为:
H ~ ( l ) [ t ] = σ ( A − L i n e a r τ ( t ) H ~ ( l ) [ t ] ) + H ~ ( l − 1 ) [ t ] \widetilde{H}^{(l)}[t]=\sigma(A-Linear_{\tau(t)\widetilde{H}^{(l)}[t]})+\widetilde{H}^{(l-1)}[t] H (l)[t]=σ(A−Linearτ(t)H (l)[t])+H (l−1)[t]
这样得到目标节点 t t t的第 l l l个HGT层的输出 H ( l ) [ t ] H^{(l)}[t] H(l)[t],将 l l l层的HGT块堆叠,使每个节点在整个图中达到不同类型和关系的节点的很大比例,即HGT为每个节点生成一个高度情境化的表示 H ( L ) H^{(L)} H(L),该表示可以输送到任何模型中进行下游异构网络任务,如节点分类和链路预测。
2.4 HGSampling 小批图采样算法
作者还提出一种高效的异构小批图采样算法—HGSampling——使HGT和传统GNN都能处理web规模的异构图。HGSampling能够为每种类型保持相似数量的节点和边,同时还能采样子图的密度,以最小化信息损失和样本方差。
以上为HGSampling的伪代码。其基本思想是为每种节点类型 τ \tau τ保持一个单独的节点预算 B [ τ ] B[\tau] B[τ],并使用重要采样策略对每种类型的节点进行相同数量的采样以减小方差。
3. 结果
最终在OAG数据集中进行节点分类任务发现,HGT模型远优于当前其他模型。