0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

图神经网络的知识蒸馏框架介绍

深度学习自然语言处理 来源:北邮 作者:刘佳玮 2021-04-04 16:48 次阅读

随着深度学习的成功,基于图神经网络(GNN)的方法[8,12,30]已经证明了它们在分类节点标签方面的有效性。大多数GNN模型采用消息传递策略[7]:每个节点从其邻域聚合特征,然后将具有非线性激活的分层映射函数应用于聚合信息。这样,GNN可以在其模型中利用图结构和节点特征信息。然而,这些神经模型的预测缺乏透明性,人们难以理解[36],而这对于与安全和道德相关的关键决策应用至关重要[5]。此外,图拓扑、节点特征和映射矩阵的耦合导致复杂的预测机制,无法充分利用数据中的先验知识。

例如,已有研究表明,标签传播法采用上述同质性假设来表示的基于结构的先验,在图卷积网络(GCN)[12]中没有充分使用[15,31]。作为证据,最近的研究提出通过添加正则化[31]或操纵图过滤器[15,25]将标签传播机制纳入GCN。他们的实验结果表明,通过强调这种基于结构的先验知识可以改善GCN。然而,这些方法具有三个主要缺点:

(1)其模型的主体仍然是GNN,并阻止它们进行更可解释的预测;

(2)它们是单一模型而不是框架,因此与其他高级GNN架构不兼容;

(3)他们忽略了另一个重要的先验知识,即基于特征的先验知识,这意味着节点的标签完全由其自身的特征确定。为了解决这些问题,我们提出了一个有效的知识蒸馏框架,以将任意预训练的GNN教师模型的知识注入精心设计的学生模型中。学生模型是通过两个简单的预测机制构建的,即标签传播和特征转换,它们自然分别保留了基于结构和基于特征的先验知识。具体来说,我们将学生模型设计为参数化标签传播和基于特征的2层感知机(MLP)的可训练组合。另一方面,已有研究表明,教师模型的知识在于其软预测[9]。通过模拟教师模型预测的软标签,我们的学生模型能够进一步利用预训练的GNN中的知识。

因此,学习的学生模型具有更可解释的预测过程,并且可以利用GNN和基于结构/特征的先验知识。我们的框架概述如图1所示。

6353b566-88e9-11eb-8b86-12bb97331649.png

图1:我们的知识蒸馏框架的示意图。学生模型的两种简单预测机制可确保充分利用基于结构/功能的先验知识。

在知识蒸馏过程中,将提取GNN教师中的知识并将其注入学生。因此,学生可以超越其相应的老师,得到更有效和可解释的预测。我们在五个公共基准数据集上进行了实验,并采用了几种流行的GNN模型,包括GCN[12]、GAT[30]、SAGE[8]、APPNP[13]、SGC[33]和最新的深层GCN模型GCNII[4]作为教师模型。实验结果表明,就分类精度而言,学生模型的表现优于其相应的教师模型1.4%-4.7%。

值得注意的是,我们也将框架应用于GLP[15],它通过操纵图过滤器来统一GCN和标签传播。结果,我们仍然可以获得1.5%-2.3%的相对改进,这表明了我们框架的潜在兼容性。此外,我们通过探究参数化标签传播与特征转换之间的可学习平衡参数以及标签传播中每个节点的可学习置信度得分,来研究学生模型的可解释性。总而言之,改进是一致,并且更重要的是,它具有更好的可解释性。本文的贡献总结如下:

我们提出了一个有效的知识蒸馏框架,以提取任意预训练的GNN模型的知识,并将其注入学生模型,以实现更有效和可解释的预测。

我们将学生模型设计为参数化标签传播和基于特征的两层MLP的可训练组合。因此,学生模型有一个更可解释的预测过程,并自然地保留了基于结构/特征的先验。因此,学习的学生模型可以同时利用GNN和先验知识。

五个基准数据集和七个GNN教师模型上的实验结果表明了我们的框架有效性。对学生模型中学习权重的广泛研究也说明了我们方法的可解释性。

2 方法

在本节中,我们将从形式化半监督节点分类问题开始,并介绍符号。然后,我们将展示我们的知识蒸馏框架,以提取GNN的知识。然后,我们将提出学生模型的体系结构,该模型是参数化标签传播和基于特征的两层MLP的可训练组合。最后,我们将讨论学生模型的可解释性和框架的计算复杂性。

2.1 半监督节点分类

我们首先概述节点分类问题。给定一个连通图和一个标记点集,其中师节点集,是边集,节点分类的目标是为每个节点无标记点集中的节点预测标签。每个节点拥有标签,其中是所有可能的标签集合。此外,图数据通常拥有节点特征,并且可以利用特征来提升分类准确率。每行矩阵的每行表示节点的维特征向量。

2.2 知识蒸馏框架

基于GNN的节点分类方法往往是一个黑盒,输入图结构、标记点集和节点特征,输出分类器。分类器将预测无标记点的标签为的概率,其中。对于标记节点,如果的标签为,那么,其余标签。简化起见,我们使用表示所有标签的概率分布。在本文中,我们框架里的教师模型可以使用任意GNN,例如GCN[12]或GAT[30]。我们称教师模型里的预训练分类器为。另一方面,我们使用表示学生模型,是参数,表示学生模型对节点v的预测概率分布。在知识蒸馏[9]的框架中,训练学生模型使其最小化与预训练教师模型的软标签预测,使得教师模型里的潜在知识被提取并注入学生模型中。因此,优化目标是对齐学生模型和与训练教师模型的输出,可以形式化为:

63b88e82-88e9-11eb-8b86-12bb97331649.png

其中度量两个预测概率分布之间的距离。特别地,本文使用欧氏距离。(注:我们还尝试最小化KL散度或最大化交叉熵。但是我们发现欧几里得距离的效果最好,并且在数值上更稳定。)

2.3 学生模型架构

我们假设节点的标签预测遵循两种简单的机制:(1)从其相邻节点传播标签;(2)从其自身特征进行转换。因此,如图2所示,我们将学生模型设计为这两种机制的组合,即参数化标签传播(PLP)模块和特征转换(FT)模块,它们可以自然地分别保留基于结构的先验知识和基于特征的先验知识。蒸馏后,学生将通过更易于解释的预测机制从GNN和先验知识中受益。

63edd1be-88e9-11eb-8b86-12bb97331649.png

图2:我们建议的学生模型的架构图。以中心节点为例,学生模型从节点的原始特征和统一的标签分布作为软标签开始,然后在每一层,将的软标签预测更新为来自的邻居的参数化标签传播(PLP)和的特征变换(FT)的可训练组合。最终,将使学生与经过训练的教师的软标签预测之间的距离最小化。在本小节中,我们将首先简要回顾传统的标签传播算法。然后,我们将介绍我们的PLP和FT模块及其可训练的组合。

2.3.1 标签传播

标签传播(LP)[40]是基于图的经典半监督学习模型。该模型仅遵循以下假设:由边连接(或占据相同流形)的节点极有可能共享相同的标签。基于此假设,标签将从标记的节点传播到未标记的节点以进行预测。正式地,我们使用表示LP的最终预测,使用表示k轮迭代后的LP预测。在这个工作中,如果是标记节点,我们将对节点的预测初始化为一个独热编码向量。否则,我们将为每个未标记的节点设置均匀分布,这表明所有类的概率在开始时都是相同的。初始化可以形式化为:

641a0928-88e9-11eb-8b86-12bb97331649.png

其中,是节点在第次迭代中的预测概率分布。在第k+1次迭代时,LP将按照如下方式更新无标记节点的预测:

6452b5fc-88e9-11eb-8b86-12bb97331649.png

其中,时节点的邻居集合,是控制节点更新平滑度的超参。注意LP没有需要训练的参数,因此以端到端的方式不能拟合教师模型的输出。因此,我们通过引入更多参数来提升LP的表达能力。

2.3.2 参数化标签传播模块

现在,我们将通过在LP中进一步参数化边缘权重来介绍我们的参数化标签传播(PLP)模块。如等式3所示,LP模型在传播过程中平等对待节点的所有邻居。但是,我们假设不同邻居对一个节点的重要性应该不同,这决定了节点之间的传播强度。更具体地说,我们假设某些节点的标签预测比其他节点更“自信”。例如,一个节点的预测标签与其大多数邻居相似。这样的节点将更有可能将其标签传播给邻居,并使它们保持不变。形式化来说,我们将给每个节点v设置一个置信度分数。在传播过程中,所有节点的邻居和自身将把他们的标签传播给。基于置信值越大,边缘权值越大的直觉,我们为重写了等式3中的预测更新函数如下:

64a69ed8-88e9-11eb-8b86-12bb97331649.png

其中是节点和节点的边权,通过下面的函数计算:

64e03026-88e9-11eb-8b86-12bb97331649.png

与LP相似,按照等式2初始化,在传播过程中,每个标记点的仍然保持独热真实编码向量。注意,作为可选项,我们可以进一步参数化置信度分数用于归纳设置:

6509de26-88e9-11eb-8b86-12bb97331649.png

其中,是一个可学习参数,将节点的特征映射为置信度分数。

2.3.3 特征转换模块

注意,通过边缘传播标签的PLP模块强调了基于结构的先验知识。因此,我们还引入了特征变换(FT)模块作为补充预测机制。FT模块仅通过查看节点的原始特征来预测标签。形式化来说,用表示FT模块的预测,我们使用两层MLP后接一个softmax函数来将特征转换为软标签预测:

652da946-88e9-11eb-8b86-12bb97331649.png

注:虽然单层逻辑回归更具可解释性,但我们发现两层逻辑回归对于提高学生的模型能力是必要的。

2.3.4 可训练组合

现在我们将结合PLP和FT模块作为我们的完整学生模型。细节上,我们将为每个节点学习一个可训练参数,来平衡PLP和FT之间的预测。换句话说,FT和PLP的预测将在每个传播步骤合并。我们将合并后的完整模型命名为CPF,等式4中的每个无标记节点的预测更新公式可以重新写做:

654b89d4-88e9-11eb-8b86-12bb97331649.png

其中边权和初始化与PLP模块一致。根据是否按照等式6参数化置信度分数,模型有两个变体,分别是归纳模型CPF-ind和转导模型CPF-tra。

2.4 整体算法与细节

假设我们的学生模型一共有K层,等式1中的蒸馏目标可以进一步写为:

657bc2d4-88e9-11eb-8b86-12bb97331649.png

其中,是范数,参数集合包括PLP和FT之间的平衡参数,PLP模块内部的置信度参数(或归纳设置下的参数),以及FT模块中MLP的参数。还有一个重要的超参数:传播层数。

2.5 对模型可解释性与计算复杂性的讨论

在本小节中,我们将讨论学习的学生模型的可解释性和算法的复杂性。经过知识蒸馏后,我们的学生模型CPF会将特定节点的标签作为标签传播和基于特征的MLP的预测之间的加权平均值进行预测。平衡参数指示基于结构的LP还是基于特征的MLP对于节点的预测更重要。LP机制几乎是透明的,我们可以轻松地找出节点在每个迭代中受哪个邻居影响的程度。另一方面,对基于特征的MLP的理解可以通过现有工作[21]或直接查看不同特征的梯度来获得。因此,学习过的学生模型比GNN教师具有更好的解释性。算法每次迭代(算法1的第3行到第13行)的时间复杂度和空间复杂度都是,这和数据集的规模线性相关。事实上,操作可以简单写成矩阵形式,对于真实数据集的训练过程,使用单GPU可以在几秒内完成。因此,我们提出的知识蒸馏框架的时间、空间效率都很高。

3 实验

在本节中,我们将从介绍实验中使用的数据集和教师模型开始。然后,我们将详细介绍教师模型和学生变体的实验设置。之后,我们将给出评估半监督节点分类的定量结果。我们还在不同数量的传播层和训练比率下进行实验,以说明算法的鲁棒性。最后,我们将提供定性案例研究和可视化效果,以更好地理解我们的学生模型CPF中的学习参数。

3.1 数据集

65a157f6-88e9-11eb-8b86-12bb97331649.png

表1:数据集统计信息我们使用五个公共基准数据集进行实验,数据集的统计数据如表1所示。如以前的文献[14,24,27]所做的那样,我们仅考虑最大的连通分量,并将边视为无向边。根据先前工作[24]中的实验设置,我们从每个类别中随机抽取20个节点作为标记节点,30个用于验证节点,所有其他节点用于测试。

3.2 教师模型及其设置

为了进行全面比较,我们在我们的知识蒸馏框架中考虑了七个GNN模型作为教师模型;对于每个数据集和教师模型,我们测试下列学生变体:

PLP: 只考虑参数化标签传播机制的学生变体;

FT:只考虑特征转换机制的学生变体;

CPF-ind:归纳设置下的完整模型;

CPF-tra:转导设置下的完整模型。

3.3 分类结果分析

65d07f9a-88e9-11eb-8b86-12bb97331649.png

表2:GCN[12]和GAT[30]作为教师模型的分类准确率

65f9bf5e-88e9-11eb-8b86-12bb97331649.png

表3:APPNP[13]和SGAE[8]作为教师模型的分类准确率

663d6c9a-88e9-11eb-8b86-12bb97331649.png

表4:SGC[33]和GCNII[4]作为教师模型的分类准确率

668d1bb4-88e9-11eb-8b86-12bb97331649.png

表5:GLP[15]作为教师模型的分类准确率五个数据集、七个GNN教师模型、四个学生变体模型上的实验结果在表格2,3,4,5中展示。

3.4 不同传播层数的分析

在本小节中,我们将研究关键超参数对学生模型CPF的体系结构(即传播层数)的影响。实际上,流行的GNN模型(例如GCN和GAT)对层数非常敏感。较大数量的层将导致过平滑的问题,并严重损害模型性能。因此,我们在Cora数据集上进行了实验,以进一步分析该超参数。

66ad1bf8-88e9-11eb-8b86-12bb97331649.png

图3:Cora数据集上具有不同数量传播层的CPF-ind和CPF-tra的分类精度。图例表示指导学生的老师模式。

3.5 不同训练比例的分析

为了进一步证明该框架的有效性,我们在不同的训练比例下进行了额外的实验。具体来说,我们以Cora数据集为例,将每个类的标记节点数量从5个变化到50个。实验结果如图4所示。

68bff9ba-88e9-11eb-8b86-12bb97331649.png

图4:Cora数据集上不同数量的标记节点下的分类精度。子标题指示相应的教师模型。

3.6 可解释性分析

现在,我们将分析学习的学生模型CPF的可解释性。具体来说,我们将探究PLP和FT之间的学习平衡参数以及每个节点的置信度得分。我们的目标是找出哪种节点具有最大或最小的和。在本小节中,我们将使用由GCN和GAT教师模型指导的CPF-ind学生模型在Cora数据集上进行展示。

68ede7e4-88e9-11eb-8b86-12bb97331649.png

图5:用于可解释性分析的平衡参数案例研究。此处的子标题表示该节点是按GCN/GAT作为教师模型,按大或小值选择的。

69547248-88e9-11eb-8b86-12bb97331649.png

图6:用于可解释性分析的置信度得分案例研究。此处的子标题表示该节点是按GCN/GAT作为教师模型,按大或小值选择的。

4 结论

在本文中,我们提出了一种有效的知识蒸馏框架,可以提取任意预训练的GNN(教师模型)的知识并将其注入精心设计的学生模型中。学生模型CPF被建立为两个简单预测机制的可训练组合:标签传播和特征转换,二者分别强调基于结构的先验知识和基于特征的先验知识。蒸馏后,学习的学生可以利用先验知识和GNN知识,从而超越GNN老师。在五个基准数据集上的实验结果表明,我们的框架可以通过更可解释的预测过程来一致,显着地改善所有七个GNN教师模型的分类精度。

在不同数量的训练比率和传播层数上进行的附加实验证明了我们算法的鲁棒性。我们还提供了案例研究,以了解学生架构中学习到的平衡参数和置信度得分。在未来的工作中,除了半监督节点分类之外,我们还将探索将我们的框架用于其他基于图的应用。例如,无监督节点聚类任务会很有趣,因为标签传播模式在没有标签的情况下不能应用。另一个方向是改进我们的框架,鼓励教师和学生模型互相学习,以取得更好的成绩。

原文标题:【WWW2021】图神经网络的知识提取与超越:一个有效的知识蒸馏框架

文章出处:【微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

责任编辑:haq

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4662

    浏览量

    99495
  • 深度学习
    +关注

    关注

    73

    文章

    5348

    浏览量

    120190

原文标题:【WWW2021】图神经网络的知识提取与超越:一个有效的知识蒸馏框架

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    rnn是递归神经网络还是循环神经网络

    RNN(Recurrent Neural Network)是循环神经网络,而非递归神经网络。循环神经网络是一种具有时间序列特性的神经网络,能够处理序列数据,具有记忆功能。以下是关于循环
    的头像 发表于 07-05 09:52 206次阅读

    递归神经网络是循环神经网络

    递归神经网络(Recurrent Neural Network,简称RNN)和循环神经网络(Recurrent Neural Network,简称RNN)实际上是同一个概念,只是不同的翻译方式
    的头像 发表于 07-04 14:54 221次阅读

    反向传播神经网络和bp神经网络的区别

    反向传播神经网络(Backpropagation Neural Network,简称BP神经网络)是一种多层前馈神经网络,它通过反向传播算法来调整网络中的权重和偏置,以达到最小化误差的
    的头像 发表于 07-03 11:00 170次阅读

    bp神经网络和卷积神经网络区别是什么

    BP神经网络(Backpropagation Neural Network)和卷积神经网络(Convolutional Neural Network,简称CNN)是两种不同类型的人工神经网络,它们在
    的头像 发表于 07-03 10:12 146次阅读

    卷积神经网络的原理是什么

    卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,广泛应用于图像识别、语音识别、自然语言处理等领域。本文将详细介绍卷积神经网络的原理,包括其
    的头像 发表于 07-02 14:44 134次阅读

    卷积神经网络和bp神经网络的区别

    卷积神经网络(Convolutional Neural Networks,简称CNN)和BP神经网络(Backpropagation Neural Networks,简称BPNN)是两种
    的头像 发表于 07-02 14:24 268次阅读

    神经网络架构有哪些

    、语音识别、自然语言处理等多个领域。本文将对几种主要的神经网络架构进行详细介绍,包括前馈神经网络、循环神经网络、卷积神经网络、生成对抗
    的头像 发表于 07-01 14:16 128次阅读

    利用深度循环神经网络对心电降噪

    具体的软硬件实现点击 http://mcu-ai.com/ MCU-AI技术网页_MCU-AI 我们提出了一种利用由长短期记忆 (LSTM) 单元构建的深度循环神经网络来降 噪心电信号 (ECG
    发表于 05-15 14:42

    《 AI加速器架构设计与实现》+第一章卷积神经网络观后感

    《 AI加速器架构设计与实现》+第一章卷积神经网络观感    在本书的引言中也提到“一胜千言”,读完第一章节后,对其进行了一些归纳(如图1),第一章对常见的神经网络结构进行了介绍
    发表于 09-11 20:34

    构建神经网络模型的常用方法 神经网络模型的常用算法介绍

    神经网络模型是一种通过模拟生物神经元间相互作用的方式实现信息处理和学习的计算机模型。它能够对输入数据进行分类、回归、预测和聚类等任务,已经广泛应用于计算机视觉、自然语言处理、语音处理等领域。下面将就神经网络模型的概念和工作原理,
    发表于 08-28 18:25 808次阅读

    人工神经网络和bp神经网络的区别

    人工神经网络和bp神经网络的区别  人工神经网络(Artificial Neural Network, ANN)是一种模仿人脑神经元网络结构和功能的计算模型,也被称为
    的头像 发表于 08-22 16:45 3679次阅读

    卷积神经网络和深度神经网络的优缺点 卷积神经网络和深度神经网络的区别

    深度神经网络是一种基于神经网络的机器学习算法,其主要特点是由多层神经元构成,可以根据数据自动调整神经元之间的权重,从而实现对大规模数据进行预测和分类。卷积
    发表于 08-21 17:07 3250次阅读

    卷积神经网络算法是机器算法吗

    神经网络的原理 先介绍一下卷积神经网络的原理。卷积神经网络中的核心结构是卷积层。卷积层中包含多组卷积核,每组卷积核会对输入数据进行卷积操作,生成一组输出特征
    的头像 发表于 08-21 16:49 645次阅读

    卷积神经网络介绍 什么是卷积神经网络算法

    卷积神经网络介绍 什么是卷积神经网络算法 卷积神经网络涉及的关键技术 卷积神经网络(Convolutional Neural Networ
    的头像 发表于 08-21 16:49 1587次阅读

    卷积神经网络的应用 卷积神经网络通常用来处理什么

    神经网络,卷积神经网络广泛用于图像识别、自然语言处理、视频处理等方面。本文将对卷积神经网络的应用进行详尽、详实、细致的介绍,以及卷积神经网络
    的头像 发表于 08-21 16:41 4524次阅读