图神经网络GNN(一)GraphEmbedding

DeepWalk


使用随机游走采样得到每个结点x的上下文信息,记作Context(x)。
SkipGram优化的目标函数:P(Context(x)|x;θ)
θ = argmax P(Context(x)|x;θ)
DeepWalk这种GraphEmbedding方法是一种无监督方法,个人理解有点类似生成模型的Encoder过程,下面的代码中,node_proj是一个简单的线性映射函数,加上elu激活函数,可以看作Encoder的过程。Encoder结束后就得到了Embedding后的隐变量表示。其实GraphEmbedding要的就是这个node_proj,但是由于没有标签,只有训练数据的内部特征,怎么去训练呢?这就需要看我们的训练任务了,个人理解,也就是说,这种无监督的embedding后的结果取决于你的训练任务,也就是Decoder过程。Embedding后的编码对Decoder过程越有利,损失函数也就越小,编码做的也就越好。在word2vec中,有两种训练任务,一种是给定当前词,预测其前两个及后两个词发生的条件概率,采用这种训练任务做出的embedding就是skip-gram;还有一种是给定当前词前两个及后两个词,预测当前词出现的条件概率,采用这种训练任务做出的embedding就是CBOW.DeepWalk作者的论文中采用的是skip-gram。故复现也采用skip-gram进行复现。
针对skip-gram对应的训练任务,代码中的node_proj相当于编码器,h_o_1和h_o_2相当于解码器。Encoder和Decoder可以先联合训练,训练结束后,可以只保留Encoder的部分,舍弃Decoder的部分。当再来一个独热编码的时候,可以直接通过node_proj映射,即完成了独热编码的embedding过程。
(本代码假定在当前结点去往各邻接结点的可能性相同,即不考虑边的权重)

import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F
import networkx as nx
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.distributions import Categorical
import matplotlib.pyplot as pltclass MyGraph():def __init__(self,device):super(MyGraph, self).__init__()self.G = nx.read_edgelist(path='data/wiki/Wiki_edgelist.txt',create_using=nx.DiGraph(),nodetype=None,data=[('weight',int)])self.adj_matrix = nx.attr_matrix(self.G)self.edges = nx.edges(self.G)self.edges_emb = torch.eye(len(self.G.edges)).to(device)self.nodes_emb = torch.eye(len(self.G.nodes)).to(device)class GraphEmbedding(nn.Module):def __init__(self,nodes_num,edges_num,device,emb_dim = 10):super(GraphEmbedding, self).__init__()self.device = deviceself.nodes_proj = nn.Parameter(torch.randn(nodes_num,emb_dim))self.edges_proj = nn.Parameter(torch.randn(edges_num,emb_dim))self.h_o_1 = nn.Parameter(torch.randn(emb_dim,nodes_num * 2))self.h_o_2 = nn.Parameter(torch.randn(nodes_num * 2,nodes_num))def forward(self,G:MyGraph):self.nodes_proj,self.edges_proj = self.nodes_proj.to(self.device),self.edges_proj.to(device)self.h_o_1,self.h_o_2 = self.h_o_1.to(self.device),self.h_o_2.to(self.device)# Encoderedges_emb,nodes_emb = torch.matmul(G.edges_emb,self.edges_proj),torch.matmul(G.nodes_emb,self.nodes_proj)nodes_emb = F.elu_(nodes_emb)edges_emb,nodes_emb = edges_emb.to(device),nodes_emb.to(device)# Decoderpolicy = self.DeepWalk(G,gamma=5,window=2)outputs = torch.matmul(torch.matmul(nodes_emb[policy[:,0]],self.h_o_1),self.h_o_2)policy,outputs = policy.to(device),outputs.to(device)return policy,outputsdef DeepWalk(self,Graph:MyGraph,gamma:int,window:int,eps=1e-9):# Calculate transpose matrixadj_matrix = torch.tensor(Graph.adj_matrix[0], dtype=torch.float32)for i in range(adj_matrix.shape[0]):adj_matrix[i,:] /= (torch.sum(adj_matrix[i]) + eps)adj_nodes = Graph.adj_matrix[1].copy()random.shuffle(adj_nodes)nodes_idx, route_result = [],[]for node in adj_nodes:node_idx = np.where(np.array(Graph.adj_matrix[1]) == node)[0].item()node_list = self.Random_Walk(adj_matrix,window=window,node_idx=node_idx)route_result.append(node_list)return torch.tensor(route_result)def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]node_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_listdef HMM_process(self,adj_matrix:torch.Tensor,node_idx:int,eps=1e-9):pi = torch.zeros((1, adj_matrix.shape[0]), dtype=torch.float32)pi[:,node_idx] = 1.0pi = torch.matmul(pi,adj_matrix)pi = pi.squeeze(0) / (torch.sum(pi) + eps)return piif __name__ == "__main__":epochs = 200device = torch.device("cuda:1")cross_entrophy_loss = CrossEntropyLoss().to(device)Graph = MyGraph(device)Embedding = GraphEmbedding(nodes_num=len(Graph.G.nodes), edges_num=len(Graph.G.edges),device=device).to(device)optimizer = torch.optim.Adam(Embedding.parameters(),lr=1e-5)scheduler=CosineAnnealingLR(optimizer,T_max=50,eta_min=0.05)loss_list = []epoch_list = [i for i in range(1,epochs+1)]for epoch in range(epochs):policy,outputs = Embedding(Graph)outputs = outputs.unsqueeze(1).repeat(1,policy.shape[-1]-1,1).reshape(-1,outputs.shape[-1])optimizer.zero_grad()loss = cross_entrophy_loss(outputs, policy[:,1:].reshape(-1))loss.backward()optimizer.step()scheduler.step()loss_list.append(loss.item())print(f"Loss : {loss.item()}")plt.plot(epoch_list,loss_list)plt.xlabel('Epoch')plt.ylabel('CrossEntrophyLoss')plt.title('Loss-Epoch curve')plt.show()

在这里插入图片描述

Node2Vec

在这里插入图片描述
在这里插入图片描述
修改Random_Walk函数如下:

    def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]if i > 0:v,t = node_list[-1],node_list[-2]x_list = torch.nonzero(adj_matrix[v]).squeeze(-1)for x in x_list:if t == x:  # 0pi[x] *= 1/self.pelif adj_matrix[t][x] == 1:  # 1pi[x] *= 1else:   # 2pi[x] *= 1/self.qnode_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_list

结果如下,这里令p=2,q=3,即1/p=0.5,1/q=0.33,会相对保守周围。结果似乎好了那么一点点。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146743.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

树莓派CM4开启I2C与UART串口登录同时serial0映射到ttyS0 开启多串口

文章目录 前言1. 树莓派开启I2C与UART串口登录2. 开启多串口总结: 前言 最近用CM4的时候使用到了I2C以及多个UART的情况。 同时配置端口映射也存在部分问题。 这里集中记录一下。 1. 树莓派开启I2C与UART串口登录 输入指令sudo raspi-config 跳转到如下界面&#…

3D WEB轻量化引擎HOOPS助力3D测量应用蓬勃发展:效率、精度显著提升

在3D开发工具领域,Tech Soft 3D打造的HOOPS SDK已经崭露头角,成为了全球领先的3D领域开发工具提供商。HOOPS SDK包括四种不同的3D软件开发工具,已成为行业的翘楚。 其中,HOOPS Exchange以其CAD数据转换的能力脱颖而出&#xff0c…

Arm Cache学习资料大汇总

关键词:cache学习、mmu学习、cache资料、mmu资料、arm资料、armv8资料、armv9资料、 trustzone视频、tee视频、ATF视频、secureboot视频、安全启动视频、selinux视频,cache视频、mmu视频,armv8视频、armv9视频、FF-A视频、密码学视频、RME/CC…

前端开发网站推荐

每个人都会遇见那么一个人,永远无法忘却,也永远不能拥有。 以下是一些可以用来查找和比较前端框架的推荐网站: JavaScript框架比较: 这些网站提供了对不同JavaScript框架和库的详细比较和评估。 JavaScripting: 提供了大量的JavaS…

JavaScript高阶班之ES6 → ES11(八)

JavaScript高阶班之ES6 → ES11 1、ES6新特性1.1、let 关键字1.2、const关键字1.3、变量的解构赋值1.3.1、数组的解构赋值1.3.2、对象的解构赋值 1.4、模板字符串1.5、简化对象写法1.6、箭头函数1.7、函数参数默认值1.8、rest参数1.9、spread扩展运算符1.9.1、数组合并1.9.2、数…

上古神器:十六位应用程序 Debug 的基本使用

文章目录 参考环境上古神器 DebugBug 与 DebuggingDebugDebug 应用程序淘汰原因使用限制 DOSBox学习 Debug 的必要性DOSBox-X Debug 的基本使用命令 R查看寄存器的状态修改寄存器的内容 命令 D显示内存中的数据指定起始内存空间地址指定内存空间的范围 命令 A使用命令语法错误查…

第8章 Spring(二)

8.11 Spring 中哪些情况下,不能解决循环依赖问题 难度:★★ 重点:★★ 白话解析 有一下几种情况,循环依赖是不能解决的: 1、原型模式下的循环依赖没办法解决; 假设Girl中依赖了Boy,Boy中依赖了Girl;在实例化Girl的时候要注入Boy,此时没有Boy,因为是原型模式,每次都…

Konva离屏缓存

前言 cache实例方法定义在Node基类上,通过该方法可以实现图形缓存,在Konva中Stage、Layer、Group、Shape等所有容器类和图形类都直接或间接继承了Node基类,故而都可以使用缓存方法。本篇文章就是探讨Konva背后的缓存机制,版本是v…

8.3Jmeter使用json提取器提取数组值并循环(循环控制器)遍历使用

Jmeter使用json提取器提取数组值并循环遍历使用 响应返回值例如: {"code":0,"data":{"totalCount":11,"pageSize":100,"totalPage":1,"currPage":1,"list":[{"structuredId":&q…

计算机网络笔记 第二章 物理层

2.1 物理层概述 物理层要实现的功能 物理层接口特性 机械特性 形状和尺寸引脚数目和排列固定和锁定装置 电气特性 信号电压的范围阻抗匹配的情况传输速率距离限制 功能特性 -规定接口电缆的各条信号线的作用 过程特性 规定在信号线上传输比特流的一组操作过程&#xff0…

3. 文档操作

1. 创建文档 1.1 创建一个文档 在相应的索引下面使用_doc创建文档,地址为:http://127.0.0.1:9200/students/_doc,创建一个姓名张三的学生信息: {"姓名":"张三","年级":5,"班级":2,&qu…

28391-2012 建筑施工机械与设备 人力移动式液压动力站

声明 本文是学习GB-T 28391-2012 建筑施工机械与设备 人力移动式液压动力站. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了人力移动式液压动力站(以下简称动力站)的范围、分类、要求、试验方法和检验规则。 本标准适用于以中小…

AI编程助手 Amazon CodeWhisperer 全面解析与实践

目录 引言Amazon CodeWhisperer简介智能编程助手智能代码建议代码自动补全 提升代码质量代码质量提升安全性检测 支持多平台多语言 用户体验和系统兼容性用户体验文档和学习资源个性化体验系统兼容性 功能全面性和代码质量功能全面性代码生成质量和代码安全性 CodeWhisperer的代…

MySQL数据库——索引(6)-索引使用(覆盖索引与回表查询,前缀索引,单列索引与联合索引 )、索引设计原则、索引总结

目录 索引使用(下) 覆盖索引与回表查询 思考题 前缀索引 语法 示例 前缀长度 前缀索引的查询流程 单列索引与联合索引 索引设计原则 索引总结 1.索引概述 2.索引结构 3.索引分类 4.索引语法 5.SQL性能分析 6.索引使用 7.索引设计…

应用在手机触摸屏中的电容式触摸芯片

触控屏(Touch panel)又称为触控面板,是个可接收触头等输入讯号的感应式液晶显示装置,当接触了屏幕上的图形按钮时,屏幕上的触觉反馈系统可根据预先编程的程式驱动各种连结装置,可用以取代机械式的按钮面板&…

机器人过程自动化(RPA)入门 9. 管理和维护代码

仅仅创建一个自动化项目是不够的。无论是决定使用哪种布局,还是正确命名步骤,以正确的方式组织项目都很重要。项目也可以在新的项目中重用,这对用户来说非常方便。本章解释了我们可以重用项目的方法。我们还将学习配置技术并看到一个示例。最后,我们将学习如何集成TFS服务器…

【JavaEE】HTML

JavaWeb HTML 超文本标记语言 超文本:文本、声音、图片、视频、表格、连接标记:有许许多多的标签组成 vscode开发工具搭建 因为我使用的IDEA是社区版,代码高亮补全缩进都有些问题,使用vscode是最好的选择~ 安装 Visual Stu…

Python实用技术二:数据分析和可视化(2)

目录 一,多维数组库numpy 1,操作函数:​ 2,numpy数组元素增删 1)添加数组元素 2)numpy删除数组元素 3)在numpy数组中查找元素 4)numpy数组的数学运算 3,numpy数…

Windows中实现将bat或exe文件作为服务_且实现命令行安装、配置、启动、删除服务

一、背景描述 在Windows环境下进行日常的项目开发过程中,有时候需要将bat文件或exe文件程序注册为Windows的服务实现开机自己运行(没有用户登陆,服务在开机后也可以照常运行)、且对于那些没有用户交互界面的exe程序来说只要在后台…