首页 / 知识蒸馏的无源域无监督领域自适应学习方法

知识蒸馏的无源域无监督领域自适应学习方法实质审查 发明

技术领域

[0001] 本发明属于领域自适应模式分类技术领域,涉及一种知识蒸馏的无源域无监督领域自适应学习方法。

相关背景技术

[0002] 机器学习是人工智能时代的一个重要研究方向,它的研究成果被广泛的应用到各个领域,在人们的日常生活中占据着越来越重要的地位。当前机器学习技术通常依赖于封[1]闭环境,其学习过程涉及三个基础假设 :一是独立同分布假设:训练与测试数据有相同的特征空间和数据分布;二是封闭的类别标签空间:训练和测试数据类别空间相同;三是大数据假设:有足够可用的带标注的训练样本,能保证学习到一个较好的分类模型。然而开放环境的机器学习过程是不满足上述三个基础假设的。学习环境不满足上述三大假设条件的机器学习称为开放环境的机器学习。当这三个假设条件不满足时,必须要研究开放环境下的机器学习问题,使机器学习方法适应开放环境成为新一代人工智能技术研究需要解决的关键问题之一。
[0003] 领域自适应(Domain Adaptation,DA[2][3])能在源域数据上学习性能良好的模型,自适应到分布不同甚至类别标签空间不同的目标域数据,解决无标注样本或少量标注样本的目标域数据的学习任务。DA能适应开放环境的应用场景,是开放环境下的机器学习的主[4]要范式之一 ,它是解决当前封闭环境下机器学习局限问题(包括分布外场景学习、封闭标签类别空间、无标签和少标签等)的主要途径。DA帮助大型预训练模型处理各种应用,推动机器学习乃至人工智能的发展。
[0004] 然而现有的领域自适应方法都普遍依赖于源域数据的可访问性,然而在现实应用场景常存在如下问题。(1)大数据时代中数据安全,隐私保护都显得尤为重要,比如医疗数据和面部数据这类私密数据集并不适合在模型训练中频繁披露和使用;(2)深度神经网络的训练需要大量的数据支撑,过于依赖于源域数据也会带来数据存储和传输的成本过高的问题;(3)与此同时大量源域数据也会增加计算机的计算负担。因此研究者们为了解决上述[5]问题,实现源域数据在无监督领域自适应方法中的“解耦合”,提出了无源领域自适应[5]
(Source‑Free DA,SFDA)。现有的SFDA 以封闭集领域自适应为基础,在不访问任何源数据的情况下将预训练的源模型迁移到未标注的目标域,完成目标域的学习任务。现有SFDA方[6]
法分为白盒SFDA和黑盒SFDA两大类 ,两者区别在于预训练模型的参数是否可用。
[0005] 在无源无监督领域自适应中,通常会提供一个已用源域数据训练好的源域模型,在之后的目标域数据训练过程中,源域数据不再参与,以仅提供源域模型而不提供源域数据的方式来完成下游任务,很好地解决了上述数据安全、数据存储和计算负担等一系列问题。
[0006] [1]Z.Yao,C.Liu,C.Y.Suen.Towards robust pattern recognition:a review.Proceedings of the IEEE,2020,108(6):894‑922.
[0007] [2]S.J.Pan,Q.Yang.A survey on transfer learning.IEEE Transactions On Knowledge Data Engineering.2010,22(10):1345‑1359.
[0008] [3]L.Zhang,X.Gao.Transfer adaptation learning:a decade survey.IEEE Transactions on Neural Networks and Learning Systems,2024,35(1):23‑44.[0009] [4]袁晓彤,张煦尧,刘希,程真,刘成林.面向开放环境的机器学习理论研究进展.模式识别与人工智能,2023,36(12):1059‑1071.
[0010] [5]J.Liang,D.Hu,J.Feng.Do we really need to access the source data?Source hypothesis transfer for unsupervised domain adaptation.Proceedings of the 37th International Conference on Machine Learning,2020.
[0011] [6]Y.Fang,P.Yap,W.Lin,H.Zhu,M.Liu.Source‑free unsupervised domain adaptation:a survey.arXiv preprint arXiv:2301.00265,2022.

具体实施方式

[0022] 下面结合附图和具体实施方式对本发明进行详细说明。
[0023] 方法的整个训练过程分为两步,其中教师网络和学生网络可选取ResNet等网络作为基础骨干网络。
[0024] 首先第一步,训练源域模型。将带有标签的源域数据输入到教师网络中,利用交叉熵损失 来优化网络参数, 定义为:
[0025]
[0026] 其中B为Batch_size,
[0027] 然后第二步,进行领域自适应。在此过程中,将训练好的源模型与一个单独的旋转分类器组合作为教师网络,并将源模型的分类器参数固定,教师网络特征提取器参数更新照旧,师生网络只输入目标域数据。
[0028] 对于教师网络,设计最大化互信息损失 来约束目标域数据,使其结果呈现区分性和多样性。互信息损失分为两部分,第一部分是熵最小化损失 熵最小化损失降低了学生网络输出的未知性,使得伪标签更具有可区分性。 定义为:
[0029]
[0030] 另一部分是多样性最大化损失 它缓解了熵最小化带来的退化解的问题,使网络输出呈现多样化,分类效果更好。 定义为:
[0031]
[0032] 其中 表示 求均值后的结果。
[0033] 将式(2)和(3)组合得到互信息损失 其表达式为:
[0034]
[0035] 为了使得教师网络分类更加准确,方法引入k均值聚类和样本旋转这种自监督方法。
[0036] 在k均值聚类方法中首先要计算每个类的类原型作为质心,质心表达式为:
[0037]
[0038] 伪标签通过判断与质心的余弦距离来获取,其表达式为:
[0039]
[0040] 每个epoch更新两次类质心,更新后的质心表达式为:
[0041]
[0042] 通过聚类得到的伪标签做交叉熵损失,表达式为:
[0043]
[0044] 其中
[0045] 在样本旋转方法中,目标域数据将分别进行90度,180度和270度的旋转,将旋转后的样本与原样本组成一个有四张图片,标签为 的小数据集,再输入进教师网络,通过特征提取器和单独的旋转分类器,计算交叉熵损失:
[0046]
[0047] 其中
[0048] 对于学生网络,我们只输入目标域数据,计算交叉熵损失 表达式为:
[0049]
[0050] 其中
[0051] 类似于教师网络,学生网络中也设计了最大化互信息损失,表达式为:
[0052]
[0053] 然而,由于训练初期教师网络和学生网络的标签都不可靠,但是教师网络至少有正确标签的约束,而学生网络会在不可靠伪标签的影响下过早收敛,所以早期阶段需要保证教师网络标签选取的优先级,引入阈值Tp实现,其表达式为:
[0054]
[0055] 其中γ设置为10,p为当前已进行迭代次数占总迭代次数的比例,因此p的取值由0到1变化。教师网络在[Tp,1]的范围内保证选取标签的优先级,因此最终的竞争模块输出为:
[0056]
[0057] 其中阈值Tp由式(15)定义。
[0058] 在竞争机制的作用下,训练初期阶段,学生网络将受到教师网络的指导和约束,学习教师网络的知识,随着训练的进行,学生网络逐渐超越教师网络,完成知识蒸馏和知识迁移。

当前第1页 第1页 第2页 第3页
相关技术
学习方法相关技术
源域相关技术
叶征春发明人的其他相关专利技术