技术领域
[0001] 本发明涉及基于延时摄影的胚胎发育检测,具体地,涉及一种胚胎发育检测装置及其训练平台。
相关背景技术
[0002] 利用AI人工智能技术来对影像资料进行识别已经成为医疗领域降本增效的有效途径之一。已经有人应用人工智能算法来代替胚胎学家对胚胎培养的延时摄影(Time‑lapse)资料进行分析来实现自动化的胚胎生长发育状态评估。
[0003] 这样的延时摄影设备是现有技术中已知的,通过将高分辨率的摄像头与胚胎培养箱相结合,以特定的间隔、频率、角度等对胚胎发育过程进行连续动态监测并且拍摄图像,无需频繁干扰胚胎培养箱内环境,即能够对胚胎发育过程进行形态学观测和分析。
[0004] 对此,需要在延时摄影设备所拍摄的胚胎时间序列图像精确识别并测量指示胚胎发育状态的特定特征和参数,例如识别原核、核仁、细胞数量、细胞形状及面积等。
[0005] 现有技术中,为实现图像中特定目标检测,往往在候选区域产生大量的可能包含目标物的边框,再用分类器去判断每个边框里是否包含有目标物,以及目标物所属类别的概率或者置信度,通过后处理来改善边框,消除重复的检测目标。例如,CN111539308A公开了利用Faster RCNN的目标检测神经网络定位胚胎图像中卵裂球区域的边框,Faster RCNN首先使用卷积神经网络提取胚胎图像的特征图,通过区域提议网络获得多个提议区域,使用ROI‑Align层对每个提议区域进行池化,最后通过全连接层进行边框回归和逻辑回归,得到卵裂球区域的边框。
[0006] 在上述采用滑窗或者候选区域的目标检测系统中,分类器只能得到图像的局部信息,因此在检测物体时不能很好的利用上下文信息,从而容易在背景上预测出错误的物体信息,这导致需要精确度的胚胎图像识别结果不可靠。
具体实施方式
[0026] 以下在具体实施方式中详细叙述本发明的详细特征以及优点,其内容足以使任何本领域技术人员了解本发明的技术内容并据以实施,且根据本说明书所揭露的说明书、权利要求及附图,本领域技术人员可轻易地理解本发明相关的目的及优点。
[0027] 现有技术中,通过对胚胎进行覆盖多个胚胎阶段的延时(Time‑lapse)摄影,获取胚胎的多个图像组成的原始图像集PP。这样的延时摄影设备是现有技术中已知的,通过将高分辨率的摄像头与胚胎培养箱相结合,以特定的间隔、频率、角度等对胚胎发育过程进行连续动态监测并且拍摄图像,无需频繁干扰胚胎培养箱内环境,即能够对胚胎发育过程进行形态学观测和分析。
[0028] 本发明第一方面涉及胚胎发育检测装置的训练平台100。如图1所示,该训练平台100包括胚胎图像增强模块1、胚胎图像训练集构建模块2、模型总损失判断模块3、参数调节模块4。
[0029] 【胚胎图像增强模块】
[0030] 胚胎图像训练集构建模块2用于针对选定的胚胎特征对输入图像进行标注并构建与胚胎特征对应的胚胎特征训练集。
[0031] 胚胎图像增强模块1接收胚胎的至少两个图像组Ga,Gb,两个图像组Ga,Gb分别包括通过对胚胎进行延时摄影获取的多个胚胎图像,且两个图像组Ga,Gb分别对应于两个不同的胚胎发育阶段;并且胚胎图像增强模块1将两个图像组Ga,Gb的多个胚胎图像通过加权相加方式进行图像融合后作为胚胎图像训练集构建模块2的输入图像。
[0032] 可将来自不同胚胎阶段的任意两张胚胎图像进行融合,融合比列为lam:0.7.相加时两张胚胎图像对应的每个像素值直接相加。即inputs=0.7*images+0.3*images_random,其中images可为来自图像组Ga的图像,对应于胚胎发育第一阶段,images_random可为来自图像组Gb的图像,对应于不同于第一阶段的其余胚胎发育阶段的图像,例如胚胎发育第四阶段的图像。如此,有利于改善胚胎第四阶段之后的图像模糊性,有利于增加对胚胎第四阶段之后的胚胎图像识别的准确率。在此基础上,在后续训练阶段中,将不同胚胎阶段的任意两张胚胎图像的标签信息也进行融合,构建成新的数据集用于训练,提高模型的泛化能力。
[0033] 此外,胚胎图像增强模块1还可对胚胎图像进行各种增强处理,包括但不限于图像生成、图像融合、图像尺寸变化、旋转、翻折等。
[0034] 在本发明一种实施方式中,胚胎图像增强模块可将每一个胚胎图像的像素数据归一化到[‑1,1]之间,具体地,可计算图像数据的均值以及方差,然后利用原数据减去均值除以方差,使得数据归一化到[‑1,1]之间,使得模型更容易收敛;在此基础上,还可对归一化数据进行数据增强,包括:指定0.5的概率,翻转垂直或水平方向的图像。
[0035] 在本发明另一种实施方式中,胚胎图像增强模块1可将像素大小范围在320‑960像素之间的胚胎图像自动适应为与本发明所采用的胚胎目标检测初始模型M0匹配的640*640像素图像,这使得针对不同像素大小的胚胎原始图像都可以进行模型训练,避免对原始图像进行填充或者缩放带来的信息损失,并且提高了本发明所生成的胚胎目标检测模型M1对胚胎中小目标的检测准确率。
[0036] 【胚胎图像数据集构建模块】
[0037] 胚胎图像数据集构建模块2对延时摄影视频数据进行多种标注,生成数据集。
[0038] 根据关注特征的不同,对经胚胎图像增强模块处理的胚胎图像数据进行标注,预定选取的标注内容例如可包括细胞分裂的时间点,包括第一次卵裂时间、第二次卵裂时间、第三次卵裂时间,或者1细胞到2细胞、3细胞、4细胞、5细胞时间;细胞面积(二分裂时,各细胞面积);原核区分(雌性原核和雄性原核)、原核生成时间、原核消失时间、原核消失前8小时、雌雄原核的面积;核仁数量、原核消失前三小时核仁分布等。
[0039] 标注后数据生成对应的数据集,包括:分裂时间点数据集、细胞面积数据集、雌雄原核变化时间点数据集、原核面积数据集、核仁数量数据集、核仁分布模式数据集等。
[0040] 本发明中,既可以通过医师或经过培训的工作人员人工完成图像标准,也可采用自动标注模块完成胚胎图像标注。
[0041] 【目标检测网络初始模型】
[0042] 本发明中的目标检测网络初始模型M0整体结构如图2所示。
[0043] 将胚胎特征训练集送入胚胎目标检测初始模型M0后,模型总损失判断模块3根据损失函数计算胚胎目标检测初始模型M0的总损失,并由参数调节模块4根据总损失对胚胎目标检测初始模型M0中至少一个参数进行修改,以生成与胚胎特征对应的胚胎目标检测模型M1。因此目标检测初始模型M0与本发明第一方面所涉及的训练平台所生成的胚胎目标检测模型M1具有相同的结构,但其各个组成部分的参数设定有所不同。
[0044] 目标检测网络初始模型M0采用PANET采样方式进行信息采样,包括:增加从底向上路径扩充方式,充分利用胚胎图像信息,增强了图像特征,避免丢失信息。
[0045] 参见图2,胚胎目标检测初始模型M0为路径聚合网络,其中包括依次连接的切片层、第一下采样层、第二下采样层、第三下采样层、空间金字塔层、第一上采样层、第二上采样层、第一输出层、第二输出层、第三输出层和检测层,即,每一层的输出作为下一层的输入。
[0046] 其中切片层接收胚胎特征训练集;并且第一上采样层包括依次连接的第一残差卷积模块RES1、第一上采样模块LOS1,并且将第一上采样模块LOS的输出图像与第三下采样层的输出相连接,以作为第二采样层的输入图像;第二上采样层包括依次连接的第二残差卷积模块RES2、第二上采样模块LOS2,并且将第二上采样模块LOS的输出图像作为第一输出层的输入图像;
[0047] 第一输出层包括第一输出模块LO1,并且将第二上采样模块LOS2的输出图像与第二下采样模块LUS2的输出相连接,以作为第一输出模块LO1的输入图像;
[0048] 第二输出层包括第二输出模块LO2,并且将第一输出模块LO1的输出图像经过卷积操作conv后与第二残差卷积模块RES2的输出相连接,以作为第二输出模块LO2的输入图像;
[0049] 第三输出层包括第三输出模块LO3,并且将第二输出模块LO2的输出图像经过卷积操作conv后与第一残差卷积模块RES1的输出相连接,以作为第三输出模块LO3的输入图像;
[0050] 检测层包括检测模块D,第一输出层、第二输出层和第三输出层的输出图像分别作为检测模块的输入图像。
[0051] 下面详细描述每一层及其组成。
[0052] 【切片层】
[0053] 切片层包括切片模块F,该切片模块F包括切片操作Slice、连接操作Concat、和切片层CBS操作,其中:
[0054] 切片操作:首先接收大小640*640、通道数为3的原始输入图像,并将原始输入图像进行三通道切片操作,生成四个大小为320*320的切片;
[0055] 在连接操作concat中,将上述四个大小为320*320的切片进行连接,即,将相同大小的图像进行通道的扩充,得到12通道的320*320的图像;
[0056] 而后,进行设定为{32*12*3*3}的CBS操作,CBS操作包括卷积操作(C)、归一化(B)和SiLu(S)三个步骤,其中32代表卷积操作中采用32个卷积核,12代表12个输入通道,3*3代表卷积核大小;归一化(B)和SiLu(S)为现有技术中已知,最终得到32个通道的320*320图像。
[0057] 【第一下采样层】
[0058] 第一下采样层包括第一下采样模块LUS1,参见图4,该第一下采样模块LUS1包括卷积操作conv、多次CBS操作组合而成的单个残差网络、像素叠加add、连接操作concat和一次额外CBS操作,其中:
[0059] 在卷积操作conv中,接收来自切片层输出的32通道的320*320图像并将其经过64个卷积核、32通道、3*3卷积核大小的卷积操作,得到64通道的160*160图像;
[0060] 单个残差网络由利用四个CBS操作组合而成,如图4所示:
[0061] 下采样层第一CBS操作C1接收卷积操作conv输出的64通道的160*160图像,其设定为{32*64*1*1},其输出为32通道的160*160图像;
[0062] 下采样层第二CBS操作C2接收下采样层第一CBS操作C2输出的32通道的160*160图像,其设定为{32*32*1*1},其输出为32通道的160*160图像;
[0063] 下采样层第三CBS操作C3接收下采样层第二CBS操作C2输出的32通道的160*160图像,其设定为{32*32*3*3},其输出为32通道的160*160图像;
[0064] 下采样层第四CBS操作C4与下采样层第一CBS操作C1同样接收下采样层卷积操作conv输出的64通道的160*160图像并与下采样层第一CBS操作C1同样设定为{32*64*1*1}、其输出为32通道的160*160图像;
[0065] 在像素叠加add步骤中,将下采样层第三CBS操作C3输出的32通道的160*160图像与下采样层第一CBS操作C1输出的32通道160*160图像进行像素叠加,输出仍为32通道的160*160图像;
[0066] 连接操作contact:将像素叠加后的32通道的160*160图像与下采样层第四CBS操作C4输出的32通道的160*160图像进行连接,输出64通道的160*160图像;
[0067] 而后,再进行一次CBS操作,对连接操作输出的64通道的160*160图像进行设定为{64*64*1*1}的CBS操作,输出64通道的160*160图像,作为下采样层1的输出。
[0068] 【第二下采样层】
[0069] 第二下采样层包括第二下采样模块LUS2,如图5所示。
[0070] 对于第二下采样模块LUS2,其设定20~29分别为:{64*128*3*3}、{64*128*1*1}、{64*64*1*1}、{64*64*3*3}、{64*64*1*1}、{64*64*3*3}、{64*64*1*1}、{64*64*3*3}、{128*128*1*1},从而实现将来自第一下采样层的64通道160*160图像经过128个卷积核3*3卷积操作,得到128通道80*80图像,而后利用三层残差网络结构,对128通道80*80图像进行一次CBS操作,最终输出通道128的80*80图像。
[0071] 【第三下采样层】
[0072] 第三下采样层包括第三下采样模块LUS3,其具有与第二下采样模块LU2相同的结构,如图5所示。
[0073] 对于第三下采样模块LUS3来说,其设定30~39分别为:{128*256*3*3}、{128*256*1*1}、{128*128*1*1}、{128*128*3*3}、{128*128*1*1}、{128*128*3*3}、{128*128*1*1}、{128*128*3*3}、{245*256*1*1},从而实现将来自第二下采样模块LU2的128通道80*80图像经过256个卷积核3*3卷积操作,得到256通道80*80图像,而后利用三层残差网络结构,对
256通道80*80图像进行一次CBS操作,最终输出256通道40*40图像。
[0074] 【空间金字塔层】
[0075] 空间金字塔层包括空间金字塔模块LSPP,如图6所示。
[0076] 空间金字塔模块LSPP接收来自第三下采样层的256通道40*40图像,并且经过一次卷积核大小3*3、卷积核个数512的卷积操作conv,得到512通道20*20图像;
[0077] 再采用一次设定为{256*512*1*1}的CBS操作,得到256通道20*20的图像;
[0078] 而后,分别进行最大池化操作和深度拼接,得到1024通道20*20图像;最终经过一次设定为{512*1024*1*1}的CBS操作,输出512通道20*20图像。
[0079] 【第一上采样层】
[0080] 第一上采样层依次包括第一残差卷积模块RES1和第一上采样模块LOS1。
[0081] 对于第一残差卷积模块RES1来说,参见图7,其设定40~45分别为:{256*512*1*1}、{256*256*1*1}、{256*256*3*3}、{512*512*1*1}、{256*512*1*1}、{256*512*1*1};从而接收空间金字塔层输出的512通道20*20图像并输出256通道20*20图像。
[0082] 第一上采样模块LOS1中,再将上述256通道20*20图像经过卷积核为2*2上采样操作,输出256通道40*40图像。
[0083] 此后,经过连接操作concat,将第三下采样层LUS3输出的256通道40*40图像与第一上采样模块LOS1输出的256通道40*40图像进行连接,输出512通道40*40图像。
[0084] 【第二上采样层】
[0085] 第二上采样层包括第二残差卷积模块RES2和第二上采样模块LOS2。
[0086] 第二残差卷积模块RES2与第一残差卷积模块RES1具有相同的结构,如图7所示;对于第二残差卷积模块RES2,其设定50~50分别为:{128*512*1*1}、{128*128*1*1}、{128*128*3*3}、{256*256*1*1}、{128*256*1*1}、{256*512*1*1},从而接收第一上采样层输出的
512通道40*40图像并输出128通道40*40图像,在第二上采样操作LOS2中,再将128通道40*
40图像经过卷积核为2*2的上采样操作转换为128通道80*80图像。
[0087] 【第一输出层】
[0088] 第一输出层包括一次连接操作concat和第一输出模块LO1。
[0089] 首先接收来自第二上采样层LOS输出的128通道80*80图像,在连接操作concat中,将第二下采样层输出的128通道80*80图像与第二上采样模块LOS2输出的128通道80*80图像进行连接,得到256通道80*80图像。
[0090] 对于第一输出模块LO1来说,参见图8,其设定60~64分别为{64*256*1*1}、{64*64*1*1}、{64*64*3*3}、{128*128*1*1}、{64*256*1*1},从而在第一输出模块LO1中,将上述
256通道80*80图像经过如图8所示的多次CBS操作及其组合后,输出128通道80*80图像。
[0091] 【第二输出层】
[0092] 第二输出层依次包括一次卷积操作conv、一次连接操作concat和第二输出模块LO2。
[0093] 在卷积操作conv中,将来自第一输出模块LO1的128通道80*80图像经过卷积核128个大小为3*3、步长为2的卷积操作,得到128通道40*40图像。
[0094] 在连接操作concat中,将上述128通道40*40图像与第二上采样层的第二残差卷积模块RES2的128通道40*40图像进行连接,得到256通道40*40图像。
[0095] 第二输出模块LO2具有与第一输出模块LO1相同的结构,如图8所示,对于第二输出模块LO2来说,其设定70~74分别为{128*256*1*1}、{128*128*1*1}、{128*128*3*3}、{256*256*1*1}、{128*128*1*1}、{128*256*1*1},从而在第二输出模块LO2中,将上述256通道40*
40图像经过如图8所示的多次CBS操作及其组合后,输出256通道40*40图像。
[0096] 【第三输出层】
[0097] 第三输出层依次包括一次卷积操作conv、一次连接操作concat和第三输出模块LO3。
[0098] 在卷积操作conv中,将来自第二输出模块LO2的256通道40*40图像,在经过卷积核256个大小为3*3、步长为2的卷积操作,得到256通道20*20图像。
[0099] 在连接操作concat中,将上述256通道20*20图像与第一上采样层的第一残差卷积模块RES1的256通道20*20图像进行连接,得到512通道20*20图像;
[0100] 第三输出模块LO3具有与第一输出模块LO1、第二输出模块LO2相同的结构,如图8所示,对于第三输出模块LO2来说,其设定80~84分别为{256*512*1*1}、{256*256*1*1}、{256*256*3*3}、{512*512*1*1}、{256*512*1*1};从而在第三输出模块LO2中,将上述256通道40*40图像经过如图8所示的多次CBS操作及其组合后,输出512通道20*20图像。
[0101] 【检测层】
[0102] 如图9所示,检测层包括检测层模块D,该检测层模块D包括三个锚框层DML1、DML2、DML3,该三个锚框层的感受野分别为8*8、16*16和32*32,分别对应于第一输出层的80*80个特征图网格、第二输出层的40*40个特征图网格和第三输出层的20*20个特征图网格。
[0103] 每一锚框层分别接收来自第一输出层、第二输出层和第三输出层的输出图像,且每一锚框层DML1、DML2、DML3分别具有三个预测框。
[0104] 下面,以第一层锚框层DML1接收来自第一输出层的输出图像为例进行描述,则第一输出层输出的128通道80*80图像送入到第一层锚框层DML1,第一层锚框层DML1会通过其所含有三个预测框F11、F12、F13去预测第一输出层输出的128通道80*80图像,即,80*80个特征图网格中的每一个网格都会被第一锚框层DML1的三个预测框F11、F12、F13进行卷积识别,因此第一输出层的输出图像经过第一层锚框层DML1后会得到3*80*80个预测结果,每一个预测结果信息中包括目标类别、目标中心点及宽高坐标、置信度。
[0105] 即,第一层锚框层DML1得到第一输出层的128通道80*80图像的80*80个特征图网格,每个特征图网格的感受野为640/80=8*8大小,负责检测小目标,输出3*80*80个小尺寸检测目标的信息。
[0106] 类似地,第二锚框层DML2得到第一输出层的128通道80*80图像的40*40个特征图网格,每个特征图网格的感受野为640/40=16*16大小,负责检测中等目标;将该40*40个特征图网格结合第二层锚框层DML2中的三个预测框F21、F22、F23进行卷积操作,输出3*40*40个中尺寸检测目标的信息。
[0107] 类似地,第三锚框层DML3得到第一输出层的128通道80*80图像的20*20个特征图网格,每个特征图网格的感受野为640/20=32*32大小,负责检测大型目标;将该20*20个特征图网格结合第三层锚框层DML3中的三个预测框F31、F32、F33进行卷积操作,输出3*20*20个大尺寸检测目标的信息。
[0108] 最终,根据检测层输出每个网格的置信度obj,判断该网格中是否包含胚胎目标。具体地,将置信度obj的阈值设定为0.5,即,置信度ob不超过0.5的目标被去除;根据非极大值抑制算法对目标信息中预测框进行筛选、剔除对同一个胚胎目标重复检测矩形框的目标、最后根据目标信息中分类概率值,对胚胎目标概率值最大的目标进行保留信息,输出对应的胚胎的类别cls、置信度obj、胚胎目标的中心点及宽高坐标以生成矩形框D1。
[0109] 如图10(a)示出送入本发明胚胎发育检测装置的原始图像,图10(c)示出本发明胚胎发育检测装置的一个检测结果,例如为基于第一输出层生成的矩形框D1表征胚胎的大小以及精确位置。具体地,矩形框D1的左上角标识“t3”是本发明胚胎发育检测装置检测得出的与该胚胎图像对应的胚胎分类cls,“0.92”是表示该分类结果的置信度obj,并且根据检测模型输出的中心点坐标和宽高生成该矩形框D1。
[0110] 【模型总损失判断模块】
[0111] 模型总损失判断模块3采用BCEWithLogitsLoss损失函数和CIOU损失函数计算胚胎总损失loss。
[0112] 具体地,本发明中,模型总损失是由三个损失模块构成,分别是胚胎矩形框损失box_loss、置信度损失obj_loss和胚胎分类概率损失cls_loss,胚胎总损失Loss为以上三个损失的加权和。
[0113] 矩形框表征胚胎的大小以及精确位置,矩形框损失box_loss用来计算检测层输出的预测胚胎矩形框D1与实际图片中胚胎标签中的矩形框坐标之间距离误差。
[0114] 矩形框损失box_loss采用CIOU损失函数,其公式如下:
[0115]
[0116]
[0117]
[0118] 如图10(b)所示,其中:
[0119] Distance_2:胚胎预测框RG的中心点和胚胎真实框RR的中心点的欧氏距离;
[0120] Distance_C:C的对角线距离(框RY左上角和右下角对角线距离);
[0121] C:胚胎预测框RG和胚胎真实框RR的最小外接矩阵,如图中框RY所示;
[0122] v:(其中w为宽,h为高,gt代表为胚胎真实框RR,p代表胚胎预测框RG)长宽比影响因子
[0123] IOU是指胚胎检测模型输出的矩形框与原图标签矩形框的交并比,如图10(b)所示。首先先计算出两个矩形框相交的面积:
[0124] S1=(xp2‑xl1)*(yp2–yl1)
[0125] 计算相并部面积:
[0126] S2=(xp2–xp1)*(yp2–yp1)+(yl2–yl1)*(xl2–xl1)‑S1
[0127] 因此IOU的计算为:
[0128] IOU=S1/S2
[0129] CIOU在IOU的基础上将预测矩形框重叠面积、中心点距离、宽高比加入了计算。
[0130] 置信度表征所预测框的可信程度,取值范围0~1,值越大说明该矩形框中越可能存在胚胎,置信度损失obj_loss计算网络的置信度;
[0131] 分类概率损失表征胚胎的类别,分类概率损失cls_loss计算检测层输出的胚胎类别与实际图片中胚胎的类别是否正确。
[0132] 胚胎分类概率损失cls_loss和置信度损失obj_loss采用BCEWithLogitsLoss损失函数,其公式如下:
[0133]
[0134] 如在胚胎分类概率损失cls_loss的计算中,yn指标签中胚胎的类别,xn是胚胎检测模型输出的预测胚胎类别值,在置信度损失obj_loss的计算中,yn指胚胎检测模型输出的预测框与原图标签目标框的CIOU,使用CIOU作为该预测框的置信度标签,xn表示胚胎检测模型输出的预测置信度值t时刻得到的随机梯度值。
[0135] 胚胎总损失Loss为以上三个损失的加权和,本发明中将胚胎置信度损失取得最大权重,胚胎矩形框损失和胚胎分类损失的权重次之,因此a=0.4,b=0.3,c=0.3:
[0136] Loss=a*obj_loss+b*loss_box+c*clc_loss
[0137] 【参数调节模块】
[0138] 参数调节模块4通过反向传播计算出胚胎目标检测初始模型M1中至少一个参数的梯度,并且通过Adam优化算法对所述参数进行优化。
[0139] 具体地,Adam优化函数利用梯度的一阶矩估计和二阶矩估计动态调整胚胎训练过程中的每个参数的学习率,相对于其他优化函数,Adam优点主要在于胚胎经过偏置校正后,每一次迭代学习率都有个确定范围,使得胚胎模型参数更加稳定。本发明所采用的Adam优化算法更新如下:
[0140] t←t+1
[0141] 计算梯度:
[0142]
[0143] 更新有偏一阶矩估计:
[0144] mt←β1·mt‑1+(1‑β1)gt
[0145] 更新有偏二阶矩估计:
[0146]
[0147] 计算偏差校正的一阶矩估计:
[0148]
[0149] 计算偏差校正的一阶矩估计:
[0150]
[0151] 更新参数:
[0152]
[0153] 其中,根据公式(1)计算梯度的指数移动平均数mt,m0初始化为0。参考Momentum算法,综合之前时间步的梯度动量。β1系数为指数衰减率,控制权重分配(动量与当前梯度),通常取接近于1的值,默认为0.9。gt为t时刻得到的随机梯度值。
[0154] mt←β1·mt‑1+(1‑β1)gt……(1)
[0155] 根据公式(2)计算梯度平方的指数移动平均数vt,v0初始化为0。系数为指数衰减率,控制之前的梯度平方的影响情况参考RMSProp算法,对梯度平方进行加权均值,默认为0.999。
[0156]
[0157] 由于m0初始化为0,会导致mt偏向于0,尤其在胚胎模型训练初期阶段。因此参考公式(3)mt进行偏差纠正,降低偏差对胚胎训练初期的影响。
[0158]
[0159] 与m0类似,因为v0初始化为0导致胚胎训练初始阶段vt偏向0,根据公式(4)对其进行纠正
[0160]
[0161] 由公式(5)可以更新参数,初始的学习率α乘以梯度均值与梯度方差的平方根之‑8比。其中默认学习率α=0.001;设定∈=10 ,避免除数变为0。
[0162]
[0163] 可以看出,对更新的步长计算,胚胎模型训练能够从梯度均值及梯度平方两个角度进行自适应地调节,而不是直接由当前梯度决定。
[0164] Adam相对于其他优化函数,具有以下优点:
[0165] ①胚胎训练时自动初始化学习率;
[0166] ②胚胎训练时自动调整学习率;
[0167] ③适用胚胎大规模的数据及参数场景;
[0168] ④适用胚胎检测多种不稳定目标函数结合场景;
[0169] 虽然本发明已参照当前的具体实施例来描述,但是本技术领域中的普通技术人员应当认识到,以上的实施例仅是用来说明本发明,在没有脱离本发明精神的情况下还可作出各种等效的变化或替换,因此,只要在本发明的实质精神范围内对上述实施例的变化、变型都将落在本申请的权利要求书的范围内。