本发明涉及小样本分类,尤其涉及一种对比自监督增强学习方法和系统。
背景技术:
::1、小样本分类模型旨在利用支持样本对查询样本进行分类。为了处理小样本学习场景,提出了基于度量的方法、基于聚类的方法和基于数据增强的方法,其中基于度量的方法用于学习基类样本的特征空间,并以此对新的查询样本分类。基于数据增强的方法通常会取得较好的性能,利用数据增广创建伪标签,以形成支持集和查询集。2、以往的研究在提取新类别支持样本特征嵌入时,都致力于创建一个针对新类别的强大特征嵌入表示。然而他们的性能受限,首先,现有方法常在基类样本上训练特征表示,但未充分挖掘和利用样本的可区分性及多样性,导致模型对新类别特征提取和泛化能力不足。其次,相关工作往往利用简单的变换进行预训练任务,缺乏复杂变换的设计,不能有效提高预训练的复杂性和模型的泛化能力。再者,在对比学习框架中,会引起语义上相近或相同的样本作为负对样本,这将导致模型偏向于学习与任务无关的特征,如背景或颜色分布,而非核心的类别判别特征。最后,无监督小样本分类往往遭受类别不均衡的影响,容易造成模型在优化过程中对易分类样本过度拟合,忽略难分类样本,进而导致整体分类性能下降。技术实现思路1、为了至少能够部分地解决模型对新类别特征提取和泛化能力不足、模型在优化过程中对易分类样本过度拟合,忽略难分类样本,进而导致整体分类性能下降以及模型偏向于学习与任务无关的特征的问题,本发明提供了一种对比自监督增强学习方法和系统,本发明通过对样本进行数据增强和置信度惩罚构建了多任务学习方法,并设置通用功能增强器、投影距离度量单元和可调整损失函数提高模型的泛化能力并解决分类不平衡问题,促进模型学习的精准性。2、为了实现上述目的,本发明的技术方案是:3、本发明第一方面提出了一种对比自监督增强学习方法,包括:4、步骤一:将训练样本进行数据增强,得到变换增强样本和旋转增强样本,便于构建多任务学习,增强对图像样本的理解,提升语义特征提取能力;5、步骤二:将变换增强样本和旋转增强样本输入到对比自监督增强学习模型中,得到变换特征向量、旋转特征向量、变换特征向量的置信度和旋转特征向量的置信度,便于使对比自监督增强学习模型在特征和输出空间中维持高度的泛化能力,以便精确捕捉图像的语义内容,使图像分类的准确提高;6、步骤三:根据变换特征向量的置信度得到变换增强样本的预测类别,根据变换特征向量计算变换增强样本的预测类别和真实类别之间的变换损失;7、根据旋转特征向量的置信度得到旋转增强样本的预测旋转角度,根据旋转特征向量计算旋转增强样本的预测旋转角度和真实旋转角度之间的旋转损失;8、基于不同数据增强类型的正样本计算置信度损失;9、步骤四:基于变换损失、旋转损失和置信度损失计算总损失,通过最小化总损失来优化对比自监督增强学习模型参数,得到最优对比自监督增强学习模型,用于提高对比自监督增强学习模型的图像分类准确率;10、步骤五:将目标样本输入最优对比自监督增强学习模型,得到图像分类结果。11、进一步地,所述步骤一中具体包括:12、训练样本包括样本和样本独立伪标签,对训练样本中的每个样本进行变换增强和旋转增强,得到变换增强样本和旋转增强样本,便于进行多任务学习;13、将样本独立伪标签进行重复得到重复后的样本独立伪标签,使重复后的样本独立伪标签的数量与变换增强样本数量相同,使样本独立伪标签与变换增强样本对应;14、对每个旋转增强样本赋予旋转标签,将旋转标签进行重复得到重复后的旋转标签,使重复后的旋转标签数量与旋转增强样本总数相同,使旋转标签与旋转增强样本对应。15、进一步地,所述步骤二中,对比自监督增强学习模型包括嵌入网络、通用功能增强器、投影距离度量单元和旋转分类器;16、将变换增强样本和旋转增强样本合并为一个张量,将得到的张量输入到嵌入网络中,得到特征向量,便于提取特征;17、将特征向量输入通用功能增强器中,得到增强后的特征向量,便于通过可学习的张量,将从基类中学习到的通用抽象知识存储,并让单个样本与全局语义特征交互,计算相关性,增强与以往学习到知识所相似的部分,促进了知识对新类别的有效迁移,从而提高了在新类别上的模型性能和泛化能力;18、将增强后的特征向量拆分为变换特征向量和旋转特征向量,便于进行多任务学习;19、将变换特征向量输入投影距离度量单元中,得到变换特征向量的置信度,便于根据变换特征向量的置信度进行分类;20、将旋转特征向量输入旋转分类器中,得到旋转特征向量的置信度,便于根据旋转特征向量的置信度计算旋转损失。21、进一步地,所述通用功能增强器中,按照以下公式表示q、k和v:22、q=f,k=β,v=wvβ23、其中,q为查询向量,k为关键向量,v为数值向量,f为输入的特征向量,f∈rh×w×c,β为可学习的特征权重滤波器,β∈rn×c,c为输入的特征向量的通道数,h为输入的特征向量的高度,w为输入的特征向量的宽度,n为输入的特征向量的批处理大小,wv为一个可学习的全连接层。24、进一步地,所述通用功能增强器中,按照以下公式对特征向量进行增强:25、f′=ffn(f+a)26、a=σ(fβt)v27、其中,f′为增强后的特征向量,ffn为前馈网络,σ为特征相关性度量单元,a为输入的特征向量的注意力响应。28、进一步地,所述投影距离度量单元中,按照以下公式计算投影距离:29、dist(x,y)=(||x|·cos(x,y)-|y||+||y|·cos(y,x)-|x||)(1+e-3cos(x,y))γ30、其中,dist(x,y)为投影距离,x和y分别为任意两个变换增强样本对应的向量,γ为缩放比例,|x|和|y|分别代表x和y的模长,cos为余弦函数。31、进一步地,所述步骤三中,变换损失和旋转损失通过可调整损失函数进行计算,按照以下公式表示变换损失和旋转损失:32、33、其中,lce为变换损失,lss为旋转损失,al为可调整损失函数,为旋转增强样本的旋转标签,为变换增强样本的独立伪标签,为训练样本与变换增强样本的独立伪标签的投影距离,pj为训练样本与第j个变换增强样本的投影距离,为训练样本与旋转增强样本的旋转标签的投影距离,qj为训练样本与第j个旋转增强样本的投影距离,αt为平衡因子,ω为参数,δ为阈值,βt为奖励程度参数,pt为基准真相概率,τ为可调的聚焦参数。34、进一步地,所述步骤四中,按照以下公式表示总损失函数:35、l=lce+η·lss+λ·lcr36、lcr(p,q)=p(log(p+ε)-log(q+ε))+q(log(q+ε)-log(p+ε)37、其中,l为总损失,lcr为置信度损失,η为权重系数,λ为权重系数,ε为常数,p为变换增强的正样本与所有变换增强样本产生的概率分布,q为旋转增强的正样本与所有旋转增强样本产生的概率分布。38、进一步地,所述步骤五具体包括:39、从目标样本中对支持集和查询集进行采样;40、将目标样本输入嵌入网络中,得到支持集和查询集的特征向量;41、将特征向量输入通用功能增强器中,得到增强后的支持集和查询集的特征向量,用于增强特征;42、将增强后的支持集和查询集的特征向量输入投影距离度量单元,得到查询集的置信度,便于进行图像分类;43、根据查询集的置信度得到查询集的预测标签,根据预测标签得到图像分类结果。44、本发明第二方面提出了一种对比自监督增强学习系统,包括:45、增强模块,用于将训练样本进行数据增强,得到变换增强样本和旋转增强样本,便于构建多任务学习,增强对图像样本的理解,提升语义特征提取能力;46、对比自监督增强学习模块,用于将变换增强样本和旋转增强样本输入到对比自监督增强学习模型中,得到变换特征向量、旋转特征向量、变换特征向量的置信度和旋转特征向量的置信度,便于使对比自监督增强学习模型在特征和输出空间中维持高度的泛化能力,以便精确捕捉图像的语义内容,使图像分类的准确提高;47、损失模块,用于根据变换特征向量的置信度得到变换增强样本的预测类别,根据变换特征向量计算变换增强样本的预测类别和真实类别之间的变换损失;48、根据旋转特征向量的置信度得到旋转增强样本的预测旋转角度,根据旋转特征向量计算旋转增强样本的预测旋转角度和真实旋转角度之间的旋转损失;49、基于不同数据增强类型的正样本计算置信度损失;50、训练模块,用于基于变换损失、旋转损失和置信度损失计算总损失,通过最小化总损失来优化对比自监督增强学习模型参数,得到最优对比自监督增强学习模型,用于提高对比自监督增强学习模型的图像分类准确率;51、分类模块,用于将目标样本输入最优对比自监督增强学习模型,得到图像分类结果。52、本发明的有益效果:53、(1)本发明提出了一种对比自监督增强学习方法和系统,通过基类信息进行多任务学习(multi-task training),引入对比学习来改进样本区分度,提出旋转预测任务迫使对比自监督增强学习模型关注旋转相关性强的结构特征,使其从多角度了解图像,减弱对比学习过度追求可分性而关注样本中无效信息产生的类内遥远样本,解决了对比自监督增强学习模型偏向于学习与任务无关的特征的问题。54、(2)本发明提出了通用功能增强器(generic feature enhancer,gfe),通过学习基类数据的通用特征,增强新类别数据的与之相似的语义特征,提升模型的泛化能力,进而解决了对比自监督增强学习模型对新类别特征提取和泛化能力不足的问题。55、(3)本发明提出了投影距离度量单元(pdma),通过向量间模长关系与角度关系(向量投影)捕获特征向量间的相似度关系,从而在特征空间中得到更好的决策边界。56、(4)本发明提出了可调整损失函数(adjustable loss,al),基于交叉熵损失函数,增加调节因子,更改不同类别样本损失权重,抬升难分类样本的损失占比,为易分类样本增加梯度,从而解决对比自监督增强学习模型在优化过程中对易分类样本过度拟合,忽略难分类样本,导致整体分类性能下降的问题。57、(5)本发明在多个少样本基准测试中进行了大量实验,结果表明本发明提出的方法获得了最先进的性能,消融研究和详细的可视化解释了不同组件的有效性。当前第1页12当前第1页12
技术特征:1.一种对比自监督增强学习方法,其特征在于,包括:
2.根据权利要求1所述的一种对比自监督增强学习方法,其特征在于,所述步骤一中具体包括:
3.根据权利要求1所述的一种对比自监督增强学习方法,其特征在于,所述步骤二中,对比自监督增强学习模型包括嵌入网络、通用功能增强器、投影距离度量单元和旋转分类器;
4.根据权利要求3所述的一种对比自监督增强学习方法,其特征在于,所述通用功能增强器中,按照以下公式表示q、k和v:
5.根据权利要求4所述的一种对比自监督增强学习方法,其特征在于,所述通用功能增强器中,按照以下公式对特征向量进行增强:
6.根据权利要求3所述的一种对比自监督增强学习方法,其特征在于,所述投影距离度量单元中,按照以下公式计算投影距离:
7.根据权利要求2所述的一种对比自监督增强学习方法,其特征在于,所述步骤三中,变换损失和旋转损失通过可调整损失函数进行计算,按照以下公式表示变换损失和旋转损失:
8.根据权利要求7所述的一种对比自监督增强学习方法,其特征在于,所述步骤四中,按照以下公式表示总损失函数:
9.根据权利要求1所述的一种对比自监督增强学习方法,其特征在于,所述步骤五具体包括:
10.一种对比自监督增强学习系统,其特征在于,包括:
技术总结本发明公开一种对比自监督增强学习方法和系统,包括:步骤一:将训练样本进行数据增强,得到变换增强样本和旋转增强样本;步骤二:将变换增强样本和旋转增强样本输入到对比自监督增强学习模型中;步骤三:计算变换损失、旋转损失和置信度损失;步骤四:基于变换损失、旋转损失和置信度损失计算总损失,通过最小化总损失来优化对比自监督增强学习模型参数,得到最优对比自监督增强学习模型;步骤五:将目标样本输入最优对比自监督增强学习模型,得到图像分类结果。本发明通过多任务学习、通用功能增强器、投影距离度量单元和可调整损失函数解决了无监督少样本图像分类问题。
技术研发人员:王龙葛,王志诚,于俊洋,王霆宇,吴进虎,王丹
受保护的技术使用者:河南大学
技术研发日:技术公布日:2024/11/26