本发明涉及人工智能,具体是一种基于类关系的二维规范化知识蒸馏方法,以及应用这种方法的系统以及计算机终端。
背景技术:
1、深度卷积神经网络被越来越多的应用于诸如图像分类,目标检测,语义分割等计算机视觉应用场景。在卷积神经网络帮助提高模型性能表现的背后,这些神经网络模型的复杂度也日渐庞大,对计算资源的要求和存储代价更加高昂。这导致基于卷积神经网络的方法往往难以直接完成在移动设备和嵌入式平台的部署。知识蒸馏作为模型压缩的重要方法之一,已经广泛应用于各种任务中。
2、知识蒸馏利用一个高性能但难以部署的教师模型去改善和提高一个轻量级的学生模型的性能表现,使得学生的性能能够逼近甚至超越教师。目前主流蒸馏方法包括logit蒸馏与特征蒸馏,其中直接学习教师输出的基于类关系蒸馏性能往往不够理想,这是因为传统logit蒸馏方法中,学生被要求对教师的预测进行精确匹配,但由于教师与学生模型之间容量与结构的差异,学生无法产生与教师相同的分布。另外,传统方法仅限于实例之间的知识传递,即类间关系的传递,而忽略了不同样本间的上下文信息,影响了模型的泛化能力,从而限制了logit知识蒸馏的精度和训练效率。
技术实现思路
1、本发明解决的技术问题是如何提高logit知识蒸馏的精度和训练效率。
2、为实现上述目的,本发明提供如下技术方案:
3、本发明公开一种基于类关系的二维规范化知识蒸馏方法,包括:
4、将训练集批次内样本即图片分别输入至教师模型和学生模型,获取教师模型和学生模型各自最后隐藏层的输出即教师logits和学生logits;其中,教师logits和学生logits为维度相同的二维矩阵,二维矩阵中的每个元素即logit表示模型对每个类别的置信度得分,且二维矩阵的行数和列数分别对应样本数和类别数;
5、对教师logits和学生logits中的每个logit分别进行类间维度规范化修正和类内维度规范化修正,根据修正结果计算类间规范化损失和类内规范化损失;
6、将学生logits与样本的原始真实标签进行训练,计算分类损失;
7、将所述类间规范化损失、所述类内规范化损失和所述分类损失三者进行加权得到总训练损失,基于所述总训练损失进行梯度下降和反向传播,从而优化学生模型参数,随后利用训练集下一批次的样本对学习模型进行持续优化,直至学生模型收敛。
8、本发明还公开一种基于类关系的二维规范化知识蒸馏系统,包括:样本输入模块、损失计算模块以及学生模型输出模块。
9、样本输入模块用于将训练集批次内样本即图片分别输入至教师模型和学生模型,获取教师模型和学生模型各自最后隐藏层的输出即教师logits和学生logits;其中,教师logits和学生logits为维度相同的二维矩阵,二维矩阵中的每个元素即logit表示模型对每个类别的置信度得分,且二维矩阵的行数和列数分别对应样本数和类别数;
10、损失计算模块用于对教师logits和学生logits中的每个logit分别进行类间维度规范化修正和类内维度规范化修正,根据修正结果计算类间规范化损失和类内规范化损失;所述损失计算模块还用于将学生logits与样本的原始真实标签进行训练,计算分类损失 ;所述损失计算模块还用于将所述类间规范化损失、所述类内规范化损失和所述分类损失三者进行加权得到总训练损失;
11、学生模型输出模块用于基于所述总训练损失进行梯度下降和反向传播,从而优化学生模型参数,随后利用训练集下一批次的样本对学习模型进行持续优化,直至学生模型收敛。
12、本发明还公开一种计算机终端,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时,实现如上所述的基于类关系的二维规范化知识蒸馏方法的步骤。
13、与现有技术相比,本发明的有益效果是:
14、1、本发明公开的基于类关系的二维规范化知识蒸馏方法,可以缩小学生模型与教师模型之间由于容量差距导致的性能问题,提升蒸馏效率。另外,该方法还克服了传统知识蒸馏中的局限性,通过将类间信息与类内相关性相结合,并利用规范化修正重构学生模型与教师模型的logits,使得学生模型能够更加关注logits之间的相对关系,从而更好地模仿教师模型。该方法实现简单但效果显著,大量实验证明了它在图像分类与目标检测任务上的有效性,并且具有显著的训练效率优势。
15、2、本发明公开的知识蒸馏系统以及计算机终端,通过应用上述方法,能够产生与上述方法相同的有益效果,在此不再赘述。
1.基于类关系的二维规范化知识蒸馏方法,其特征在于,包括:
2.根据权利要求1所述的基于类关系的二维规范化知识蒸馏方法,其特征在于,所述对教师logits和学生logits中的每个logit分别进行类间维度规范化修正和类内维度规范化修正包括:
3.根据权利要求2所述的基于类关系的二维规范化知识蒸馏方法,其特征在于,所述根据修正结果计算类间规范化损失和类内规范化损失包括:
4.根据权利要求3所述的基于类关系的二维规范化知识蒸馏方法,其特征在于,所述分类损失的计算公式为:
5.根据权利要求4所述的基于类关系的二维规范化知识蒸馏方法,其特征在于,所述总训练损失的计算公式为:
6.根据权利要求1所述的基于类关系的二维规范化知识蒸馏方法,其特征在于,所述样本为待执行计算机视觉任务的图片,所述计算机视觉任务的种类包括图像分类、目标检测和语义分割。
7.基于类关系的二维规范化知识蒸馏系统,其特征在于,包括:
8.一种计算机终端,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时,实现如权利要求1至6中任意一项所述的基于类关系的二维规范化知识蒸馏方法的步骤。
