背景技术:
1、机器学习(ml)模型的联邦学习是一种越来越流行的用于训练ml模型的ml技术。在传统的联邦学习中,本地ml模型本地存储在用户的客户端装置上,并且作为本地ml模型的基于云的对应物的全局ml模型远程存储在远程系统(例如,服务器集群)处。使用本地ml模型,客户端装置可以处理在客户端装置处检测到的用户输入以生成预测的输出,并且可以将预测的输出与真实值(ground truth)输出进行比较以使用监督学习技术来生成梯度。此外,客户端装置可以向远程系统传输该梯度。远程系统可以利用该梯度,以及可选地利用在附加客户端装置处以类似方式生成的附加梯度来更新全局ml模型的权重。此外,远程系统可以向客户端装置传输全局ml模型或全局ml模型的经更新的权重。客户端装置随后可以用全局ml模型替换本地ml模型,或者用全局ml模型的经更新的权重替换本地ml模型的权重,从而更新本地ml模型。
2、需注意,全局ml模型可以在远程系统处使用服务器数据集进行初始训练,并且按照上述方式使用联邦学习框架进行微调。换句话说,全局ml模型可以在远程服务器处基于服务器数据集进行初始训练,直到该全局ml模型可用为止,并且随后可以以隐私敏感的方式并基于在推理时部署全局ml模型时更可能遇到的客户端数据进行微调。然而,以这种方式训练的ml模型可能容易发生灾难性遗忘,因为在使用这种朴素(naïve)微调时基于在客户端装置处生成的梯度更新全局ml模型的权重时,在初始训练中从服务器数据集中学习到的信息可能会被突然遗忘。
技术实现思路
1、本文公开的实现方式涉及在机器学习(ml)模型的联邦学习中实现各种技术以减轻ml模型的灾难性遗忘。实现方式可以识别基于服务器数据集在远程服务器处进行了初始训练(例如,引导)的给定全局ml模型,并确定要在生成对应客户端梯度时被利用的基于服务器的数据,所述对应客户端梯度基于对应客户端数据生成并被利用以对给定全局ml模型进行微调。此外,实现方式可以向至少给定客户端装置传输(i)给定全局ml模型和(ii)基于服务器的数据,以使给定客户端装置基于使用(i)给定全局ml模型并根据(ii)基于服务器的数据处理对应客户端数据来生成对应客户端梯度,并将对应客户端梯度传输回远程服务器。此外,实现方式可以基于从给定客户端装置接收的对应客户端梯度(以及可选地从附加客户端装置接收的附加的对应客户端梯度)生成给定的经更新的全局ml模型。
2、例如,假设给定全局ml模型对应于全局自动语音辨识(asr)模型,该asr模型基于远程服务器可用的音频数据语料库在远程服务器处进行了初始训练。在此示例中,可以基于全局asr模型的全局权重和被利用以对全局asr模型进行初始训练的音频数据语料库来确定基于服务器的数据。进一步假设向给定客户端装置传输全局asr模型和基于服务器的数据,以基于在给定客户端装置处本地生成的客户端数据对全局asr模型进行微调。在此示例中,给定客户端装置可以获得捕获客户端的用户的一个或多个口头话语的音频数据,并且可以使用全局asr模型(例如,响应于从远程服务器接收到全局asr模型而本地存储在给定客户端装置处的全局asr模型的实例)处理音频数据,以生成预测的输出(例如,所辨识的文本、预测的音素等)。此外,给定客户端装置可以使用各种监督或半监督学习技术基于预测的输出来生成对应客户端梯度。例如,给定客户端装置可以根据基于服务器的数据修改或增强对应客户端梯度,并将经修改或增强的对应客户端梯度传输回远程服务器。此外,远程服务器可以基于对应客户端梯度(以及可选地从参与全局asr模型中的联邦学习的其他客户端装置接收的其他对应客户端梯度)更新全局asr模型的全局权重,从而生成经更新的全局asr模型。
3、在一些实现方式中,基于服务器的数据可以包括给定全局ml模型的一个或多个对应全局权重的一个或多个对应弹性权重巩固(ewc)损失项。一个或多个对应ewc损失项可增加对应损失惩罚,该对应损失惩罚减慢在基于客户端数据对给定全局ml模型进行微调期间对给定全局ml模型的一个或多个对应全局权重的学习。换句话说,一个或多个对应ewc损失项确保当给定全局ml模型随后基于客户端数据(例如,经由对应客户端梯度)进行更新时,给定全局ml模型的一个或多个对应全局权重不会过度拟合,从而减轻和/或消除灾难性遗忘。
4、在这些实现方式中,可以基于例如对应fisher信息矩阵来确定一个或多个对应ewc损失项,该对应fisher信息矩阵是基于给定全局ml模型的一个或多个对应全局权重针对被利用以在远程服务器处对给定全局ml模型进行初始训练的服务器数据集确定的。fisher信息旨在测量可观测随机变量所携带的有关分布的未知参数的信息量,并且fisher信息矩阵可以被计算为以矩阵形式(例如hessian矩阵)表示的该度量的期望值。在基于多个服务器数据集对给定全局ml模型进行初始训练的实现方式中,多个服务器数据集中的每个服务器数据集可以与一组一个或多个对应ewc损失项相关联,这些ewc损失项是基于多个服务器数据集中的每个服务器数据集的对应fisher信息矩阵确定的。
5、在这些实现方式的一些版本中,可以基于fisher信息矩阵的对角线确定一个或多个对应ewc损失项。例如,fisher信息矩阵的对角线的对应值可被利用以作为一个或多个对应全局权重的一个或多个对应ewc损失项。例如,fisher信息矩阵的对角线的第一个值(例如,行1,列1)可被利用以作为第一个全局权重的第一个ewc损失项,fisher信息矩阵的对角线的第二个值(例如,行2,列2)可被利用以作为第二个全局权重的第二个ewc损失项,fisher信息矩阵的对角线的第三个值(例如,行3,列3)可被利用以作为第三个全局权重的第三个ewc损失项,对于全局ml模型的每个其他全局权重依此类推。在这些实现方式的附加或替代版本中,可以在确定一个或多个ewc损失项时利用附加或替代值或值的组合,使得可以在更新一个或多个全局权重中的多个全局权重时利用给定ewc损失项,并且/或者可以在更新一个或多个全局权重中的给定全局权重时利用多个ewc损失项,但这些实现方式在计算上可能不那么高效。
6、继续上述示例,其中给定全局ml模型对应于全局asr模型,在给定客户端装置处本地生成并传输回服务器的对应客户端梯度可以使用一个或多个对应ewc损失项进行修改或增强。例如,在被传输回远程服务器之前,在给定客户端装置处本地生成的对应客户端梯度可以以加权或非加权的方式与一个或多个对应ewc损失项组合。因此,当远程服务器随后基于对应客户端梯度更新全局asr模型时,经更新的全局asr模型的一个或多个经更新的权重不会过度拟合到客户端数据。
7、在这些实现方式的一些版本中,以及在对给定全局ml模型的训练的后续迭代处,实现方式可以基于服务器数据集并基于给定的经更新的全局ml模型的一个或多个对应的经更新的全局权重来确定经更新的fisher信息矩阵和一个或多个对应的经更新的ewc损失项(例如,经更新的基于服务器的数据)。此外,实现方式可以向至少给定客户端装置传输(iii)给定的经更新的全局ml模型和(iv)经更新的基于服务器的数据,以使给定客户端装置基于使用(iii)给定的经更新的全局ml模型并基于(iv)经更新的基于服务器的数据处理对应的附加客户端数据来生成对应的附加客户端梯度,并将对应的附加客户端梯度传输回远程服务器。此外,实现方式可以基于从给定客户端装置接收的对应的附加客户端梯度(以及可选地从附加客户端装置接收的进一步附加的对应客户端梯度)生成给定的进一步经更新的全局ml模型。可以以这种方式继续对给定全局ml模型进行微调,直到满足用于使给定全局ml模型部署在给定客户端装置和/或多个附加客户端装置处以进行推理的一个或多个条件。
8、通过使用本文描述的技术,可以实现各种技术优势。作为一个非限制性示例,在利用本文所述的一个或多个对应ewc损失项时,当基于给定的对应客户端梯度更新ml模型时,可以通过惩罚给定的对应客户端梯度对ml模型的权重的影响程度来减轻和/或消除ml模型的灾难性遗忘。因此,ml模型在准确率和/或召回率方面可更加稳健。作为另一个非限制性示例,对于联邦学习的给定迭代,可以在远程服务器处确定基于服务器的数据,同时保持客户端数据的安全性,从而消除了客户端装置在确定在客户端装置处本地被利用以减轻和/或防止灾难性遗忘的基于服务器的数据时消耗不必要的计算资源和/或网络资源的需要。例如,客户端装置不需要处理被利用以对给定全局ml模型进行初始训练的大量数据。相反,参与给定全局ml模型的微调的每个客户端装置可以利用在服务器处确定的基于服务器的数据。
9、提供上述描述作为对本公开的一些实现方式的概述。下文将更详细地描述那些实现方式以及其他实现方式的进一步描述。
1.一种由一个或多个处理器实现的方法,所述方法包括:
2.如权利要求1所述的方法,
3.如任一项前述权利要求所述的方法,进一步包括:
4.如任一项前述权利要求所述的方法,进一步包括:
5.如任一项前述权利要求所述的方法,其中所述全局ml模型进一步基于附加服务器数据集在所述远程服务器处进行初始训练,所述方法进一步包括:
6.如任一项前述权利要求所述的方法,进一步包括:
7.如权利要求6所述的方法,其中,所述一个或多个条件包括以下中的一者或多者:在生成所述经更新的全局ml模型时是否利用了阈值数量的梯度、自所述经更新的全局ml模型更新以来是否已经过去了阈值持续时间、或所述经更新的全局ml模型的性能是否满足阈值性能度量。
8.如任一项前述权利要求所述的方法,其中,向所述给定客户端装置传输(i)所述全局ml模型和(ii)所述一个或多个全局权重中的每个全局权重的所述对应ewc损失项进一步使所述给定客户端装置:
9.如任一项前述权利要求所述的方法,其中,所述一个或多个全局权重中的每个全局权重的所述对应ewc损失项对应于所述fisher信息矩阵的对角线。
10.如任一项前述权利要求所述的方法,其中,在对所述全局ml模型进行初始训练时利用的所述服务器数据集是从公开可用的多媒体数据储存库中获得的。
11.如任一项前述权利要求所述的方法,其中,所述全局ml模型是在处理音频数据时利用的基于音频的全局ml模型。
12.如权利要求1至10中任一项所述的方法,其中,所述全局ml模型是在处理视觉数据时利用的基于视觉的全局ml模型。
13.一种由一个或多个处理器实现的方法,所述方法包括:
14.如权利要求13所述的方法,其中,向所述远程服务器传输所述给定客户端梯度进一步使所述远程服务器:
15.如权利要求13或权利要求14所述的方法,进一步包括:
16.如权利要求15所述的方法,其中,所述一个或多个条件包括以下中的一者或多者:当日时间、周中此日、所述给定客户端装置是否正在充电、所述给定客户端装置是否具有至少阈值荷电状态、所述给定客户端装置的温度是否低于温度阈值、或所述给定客户端装置是否正在由所述给定客户端装置的给定用户持有。
17.如权利要求13至16中任一项所述的方法,其中,所述全局ml模型是基于音频的全局ml模型,其中,所述客户端数据是由所述给定客户端装置的一个或多个传声器在所述给定客户端装置处本地生成的音频数据,并且其中,处理所述客户端数据以生成所述预测的输出包括:
18.如权利要求13至17中任一项所述的方法,其中,所述全局ml模型是基于视觉的全局ml模型,其中,所述客户端数据是由所述给定客户端装置的一个或多个视觉组件在所述给定客户端装置处本地生成的视觉数据,并且其中,处理所述客户端数据以生成所述预测的输出包括:
19.一种由一个或多个处理器实现的方法,所述方法包括:
20.如权利要求19所述的方法,进一步包括:
21.如权利要求20所述的方法,其中,所述一个或多个ewc损失项中的每个ewc损失项针对n个全局权重,并且其中,n是大于一的正整数。
22.一种远程服务器,包括至少一个远程处理器和存储指令的远程存储,所述指令在由所述至少一个远程处理器执行时使所述至少一个远程处理器执行如权利要求1至12中任一项所述的方法。
23.一种客户端装置,包括至少一个客户端装置处理器和存储指令的客户端装置存储,所述指令在由所述至少一个客户端装置处理器执行时使所述至少一个客户端装置处理器执行如权利要求13至21中任一项所述的方法。