前置知识
图神经网络的主要思想是:对目标节点,使用相邻的节点来进行计算。而相邻节点又要需要他的相邻节点来计算,形成多层的结构。
至于,相邻节点如何用来计算目标节点,方法有两个,一个是平均领域的信息;另一个是使用神经网络来进行计算。
数学表示:
当前时刻的节点,由上一层的embedding,和它相邻节点平均值相加得到,并且经过非线性映射。
模型训练:
对权重矩阵 W, B,使用随机梯度方法进行更新。
至于无监督训练,一般假设,相似的节点具有相同的embedding。
有监督训练(节点分类情形)
搭建并训练一个简单图神经网络
根据官网教程,搭建一个简单的用于多分类的图卷积网络。输入是一个无向图数据(Graph data),图数据由节点数据(Node)和边数据(Edge)构成。节点数据是一个N × C 的矩阵,其中N 表示节点个数,而C 表示节点特征长度。边数据是一个2 × E的矩阵。E 表示有向边个个数。在PyG中,由于边数据被定义为有向边,如果表示无向边图的话,那么E / 2 才是无向边的个数。代码如下所示。node_features代表C ,而node_class代表预测类别,例子中是7类,即node_class=7。
1 | class simpleNet(torch.nn.Module): # simpleNet继承自torch.nn.Module |
简易的训练和测试过程如下所示:
1 | def example_5(): |
跟官网稍不一样的是,我在mask变量后面加上强制转换为bool类型的操作,即.bool()。不然运行时候会报警告UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.。输出结果是:
1 | dataset.num_features: 1433 |