首页 / 基于强化学习的弱监督自学习方法

基于强化学习的弱监督自学习方法无效专利 发明

技术领域

[0001] 本发明涉及弱监督自适应策略选择领域,尤其涉及一种基于强化学习的弱监督自学习方法。

相关背景技术

[0002] 监督学习技术通过学习大量训练样本来构建预测模型,其中每个训练样本都有一个标签标明其真值输出,模型的效果很大程度上依赖于标签的质量。而在实际应用中,由于数据标注过程的高成本,很难获得如全部真值标签的强监督信息。这便产生了弱监督问题,如何在弱监督的数据上训练高效的模型成为当下的一大研究热点。
[0003] 目前已经存在一些基于机器学习甚至深度学习的弱监督问题解决方法,如主动学习、直推学习、标签传播等,它们大都只是单一地运用于弱监督场景中,而在复杂多变的实际应用中,方法的选择耗时耗力,同时单一方法也往往不足以训练出高效的模型。

具体实施方式

[0043] 下面根据附图1~图3,给出本发明的较佳实施例,并予以详细描述,使能更好地理解本发明的功能、特点。
[0044] 请参阅图1和图2,本发明实施例的一种基于强化学习的弱监督自学习方法,包括步骤:
[0045] S1:收集获取并预处理弱监督数据,获得弱监督数据特征集合。
[0046] 其中,S1步骤进一步包括步骤:
[0047] S11:数据清理,通过对原始弱监督数据进行填写缺失值、光滑噪声和识别解决数据不一致来实现数据的格式化、异常数据的清除错误纠正以及重复数据的清除;
[0048] S12:数据变换,通过平滑聚集,数据概化或规范化的方式将数据转换成学习模型需要的形式;
[0049] S13:数据表征,通过对业务属性进行邻接性,聚集性和结构性分析,对数据变换后的原始弱监督数据进行重新表征,获得弱监督数据特征集合。
[0050] S2:基于弱监督数据特征集合及业务场景,对弱监督场景进行信息量化。
[0051] 其中,S2步骤进一步包括步骤:
[0052] S21:量化弱监督数据特征集合的内部信息,内部信息包括标签比率、标签的均衡度比率和数据的分布与标签的互信息量;
[0053] S22:量化弱监督业务场景的外部辅助量,主要包含是否有可靠的模式或者业务规则以及众包数据的来源可信度等。
[0054] S3:基于DQN算法训练强化学习模型,确定弱监督算法调度策略。
[0055] 其中,S3进一步包括步骤:
[0056] S31:建立强化学习模型;强化学习模型包括一强化学习的状态集和一动作集,强化学习的状态集为连续状态空间,包括内部信息和外部辅助量,动作集包括若干弱监督算法,如主动学习、直推学习、标签传播、数据编辑、Snorkel等;
[0057] S32:初始化重播缓冲区D,重播缓冲区D中存储着之前智能体所经历的行为,用来在训练神经网络的时候打破经历之间的相关性,并且可以解决非静态分布问题;
[0058] 初始化一Q网络,记作Q,Q网络随机生成权重θ;该网络对应于Q-Learning算法中的Q函数,可以解决状态空间连续的问题,实际该网络是对于状态到动作的映射的拟合;
[0059] 初始化一target Q网络,记作 target Q网络结构与Q网络完全相同,target Q网络随机生成权重θ′;θ′=θ;该网络用于经验重放,是若干次迭代之前的Q网络;
[0060] 初始化状态s={x1,x2,…,xn}。其中,x1,x2,…,xn为S2步骤中定义的原始数据集的内部信息和外部辅助量;
[0061] S33:将当前状态s输入Q,输出所有动作对应的Q值Q(s,a;θ),a表示动作;基于ε-greedy策略选择一个动作a,有概率ε根据Q(s,a;θ)中最大值选择对应的动作,此时a=argmaxaQ(s,a;θ),而有概率(1-ε)随机选择一个动作;
[0062] S34:根据当前选择的动作a,使用对应的弱监督算法对当前弱监督数据特征集合进行标签增强,获得标签增强后的新数据集;
[0063] S35:评估新数据集的标签效用并反馈给智能体奖赏值,分别使用本次标签增强前的有标签数据和新数据集训练一个预测模型,并在一测试集结果进行预测,通过计算准确率、召回率和打扰率对两个模型的预测结果进行评估,并根据评估结果反馈给智能体一个奖励值;
[0064] 准确率precision表达为公式(1):
[0065]
[0066] 召回率recall表达为公式(2):
[0067]
[0068] 打扰率disturb表达为公式(3):
[0069]
[0070] 其中,TP为模型将正类判定为正类的数量,FP为模型将负类判定为正类的数量,FN为模型将正类判定为负类的数量,TN为模型将负类判定为负类的数量;
[0071] S36:将本次转换存储在重播缓冲区D中,记作(st,at,rt,st+1),其中st为本次动作之前的环境,at为本次执行的动作,rt为奖励值,st+1为执行本次动作后的环境状态;
[0072] S37:从重播缓冲区D中随机抽取一个minibatch的样本,并使用梯度下降法对Q进行更新,损失函数Loss表达为公式(4):
[0073] Loss=(yj-Q(sj,aj;θ))2   (4);
[0074] 其中,Q(sj,aj;θ)为在状态sj在执行动作aj时对应的Q值;
[0075]
[0076] 其中,第一种取值在sj+1为最终状态下成立,rj为在状态sj执行动作aj后环境反馈给智能体的奖励值,γ为衰变常数,maxa′ 为在输入为sj+1的情况下任意取a′∈A中 的最大值,A为S31中所定义的动作集;
[0077] S38:每隔若干步,更新 网络,将Q网络拷贝至target Q网络,使θ′=θ;
[0078] S39:重复步骤S33至S38直至标签增强后的数据集达到期望。
[0079] 以上结合附图实施例对本发明进行了详细说明,本领域中普通技术人员可根据上述说明对本发明做出种种变化例。因而,实施例中的某些细节不应构成对本发明的限定,本发明将以所附权利要求书界定的范围作为本发明的保护范围。

当前第1页 第1页 第2页 第3页