1.本技术涉及计算机技术领域,尤其涉及一种模型训练、数据增强方法、装置、电子设备及存储介质。
背景技术:
2.随着数据采集技术的不断进步,越来愈多的数据正在被收集并广泛应用于商业分析,金融服务和医疗教育等各个方面。
3.但是,由于数据本身的不平衡性和采集手段的限制,相当多的数据存在没有标签或者标签不平衡的情况,导致模型结果不理想,甚至输出错误的结果,这对我们当前的数据预处理技术带来了极大的挑战。具体来讲,数据样本标签不平衡是指在不同标签的数据源中,某一些标签的数据占绝大部分,而另外一些标签的数据只占很少一部分。例如在二分类预测问题中,标签为“1”的数据占总量的99%,而标签为“0”的数据只占1%。这种数据常常会损害模型的效果,使一个二分类模型无法获得较好的预测结果。
技术实现要素:
4.为了解决上述技术问题或者至少部分地解决上述技术问题,本技术提供了一种模型训练、数据增强方法、装置、电子设备及存储介质。
5.第一方面,本技术提供了一种模型训练方法,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,所述方法包括:
6.所述生成器生成参考样本数据;
7.第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
8.第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
9.基于所述第一距离和所述第二距离确定目标函数;
10.利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
11.可选的,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
12.可选的,所述利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型,包括:
13.利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
14.将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
15.可选的,所述目标函数为:
16.[0017][0018]
其中,posdata表示正类数据,negdata表示负类数据,alldata表示生成的负类数据和原有负类数据的并集。d1表示第一判别器参数,d2表示第二判别器参数,g表示生成器参数。
[0019]
可选的,所述第一判别器和所述第二判别器的结构相同,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-relu层和sigmoid层。
[0020]
可选的,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-relu层。
[0021]
第二方面,本技术提供了一种数据增强方法,包括:
[0022]
利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如第一方面任一所述的模型训练方法训练得到的;
[0023]
将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
[0024]
第三方面,本技术提供了一种模型训练装置,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,所述装置包括:
[0025]
生成模块,用于所述生成器生成参考样本数据;
[0026]
第一计算模块,用于第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
[0027]
第二计算模块,用于第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
[0028]
选择模块,用于基于所述第一距离和所述第二距离确定目标函数;
[0029]
训练模块,用于利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
[0030]
可选地,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
[0031]
可选地,所述训练模块,还用于:
[0032]
利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
[0033]
将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
[0034]
可选地,所述目标函数为:
[0035][0036][0037]
其中,posdata表示正类数据,negdata表示负类数据,alldata表示生成的负类数
据和原有负类数据的并集。d1表示第一判别器参数,d2表示第二判别器参数,g表示生成器参数。
[0038]
可选地,所述第一判别器和所述第二判别器的结构相同,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-relu层和sigmoid层。
[0039]
可选地,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-relu层。
[0040]
第四方面,本技术提供了一种数据增强装置,包括:
[0041]
生成模块,用于利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如权利要求8所述的模型训练方法训练得到的;
[0042]
添加模块,用于将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
[0043]
第五方面,本技术提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
[0044]
存储器,用于存放计算机程序;
[0045]
处理器,用于执行存储器上所存放的程序时,实现第一方面任一所述的模型训练方法或第二方面所述的数据增强方法。
[0046]
第六方面,本技术提供了一种计算机可读存储介质,所述计算机可读存储介质上存储有模型训练方法的程序或者数据增强方法的程序,所述模型训练方法的程序被处理器执行时实现第一方面任一所述的模型训练方法的步骤,所述数据增强方法的程序被处理器执行时实现第二方面所述的数据增强方法的步骤。
[0047]
本技术实施例提供的上述技术方案与现有技术相比具有如下优点:
[0048]
本技术实施例提供的该方法,本发明实施例通过生成器生成参考样本数据,第一判别器计算参考样本数据与预设负样本数据之间的第一距离,第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离,再基于所述第一距离和所述第二距离确定目标函数,最后可以利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
[0049]
本发明实施例通过生成器生成参考样本数据,基于第一距离和第二距离确定目标函数,利用所述目标函数训练所述生成对抗网络模型,可以使训练完成的生成对抗网络模型的输出数据满足预设样本平衡条件,对较少的那一类样本生成额外的数据,即生成的输出数据可以使两类样本更加平衡,由于是生成额外的数据,所以不会对数据量造成损失,使得数据样本标签不平衡。
附图说明
[0050]
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本发明的实施例,并与说明书一起用于解释本发明的原理。
[0051]
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员而言,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0052]
图1为本技术实施例提供的一种生成对抗网络模型的原理示意图;
[0053]
图2为本技术实施例提供的一种模型训练方法的一种流程图;
[0054]
图3为图1中步骤s105的流程图;
[0055]
图4为本技术实施例提供的一种模型训练方法的另一种流程图;
[0056]
图5为本技术实施例提供的一种模型训练装置的结构图;
[0057]
图6为本技术实施例提供的另一种模型训练装置的结构图;
[0058]
图7为本技术实施例提供的一种电子设备的结构图。
具体实施方式
[0059]
为使本技术实施例的目的、技术方案和优点更加清楚,下面将结合本技术实施例中的附图,对本技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本技术的一部分实施例,而不是全部的实施例。基于本技术中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本技术保护的范围。
[0060]
在实现本发明的过程中,发明人发现,现有的技术方案往往通过上采样,下采样和对样本分配权重的方式以解决数据样本标签不平衡的问题。这些方法往往存在一些缺陷。其一,这些方法有时难以取得较好的效果。以下采样为例,这种方法通过对标签较多的那一类数据进行下采样,使得两种(或多种)标签的数据具有相似的数量。但是,对于不平衡较为严重的情况来说,这种方法会大大减少可以使用的数据量,损害了模型的效果。其二,有些方法对于模型的依赖较为严重,效果可能随着模型的变换而变化。比如对样本分配权重的方法就要求模型必须可以处理带权重的样本。另外对于样本权重的选取也增加了应用该方法的难度。为此,本发明实施例提供了一种模型训练、数据增强方法、装置、电子设备及存储介质,所述模型训练方法用于训练生成对抗网络模型,生成对抗网络:是机器学习中非监督式学习的一种方法,通过让两个神经网络相互博弈的方式进行学习。生成对抗网络由一个生成网络与一个判别网络组成。生成网络从潜在空间(latent space)中随机取样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。而生成网络则要尽可能地欺骗判别网络。两个网络相互对抗、不断调整参数,最终目的是使判别网络无法判断生成网络的输出结果是否真实。
[0061]
不同于一般的生成对抗网络模型,在本发明实施例中,同时利用正样本和负样本生成负样本数据对生成对抗网络模型进行训练。本发明实施例所依据的原理是:减小生成的数据与负样本之间的差异,并增大生成数据与正样本的差异。通过这种方法生成的负样本能够保持与真实负样本分布相近但是与正样本保持足够的分离间隔。使得重组后的数据能够使分类器更好地找到正负类的分离面。
[0062]
在本发明实施例中,如图1所示,生成对抗网络模型包括:生成器(generator)和两个判别器(discriminator),也就是说,模型训练方法用于训练生成器和两个判别器。其中,所述生成器的输出作为两个所述判别器的输入,假设两个判别器分别为第一判别器和第二判别器,生成器用于将输入的随机噪声数据转化为和真实负样本分布相近的数据,从而生成参考样本数据(负样本数据),达到数据增强的目的;
[0063]
将参考样本数据和预设负样本数据输入第一判别器中,第一判别器判别参考样本
数据和预设负样本数据之间的差距,也即第一判别器用于判断参考样本数据和预设负样本数据是否属于同一类;
[0064]
将参考样本数据和预设负样本数据合并,得到负类数据,将负类数据和预设正样本数据输入第二判别器,第二判别器判别负类数据和预设正样本数据之间的差距,也就是说,第二判别器用于判断负类数据和预设正样本数据是否为同一类。
[0065]
如图2所示,所述模型训练方法可以包括以下步骤:
[0066]
步骤s101,所述生成器生成参考样本数据;
[0067]
在本发明实施例中,所述生成器包括多个级联的生成单元,每个生成单元包括级联的全连接层、标准化层和leaky-relu层,其中,标准化层可以指batch-normalization算法层,batch-normalization算法层用于防止梯度爆炸,在本发明实施例中,第一级生成单元中全连接层和leaky-relu层的维度均为256,第二级生成单元中全连接层和leaky-relu层的维度均为512,第三级生成单元中全连接层和leaky-relu层的维度均为1024。
[0068]
在步骤s101之前,可以获取原始数据集及服从高斯分布的随机噪声数据,原始数据集中包括预设正样本数据和负样本数据。
[0069]
为表述方便,在本发明实施例中,将标签较少的样本称作负样本数据,将标签较多的样本称作正样本数据,并且令负样本的标签为-1,正样本的标签为1。
[0070]
在该步骤中,可以将服从高斯分布的随机噪声数据输入至生成器的输入层,随机噪声数据的维度是100维,生成器可以基于随机噪声数据生成参考样本数据。
[0071]
步骤s102,第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
[0072]
在本发明实施例中,所述第一判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层和leaky-relu层,第一级判别单元中全连接层和leaky-relu层的维度均为512,第二级判别单元中全连接层和leaky-relu层的维度均为256。
[0073]
步骤s103,第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
[0074]
在本发明实施例中,所述第二判别器和所述第一判别器的结构相同,所述第二判别器包括:多个级联的判别单元和sigmoid层,最后一级判别单元的输出作为sigmoid层的输入,每个所述判别单元包括级联的全连接层、leaky-relu层和sigmoid层。
[0075]
步骤s104,基于所述第一距离和所述第二距离确定目标函数;
[0076]
为了减小参考样本数据与负样本之间的差异,并增大参考样本数据与正样本的差异,也就是说,本发明实施例的目的是使目标样本数据可以使第一分类器产生较大的误差(即:使目标样本数据和预设负样本数据差距较小),而使第二分类器产生较小的误差(即:使目标样本数据和预设正样本数据差距较大)。
[0077]
也就是说,在本发明实施例中,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
[0078]
所以在该步骤中,可以基于第一距离和第二距离在参考样本数据中选择满足预设样本平衡条件的目标样本数据,预设样本平衡条件可以指与预设负样本预设负样本数据差距较小,且,和预设正样本数据差距较大。
[0079]
满足预设样本平衡条件的目标样本数据即参考样本数据中,第一距离较小且第二
距离较大的目标样本数据,示例性的,目标样本数据可以指参考样本数据中,第一距离小于预设第一阈值且第二距离大于预设第二阈值的目标样本数据。
[0080]
步骤s105,利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
[0081]
在该步骤中,可以将所述预设负样本数据和所述正样本数据输入生成对抗网络模型,基于生成对抗网络模型输出的输出数据与所述目标样本数据之间的差异,不断的调整生成对抗网络模型的模型参数,直至输出数据与所述目标样本数据一致,确定生成对抗网络模型收敛,得到所述生成对抗网络模型,以用于数据增强。
[0082]
本发明实施例通过生成器生成参考样本数据,第一判别器计算参考样本数据与预设负样本数据之间的第一距离,第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离,再基于所述第一距离和所述第二距离确定目标函数,最后可以利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
[0083]
本发明实施例通过生成器生成参考样本数据,基于第一距离和第二距离确定目标函数,利用所述目标函数训练所述生成对抗网络模型,可以使训练完成的生成对抗网络模型的输出数据满足预设样本平衡条件,对较少的那一类样本生成额外的数据,即生成的输出数据可以使两类样本更加平衡,由于是生成额外的数据,所以不会对数据量造成损失,使得数据样本标签不平衡。
[0084]
在本发明的又一实施例中,如图3所示,所述步骤s105可以包括以下步骤:
[0085]
步骤s301,利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
[0086]
在本发明实施例中,所述目标函数为:
[0087][0088][0089]
其中,posdata表示正类数据,negdata表示负类数据,alldata表示生成的负类数据和原有负类数据的并集。d1表示第一判别器参数,d2表示第二判别器参数,g表示生成器参数。
[0090]
步骤s302,将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
[0091]
本发明实施例通过目标函数,能够不断的调整模型参数,最终得到生成器参数、第一判别器参数和第二判别器参数,便于使生成对抗网络模型的输出数据满足预设样本平衡条件,对较少的那一类样本生成额外的数据,即生成的输出数据可以使两类样本更加平衡,由于是生成额外的数据,所以不会对数据量造成损失,使得数据样本标签不平衡。
[0092]
在本发明的又一实施例中,还提供一种数据增强方法,如图4所示,所述方法包括:
[0093]
步骤s401,利用生成对抗网络模型生成第二负样本数据,所述生成对抗网络模型是利用如前述方法实施例所述的模型训练方法训练得到的;
[0094]
在该步骤中,生成对抗网络模型的输入数据为服从高斯分布的随机噪声数据,再利用生成对抗网络模型进行数据增强时,生成对抗网络模型的输入数据与训练该生成对抗网络模型时输入至生成器的服从高斯分布的随机噪声数据相同。
[0095]
第二负样本数据加上预设负样本数据的总数一般应该与预设正样本数据的数量相同。
[0096]
生成第二负样本数据后,将第二负样本数据对应的数据标签设置为-1(即与预设负样本数据的标签相同)。
[0097]
步骤s402,将所述第二负样本数据加入原始数据集中,得到新数据集,所述原始数据集包括预设正样本数据和预设负样本数据。
[0098]
在该步骤中,可以将生成的第二负样本数据加入原数据集,并将整个数据集进行随机打乱操作,得到新数据集。
[0099]
本发明实施例能够生成第二负样本数据,并将生成第二负样本数据加入原始数据集,得到可直接用于训练的新数据集,新数据集对其所运用的模型没有依赖。
[0100]
在本发明的又一实施例中,还提供一种模型训练装置,生成对抗网络模型包括:生成器和两个判别器,所述生成器的输出作为两个所述判别器的输入,如图5所示,所述装置包括:
[0101]
生成模块11,用于所述生成器生成参考样本数据;
[0102]
第一计算模块12,用于第一判别器计算参考样本数据与预设负样本数据之间的第一距离;
[0103]
第二计算模块13,用于第二判别器计算由所述参考样本数据和预设负样本数据组成的负类数据与预设正样本数据之间的第二距离;
[0104]
选择模块14,用于基于所述第一距离和所述第二距离确定目标函数;
[0105]
训练模块15,用于利用所述目标函数训练所述生成对抗网络模型,直至所述生成对抗网络模型收敛,得到所述生成对抗网络模型。
[0106]
可选地,所述目标函数的优化目标为最小化所述第一距离,最大化所述第二距离。
[0107]
可选地,所述训练模块,还用于:
[0108]
利用所述目标函数训练所述生成对抗网络模型,得到所述生成器的生成器参数、所述第一判别器的第一判别器参数及所述第二判别器的第二判别器参数;
[0109]
将所述生成器参数、所述第一判别器参数及所述第二判别器参数输入所述生成对抗网络模型中,得到所述生成对抗网络模型。
[0110]
可选地,所述目标函数为:
[0111][0112][0113]
其中,posdata表示正类数据,negdata表示负类数据,alldata表示生成的负类数据和原有负类数据的并集。d1表示第一判别器参数,d2表示第二判别器参数,g表示生成器参数。
programmablegatearray,简称fpga)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
[0127]
在本发明的又一实施例中,还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有模型训练方法的程序或者数据增强方法的程序,所述模型训练方法的程序被处理器执行时实现前述方法实施例所述的模型训练方法的步骤,所述数据增强方法的程序被处理器执行时实现前述方法实施例所述的数据增强方法的步骤。
[0128]
需要说明的是,在本文中,诸如“第一”和“第二”等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个
……”
限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
[0129]
以上所述仅是本发明的具体实施方式,使本领域技术人员能够理解或实现本发明。对这些实施例的多种修改对本领域的技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本发明的精神或范围的情况下,在其它实施例中实现。因此,本发明将不会被限制于本文所示的这些实施例,而是要符合与本文所申请的原理和新颖特点相一致的最宽的范围。
转载请注明原文地址:https://tc.8miu.com/read-58.html