本发明涉及一种基于伪元路径生成的联邦异构图表示学习方法及系统,属于联邦异构图表示学习。
背景技术:
1、异构图被广泛用于建模现实中各种实体之间的复杂交互关系。异构图表示学习是指在低维向量空间中学习到异构图合适的嵌入,以更好地服务于进一步的推荐和预测等任务。能够有效挖掘图节点间的深层关系的图神经网络(graph neural network,gnn)常常被应用在异构图表示学习领域中。然而,gnn的性能往往依赖于图的数量和结构的完整性,为集中存储的图设计的异构图表示学习模型不能直接在分布式存储的图上取得同样好的性能。因此,联邦学习被引入到多机构合作的图表示学习中,联邦异构图表示学习方法在保护每个机构本地数据隐私的前提下,通过聚合各个参与方的模型参数来提升整体的模型性能。
2、现阶段,大部分联邦图表示学习的研究主要关注同构图,这些方法虽然通过联邦学习解决了图数据分布式存储导致的性能下降问题,但是不适合于学习异构图复杂的结构。一些研究关注到了联邦异构图表示学习,但是大部分的方法是专门为推荐系统设计的,这些方法只适用于具有特殊结构的“用户-物品”推荐图,而不适用于普遍的具有更复杂结构的异构图。为了学习异构图中的复杂关系,一些异构图表示学习模型使用元路径来挖掘节点的异构邻居,并用基于元路径的邻居信息来增强节点的表示能力,以学习到异构图的全局结构信息。然而,这些方法直接应用在联邦学习中的效果并不理想,因为在联邦学习场景下,由于数据不能跨客户端共享,一些节点跨客户端的连接关系会被切断,进而导致一些跨客户端的元路径不可达。在这种情况下,每个客户端的数据可以看成是缺失全局结构信息的子图,这些子图的不完整性会降低模型性能。因此,恢复跨客户端的缺失元路径所包含的全局信息,对联邦异构图表示学习来说非常重要。
3、一些研究关注到了联邦学习中异构图跨客户端的子图信息缺失问题,但是各有局限性。一些方法向中心服务器和其他客户端提供加密后的本地数据,以此来补充缺失信息,但是存在泄露本地数据敏感信息的风险。还有的方法通过联邦学习训练生成器,以生成伪节点来补充信息,但是这种方法要求生成的伪节点更接近本地,因此仍然无法补充跨客户端的信息,更无法恢复跨客户端的不可达元路径信息。
技术实现思路
1、针对上述现有技术的不足,本发明提供了一种基于伪元路径生成的联邦异构图表示学习方法及系统,之前的联邦异构图表示学习方法,要么没有恢复子图跨客户端的缺失信息,要么存在数据隐私泄露的风险,与之前的联邦异构图表示学习方法相比,本发明以一种保护隐私的方式恢复了跨客户端的结构信息,并且可以充分利用元路径来挖掘节点间潜在关系,从而提升性能。具体来说,本发明设计了一种基于关系的伪元路径生成方法,并利用生成的伪元路径来传递跨客户端的信息。此外,本发明还提出了联邦图融合与伪元路径分配方法,可以帮助恢复异构子图缺失的元路径信息。本发明将提出的新方法模块集成到联邦学习框架中,提升了联邦图表示学习的性能,同时可以很好地保护客户端本地的数据隐私。
2、术语解释:
3、异构图,是一种包含多于一种类型的节点和边的图结构,对于一个节点类型集合v={v1,…,vm}和边类型集合e={e1,…,en}的图g,如果m>1或n>1,则g被称为异构图。
4、子图,是指从原图中去除一些节点或边而得到的图结构,给定一个图g=(v,e),gsub=(v′,e′)是g的子图,当且仅当且
5、元路径,是一种描述异构图中以特定顺序排列的节点和边的路径模式,它可以用节点和边的类型以及特定顺序来定义,元路径φ可以表示为(v1,e1,v2…vi-1,ej,vi…vk-1,el,vk),其中i,j≤max(m,n),k≤m,l≤n。在元路径φ中,v1称为起始节点,最终到达的vk被称为目标节点,路径中的vi被称作中间节点,目标节点vk就是起始节点v0基于元路径的邻居节点。
6、联邦学习,是一种多机构协同训练模型的学习框架。联邦学习的参与者包括多个客户端和一个中心服务器,在经典的联邦学习框架中,每个客户端用本地数据训练本地模型,并将模型参数上传至中心服务器;中心服务器聚合这些本地模型参数,以形成全局模型参数,然后将全局模型参数发送给每个客户端;客户端再使用全局模型参数更新本地模型。联邦学习的目标是在保护客户端数据隐私的同时,优化全局模型,最终提升每个客户端的模型性能。
7、本发明的技术方案如下:
8、一种基于伪元路径生成的联邦异构图表示学习方法,包括如下步骤:
9、st1:从公开的异构图数据集中选择不同规模的数据集,并将其按照不平均的比例分配给参与联邦学习的每个客户端,同时切断跨客户端的边,以得到每个客户端本地异构子图数据集,客户端将本地异构子图数据集划分为训练集、验证集和测试集;其中,所有客户端在本地异构子图数据集上遵循着相同的学习任务,并使用相同结构的异构图表示学习模型,并且所有客户端使用相同定义的元路径来进行本地的学习任务;
10、st2:客户端基于本地异构子图数据集构造基于关系的伪邻居节点生成器,构建损失函数,优化生成的伪邻居节点,使其接近真实邻居节点的表示;
11、st3:利用基于关系的伪邻居节点生成器,客户端按照所有客户端都遵循的相同定义的元路径中的节点关系顺序,将本地异构子图中的真实节点作为起始节点,生成一些伪节点,并将起始节点和生成的伪节点按顺序连接成为伪元路径,客户端将生成的伪元路径加入本地异构子图训练集中,形成本地补充子图;
12、其中,每个客户端都遵循相同定义的元路径来进行学习任务,这种相同定义的元路径一般是在某些数据集上约定俗成的。比如在论文数据集中,一般使用的元路径为“作者-论文-主题-论文-作者”或者“作者-论文-作者”,在论文数据集上进行的学习任务大多使用这种定义的元路径。由于本文方案是基于联邦学习的,所以每个客户端都需要遵循这种相同定义的元路径。
13、st4:客户端将本地生成的伪元路径中全部伪节点和其拓扑结构上传至中心服务器,中心服务器接收每个客户端的伪节点和拓扑结构,对每两个来自不同客户端的相同拓扑位置的节点,构建相似度矩阵;
14、st5:中心服务器构建相似节点评估规则,基于相似节点评估规则聚类相似节点,并把相同的伪元路径结构赋予给来自不同客户端的相似节点;
15、st6:中心服务器将每个客户端增加的来自其他客户端的伪节点和伪元路径结构发送给相应客户端,客户端接收新增的来自其他客户端的伪节点和伪元路径结构,并将它们加入到本地补充子图中,形成联邦补充子图;
16、st7:基于本地训练集,客户端构建伪标签预测器,构建伪标签预测目标函数,优化预测的伪标签和真实标签接近;
17、st8:客户端用伪标签预测器预测新增的伪节点标签,进一步完善联邦补充子图的标签集,将联邦补充子图和其全部标签作为异构图表示学习模型的训练数据;
18、st9:客户端本地训练一个异构图表示学习模型,异构图表示学习模型基于异构图注意力网络结构,分别在节点级和语义级使用注意力机制,并利用元路径来学习节点间的复杂关系,通过设置超参数权衡不同部分的权重,得到最优的本地目标模型表达式和本地模型参数;
19、st10:客户端将最优的本地模型参数上传至中心服务器,中心服务器接收来自不同客户端的本地模型参数,通过差分隐私联邦平均方法来聚合这些本地模型参数,以形成全局模型参数;
20、st11:中心服务器将全局模型参数发送给每个客户端,客户端接收全局模型参数,并使用全局模型参数更新本地模型,即用接收到的全局模型参数替换本地模型参数;
21、st12:客户端在真实的本地异构子图训练集上进一步微调更新后的本地模型(微调就是用另一个数据集训练已经在某一个数据集上训练完的模型),然后将新的本地模型参数上传至中心服务器;
22、st13:中心服务器和客户端重复上述参数聚合、发送全局模型参数、本地微调和上传本地模型参数的过程,直至全局模型性能达到收敛;
23、st14:客户端将待学习的真实的本地异构子图通过最优的本地异构图表示学习模型表达式,学习到图中节点的嵌入表示,并进一步应用于节点分类等下游任务。
24、优选的,st3中,基于关系的伪邻居节点生成器是基于全连接神经网络的结构来训练的,具体来说,全连接神经网络以基于选定关系的起始节点嵌入表示、选定的关系嵌入表示和基于此关系的邻居节点嵌入表示作为输入,最终输出生成的伪邻居节点嵌入表示;
25、客户端在元路径的指导下,伪邻居节点生成器基于关系的伪邻居节点生成算法生成伪邻居节点,进而形成伪元路径,以补充本地异构子图;
26、基于关系的伪邻居节点生成算法的过程为:
27、按照一定的比例随机选定一些起始节点,根据特定的元路径所定义的关系集合的顺序,利用基于关系的伪邻居节点生成器,逐一生成节点对于特定关系的伪邻居节点,直到形成完整的伪元路径;
28、元路径中包含多种节点,从起始节点开始,按顺序一个节点一个节点地生成,最终生成的所有伪节点形成伪元路径。基于关系的伪邻居节点生成算法主要可以看成是两个步骤:1)“基于关系的伪邻居节点生成算法”(每次只能生成一种类型的伪节点):输入某种类型的节点和某种关系(即边的类型),通过“基于关系的伪邻居节点生成器”来生成输入节点通过输入的边连接到的某一种特定类型的节点,如输入节点类型是“作者”,输入关系类型是“写作”,则“作者”节点基于“写作”关系生成的邻居节点是“论文”节点,因为“论文”节点是生成器生成的而不是真实存在的,所以是“伪”的。2)“元路径指导的”(得到伪元路径):元路径提供节点顺序,多次生成伪节点,按照顺序连接得到伪元路径。
29、生成伪邻居节点的方法如下:以某种类型的节点和某种类型的关系作为输入,通过一个全连接神经网络来生成输入节点基于输入关系的伪邻居节点;生成的伪邻居节点是全连接神经网络的输出,对于每一种不同的输入节点类型和输入关系类型的组合,使用不同的基于关系的伪邻居节点生成器,以对应地生成每一种相应的伪邻居节点。
30、优选的,st2中,基于关系的伪邻居节点生成器使用adam作为优化器,使用relu作为激活函数,并使用dropout来防止过拟合,利用网络内部的隐藏层来学习到输入之间的映射关系;
31、为了使生成的伪邻居节点的嵌入表示尽可能接近真实节点,使用mse作为损失函数,以衡量生成的伪邻居节点嵌入与真实邻居节点嵌入y的距离;mse损失函数如式(i)所示:
32、
33、式(i)中,n代表节点嵌入的维度,yi是真实邻居节点嵌入y在第i维的具体值,是伪邻居节点嵌入在第i维的具体值;
34、随着本地伪邻居节点的生成和伪元路径的形成,客户端将生成的伪元路径加入本地异构子图训练集中,最终形成本地补充子图。
35、优选的,st4中,客户端筛选出本地伪元路径中的可信伪元路径,可信伪元路径是指伪元路径的拓扑结构中完全是生成的节点和结构,不包含开始生成伪元路径的真实的起始节点;筛选出本地全部可信伪元路径后,客户端将这些可信伪元路径上传至中心服务器;
36、中心服务器接收来自每个客户端的可信伪元路径,对于这些伪元路径中每一种类型的伪节点,中心服务器分别计算每两个来自不同客户端的在伪元路径中拓扑位置相同、并且类型也相同的伪节点的相似度,以形成伪节点相似度矩阵;
37、优选的,使用曼哈顿距离来衡量每对节点之间的相似度;节点间的曼哈顿距离计算如式(ii)所示:
38、
39、式(ii)中,embi和embj分别代表来自两个不同客户端i和j的节点嵌入,具体来说,embi=(x1,x2,…,xn)和embj=(y1,y2,…,yn);用si,j表示节点i和节点j之间的相似度,相似度的计算如式(iii)所示:
40、
41、由式(iii)可得,节点间的曼哈顿距离越小,节点间的相似度越大;
42、中心服务器根据节点间的相似度和一些约束条件构建相似节点评估规则,并基于此评估规则来聚类相似节点,从而为每一个客户端的伪节点寻找到来自其他客户端的最相似的一些伪节点。
43、优选的,约束条件包括节点相似度阈值τ和最大相似数量,具体来说,只有两个节点的相似度大于给定的阈值τ,这两个节点才会被判定为是相似的。每个节点的最大相似节点数量被设定为k,即按照相似度从高到低来排序后,相似节点最多会被保留k个,用vsimilar表示与某节点相似的节点所构成的集合,|vsimilar|表示该集合中的节点数量;约束条件如式(iv)所示:
44、
45、中心服务器将找到的符合式(iv)中要求的最相似的伪节点,即相似度从高到低来排序后的前k个节点聚类在一起,目的是为了还原同一个节点出现在不同客户端的情况,这些伪节点在不同客户端中基于伪元路径的下一步关系就是需要补充的“子图间缺失关系”。因此,中心服务器将相同的连接关系赋予这些相似的伪节点,使得伪节点可以连接到其他客户端最相似伪节点的邻居节点,从而增加跨客户端的伪元路径;
46、对于每一个客户端的每条伪元路径中的中间节点,中心服务器循环执行上述聚类和增加连接操作,最终,原本在中心服务器中彼此独立的来自不同客户端的伪元路径就形成了一个相互关联的融合图;
47、中心服务器检查每个客户端的伪节点新增的连接关系,并把这些新增的来自其他客户端的伪元路径及其结构关系发送给相应客户端。
48、优选的,st7中,基于半监督学习的思想,为保证伪元路径数据在训练过程中的可用性,伪元路径中的伪节点需要有一些伪标签,因此,客户端基于原始的本地异构子图数据集划分得到的本地训练集,来构建伪标签预测器,以给本地生成的伪节点和联邦补充的伪节点预测一些伪标签,实现数据的有效扩充;
49、使用一个全连接神经网络作为伪标签预测器,并选用relu激活函数和adam优化器,同时,通过最小化交叉熵损失来构建伪标签预测的目标函数,目的是让预测的伪标签尽可能和真实的标签接近;训练的损失如式(v)所示:
50、
51、式(v)中,yp代表神经网络的输出,labelr代表真实标签,yp[labelr]表示向量yp中对应于真实标签labelr的分量,∑iexp(yp[i])中的i表示所有可能的类别索引,yp[i]表示向量yp中对应于索引号为i的类别的分量,∑iexp(yp[i])表示对所有类别的预测分量进行指数运算后的求和;
52、st9中,客户端在新的训练数据集上训练一个异构图表示学习模型,新的训练集包含:本地真实子图训练集、本地补充的伪元路径、联邦补充的伪元路径、本地真实的标签集和标签预测器预测的伪标签集;
53、异构图表示学习模型学习利用元路径来学习异构图中节点间的复杂关系;模型使用注意力机制来衡量元路径和邻居节点的不同重要性;
54、基于异构图注意力网络(han)的思想和网络结构,分别在节点级和语义级使用注意力机制;
55、节点级注意力用于学习节点的基于元路径的邻居节点的重要性,除了真实节点外,这些邻居节点也包含在本地补充和联邦补充过程中增加的基于伪元路径的伪邻居;基于节点级注意力更新节点嵌入表示如式(vi)所示:
56、
57、式(vi)中,embi代表节点i的嵌入表示,φ代表某个元路径的统一定义,表示节点i基于某条真实元路径的真实邻居节点集合,而表示节点i基于某条伪元路径的伪邻居节点集合,代表权重系数,sig(·)是sigmoid激活函数,||是连接操作;
58、语义级注意力用来衡量不同元路径的重要性,|φ|表示不同定义的元路径数量,表示元路径φi中已经学习完嵌入表示的节点,qt是一个语义级注意力向量;基于语义级注意力更新节点的嵌入表示,如式(vii)所示:
59、
60、式(vii)中,w表示权重矩阵,b表示偏置向量,tanh(w·embi+b)表示对节点嵌入embi进行仿射变换,然后通过双曲线正切激活函数tanh进行非线性映射;其中,w和b随着模型的训练而不断更新;
61、在客户端训练本地异构图表示学习模型阶段,选用adam优化器和交叉熵损失函数,同时还引入dropout和早停策略以避免过拟合;训练完成后,客户端得到当前最优的本地模型.
62、优选的,st10中,采用加权平均的参数聚合方式,其中客户端ci的权重取决于其包含的分类目标节点的数据量,如式(viii)所示:
63、
64、式(viii)中,client表示所有参与联邦学习的客户端的集合,表示客户端ci中所有的目标节点所组成的集合,表示客户端ci中所有目标节点的数量;
65、基于差分隐私的思想,为进一步提升隐私保护,向聚合后的全局模型参数中加入随机生成的高斯噪声,其符合如式(ix)所示的正态分布:
66、
67、式(ix)中,u表示均值,σ表示标准差;u=0且服从正态分布的随机变量x即为加入的噪声,表示为x~n(0,σ2),可以通过控制σ的大小来控制加入噪声的程度,但需要进行隐私保护和模型效用之间合适的权衡;
68、中心服务器执行一种“差分隐私联邦平均”参数聚合方法,如式(x)所示:
69、
70、式(x)中,paramglobal代表全局模型参数,paramlocal代表客户端本地模型参数;
71、st12中,微调的具体方法如下:以本地异构子图数据集划分出的训练集作为训练集,训练更新后的本地模型。
72、st13中,全局模型是否收敛可以由全局模型在联邦训练的每一轮中的参数变化量来判断,具体判断方法如下:由中心服务器监控每一轮全局模型参数的变化量,如果连续几轮训练中,全局模型的参数变化量小于某个设定的阈值,则认为全局模型已经收敛。
73、一种基于伪元路径生成的联邦异构图表示学习系统,其特征在于,用于实现上述的基于伪元路径生成的联邦异构图表示学习方法,包括:
74、伪元路径生成单元,被配置为:客户端基于本地真实异构图中的节点和关系,生成一些伪节点,进而按照元路径的模式形成一些伪元路径;
75、联邦异构图融合单元,被配置为:中心服务器聚类对客户上传的伪节点聚类,并为相似节点增加新的伪元路径结构,以得到由伪元路径构成的彼此相连的融合异构图;
76、异构图表示学习单元,被配置为:客户端利用基于元路径和注意力机制的模型来学习本地异构图中节点的嵌入表示;
77、联邦参数聚合单元,被配置为:中心服务器利用基于差分隐私和联邦平均的方法,聚合来自不同客户端的模型参数,以得到全局模型参数。
78、一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,处理器执行所述计算机程序时实现上述的基于伪元路径生成的联邦异构图表示学习方法的步骤。
79、一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现上述的基于伪元路径生成的联邦异构图表示学习方法的步骤。
80、本发明未详尽之处,均可参见现有技术。
81、本发明的有益效果为:
82、1、本发明为多机构合作的异构图表示学习引入了一个完整的解决方案,通过引入伪元路径生成、联邦异构图融合和联邦参数聚合方法,将异构图表示学习方法很好地融入进联邦学习架构中,通过联邦学习的架构来提升模型性能,为联邦异构图表示学习领域提供了一种有益的思路和可行解决方案。
83、2、通过客户端伪元路径生成,本发明解决了联邦异构图表示学习中跨客户端的结构信息缺失问题。具体来说,在中心服务器的协调下,客户端可以通过生成的伪元路径来将本地异构图中重要的结构信息传递给其他客户端,每个客户端在本地训练模型时可以利用跨客户端的结构信息,因此客户端本地的异构图表示模型性能得到提升。
84、3、通过共享伪元路径而不是本地真实数据,本发明可以保障客户端的本地数据隐私,每个客户端的本地真实异构图都不会被其他客户端或者中心服务器获取,这种技术可以在很好地提升模型性能的同时,极大地提升联邦异构图表示学习方法中对客户端本地数据隐私的保护能力。
1.一种基于伪元路径生成的联邦异构图表示学习方法,其特征在于,包括如下步骤:
2.根据权利要求1所述的基于伪元路径生成的联邦异构图表示学习方法,其特征在于,st3中,基于关系的伪邻居节点生成器是基于全连接神经网络的结构来训练的,全连接神经网络以基于选定关系的起始节点嵌入表示、选定的关系嵌入表示和基于此关系的邻居节点嵌入表示作为输入,最终输出生成的伪邻居节点嵌入表示;
3.根据权利要求1所述的基于伪元路径生成的联邦异构图表示学习方法,其特征在于,st2中,基于关系的伪邻居节点生成器使用adam作为优化器,使用relu作为激活函数,并使用dropout来防止过拟合,利用网络内部的隐藏层来学习到输入之间的映射关系;
4.根据权利要求3所述的基于伪元路径生成的联邦异构图表示学习方法,其特征在于,st4中,客户端筛选出本地伪元路径中的可信伪元路径,可信伪元路径是指伪元路径的拓扑结构中完全是生成的节点和结构,不包含开始生成伪元路径的真实的起始节点;筛选出本地全部可信伪元路径后,客户端将这些可信伪元路径上传至中心服务器;
5.根据权利要求4所述的基于伪元路径生成的联邦异构图表示学习方法,其特征在于,约束条件包括节点相似度阈值τ和最大相似数量;每个节点的最大相似节点数量被设定为k,即按照相似度从高到低来排序后,相似节点最多会被保留k个,用vsimilar表示与某节点相似的节点所构成的集合,|vsimilar|表示该集合中的节点数量;约束条件如式(iv)所示:
6.根据权利要求5所述的基于伪元路径生成的联邦异构图表示学习方法,其特征在于,st7中,客户端基于原始的本地异构子图数据集划分得到的本地训练集,来构建伪标签预测器,以给本地生成的伪节点和联邦补充的伪节点预测一些伪标签,实现数据的有效扩充;
7.根据权利要求5所述的基于伪元路径生成的联邦异构图表示学习方法,其特征在于,st10中,采用加权平均的参数聚合方式,其中客户端ci的权重取决于其包含的分类目标节点的数据量,如式(viii)所示:
8.一种基于伪元路径生成的联邦异构图表示学习系统,其特征在于,用于实现权利要求1~7任一所述的基于伪元路径生成的联邦异构图表示学习方法,包括:
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,处理器执行所述计算机程序时实现权利要求1-7任一所述的基于伪元路径生成的联邦异构图表示学习方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-7任一所述的基于伪元路径生成的联邦异构图表示学习方法的步骤。