技术领域
[0001] 本申请涉及空中手写字符生成技术领域,特别涉及一种生成对抗网络模型。
相关背景技术
[0002] 生成对抗网络(Generative Adversarial Network,简称GAN)是由伊恩·古德费洛等人于2014年提出的,它的设计灵感来源于博弈论中的“零和游戏”。GAN由两个主要组成部分构成:生成器(Generator)和判别器(Discriminator)。这两个部分通过对抗学习的方式相互竞争,从而使得生成器能够不断提高生成逼真样本的能力,而判别器则不断提高辨别真伪样本的能力。
[0003] 其中GAN在图像领域取得了显著成功,其生成器输入为随机噪声。为加速网络收敛,且生成特定类别数据,产生了条件生成式对抗网络(Conditional Gan,简写为CGAN)。然而,现在的生成对抗网络模型主要是图像生成,无法生成长度可变的联机手写字符数据。目前虽然已建成一些公开的空中手写字符数据集(如:IAHCC‑UCAS2016),但是人类语言有很多种,数据采集非常费时费力,且消耗大量财力,给空中手写字符识别算法研究带来了难度。因此,亟需一个能够生成长度可变的联机手写字符数据的GAN模型。
具体实施方式
[0026] 下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述。显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域技术人员在没有付出创造性劳动前提下所获得的所有其他实施例,都属于本申请的保护范围。
[0027] 空中手写脱离了键盘、书写板等传统人机交互设备,是一种新型的更加自然的人机交互方式。稳定、高效的空中手写人机交互系统依赖于高精度的分类模型,而分类模型的训练需要大量的空中手写数据。现实中空中手写数据采集费时费力,且成本高昂。而结合真实空中手写数据,使用电脑人工生成空中手写样本是一种解决训练数据不足的有效途径。基于此,本申请提供一种新型的生成对抗网络模型,该生成对抗网络模型用于端到端生成空中手写字符。空中手写数据是由按照时间先后顺序排列的坐标序列构成,不同的字符包含的坐标点数不同。因此本申请的创新之处在于利用卷积神经网络构成的生成对抗网络模型端到端生成长短不一的时序坐标序列数据,提出的生成式对抗网络模型包括生成器和判别器,判别器的输入为长度不同的真实空中手写字符,而生成器的输入为判别器获得的类均值加噪声。通过实验证明生成的样本混合真实样本能够有效提高分类器的识别精度。这项技术为空中手写人机交互的推广、应用提供了可靠的数据支持。具体如下:
[0028] 本申请实施例提供一种生成对抗网络模型,如图1所示,该生成对抗网络模型包括生成器和判别器。
[0029] 所述生成器和所述判别器均由卷积神经网络构成,所述卷积神经网络的输入卷积层的卷积核大小为2×2,其余各卷积层的卷积核大小为1×3。
[0030] 其中,生成器和判别器均使用卷积网络,用来捕获数据特征,使生成器生成的数据更加接近真实数据又存在一定差异。
[0031] 所述生成器用于生成长度可变的空中手写字符,其输入为各个类的混合了噪声的均值向量。
[0032] 其中,生成器用于生成和训练数据相似的新数据,其输入为各个类的混合了噪声的均值向量,输出为生成的生成长度可变的空中手写字符。如图1所示,M和N分别表示类均值向量和随机噪声,Pvirtual表示生成的虚拟数据(长度可变的虚拟空中手写字符)。
[0033] 其中,空中手写脱离了键盘、书写板等传统人机交互设备,是一种新型的人机交互方式,书写者利用空中手写系统隔空书写,写成的字符只有一个笔画,没有起笔/落笔信息。空中手写字符数据由一系列时间序列坐标构成,不同字符坐标数量不同。
[0034] 所述判别器用于与所述生成器进行对抗训练,以使生成的所述长度可变的空中手写字符更加接近真实数据。
[0035] 其中,利用真实数据单独训练判别器,训练完成后将判别器作为特征提取器提取各个字符类的均值向量,之后将均值向量和随机噪声点混合作为生成器的输入,训练生成对抗网络模型以生成虚拟数据,虚拟数据和真实数据等长且同样由坐标序列构成。
[0036] 其中,生成器和判别器通过对抗学习的方式相互竞争,从而使得生成器能够不断提高生成逼真样本的能力,而判别器则不断提高辨别真伪样本的能力。
[0037] 所述判别器包括全局平均池化层GAP,所述平均池化层GAP用于对所述判别器进行卷积操作后生成的不同尺寸的特征图转化为等长向量。
[0038] 其中,数据(长度可变的空中手写字符)的生成不是凭空产生的,生成式对抗网络需要有真实数据作为判别器的输入,生成器生成的虚拟数据(长度可变的空中手写字符)通过与真实数据进行竞争达到以假乱真的效果,从而生成空中手写字符数据。目前公开的较大的空中手写数据集只有IAHCC‑UCAS2016,该数据集(IAHCC‑UCAS2016)包含国标一级字库3755汉字字符和56个常用汉字字符,共计3811类,每个类包含115个样本。本发明以此为基础,采用公开空中手写数据集IAHCC‑UCAS2016作为真实数据来训练和测试本申请的生成对抗网络模型。具体的,我们选择前500个类进行实验,即训练集和测试集均有500个类,每个类有115个样本,其中每个类的前92个样本(约占80%)用于训练生成对抗网络模型,其余23个样本(共有11500种数据)用于测试判别器的识别精度。每个样本尺寸为2×T,可表示为[x1,x2,……,xT;y1,y2,……,yT],T为坐标点个数,且随着不同字符变化。数据集IAHCC‑UCAS2016中一些样本如图2中的第一行所示。
[0039] 其中,空中手写字符数据集中的数据由坐标序列构成,不是图像,不同字符的坐标序列的长度不同。由于不同的空中手写字符长度不同,经过一些卷积操作后得到的特征图尺寸也不相同,为了方便分类使用全局平均池化层GAP将特征图转化为等长向量,并用于分类。其中,各层特征图大小可表示为1×S,其中S对应于输入的文本序列的长度T,S随着输入字符不同而变化。
[0040] 其中,生成器生成的数据与真实样本尺寸一致为2×T,也可表示为[g1,g2,……,gT;h1,h2,……,hT]。
[0041] 在一些实施例中,所述判别器的输入为真实的空中手写字符样本和所述生成器生成的虚拟空中手写字符样本。
[0042] 其中,判别器的输入为真实的空中手写字符样本和生成器生成的虚拟空中手写字符样本,输出为真假两个类,该判别器主要用于判断输入的数据是真实的训练数据还是生成器生成的虚拟数据。
[0043] 其中,本申请验证生成器生成的数据(空中手写字符)的有效性的方法为:首先通过图示化生成的空中手写字符从而直观判断生成样本是否有效。此外,通过使用真实样本混合生成样本训练分类器(判别器),和只用真实样本训练分类器的识别效果对比,如果识别精度更高,说明生成样本有效。例如,我们采用100个类进行实验。首先用真实数据训练分类器,在测试集上分类得到准确率是99.17%,之后将真实数据和生成的数据混合训练分类器,在测试集上的准确率达到99.30%,显著提高了识别精度,从而证明生成样本是有效的。
[0044] 在一些实施例中,所述生成器包括第一生成网络模块和第二生成网络模块,所述第一网络模块和第二网络模块均为残差结构。
[0045] 如图1所示,生成器包括BlockD(第一生成网络模块)和BlockG(第二生成网络模块)。其中,BlockD和BlockG为残差结构,从而能够在网络深度增加时解决训练时的梯度消失问题。其中,“BlockD 100”表示网络模块BlockD,通道数为100;“BlockG 128”表示网络模块BlockG,通道数为128。
[0046] 其中,BlockD和BlockG的详细结构如图3所示。图3中,“conv1 k:1×3s:1”表示卷积核大小是1×3,卷积步长是1。“Batch Norm”表示批量规范化,“Lrelu”表示非线形激活函数。生成器用于生成空中手写样本,其尺寸与判别器输入的真实样本一致,生成器的训练目标是使得判别器无法准确分辨真实数据和生成的数据。
[0047] 在一些实施例中,所述生成器还包括第三生成网络模块和第四生成网络模块,所述第三网络模块和第四网络模块用于特征提取。
[0048] 如图1所示,生层器还包括BlockE(第三生成网络模块)和BlockF(第四生成网络模块),BlockE和BlockF用于特征提取。其中,其中,“BlockE 128”表示网络模块BlockE,通道数为128;“BlockF 100”表示网络模块BlockF,通道数为100。
[0049] 其中,BlockE和BlockF的详细结构如图3所示。图3中,“conv1 k:1×3s:1”表示卷积核大小是1×3,卷积步长是1。“Batch Norm”表示批量规范化,“Lrelu”表示非线形激活函数。生成器用于生成空中手写样本,其尺寸与判别器输入的真实样本一致,生成器的训练目标是使得判别器无法准确分辨真实数据和生成的数据。
[0050] 在一些实施例中,所述判别器包括第一判别网络模块和第二判别网络模块,所述第一判别网络模块和第二判别网络模块均为残差结构。
[0051] 如图1所示,判别器包括BlockA(第一判别网络模块)和BlockB(第二判别网络模块)。其中,BlockA和BlockB为残差结构,从而能够在网络深度增加时解决训练时的梯度消失问题。其中,“BlockA 128”表示网络模块BlockA,通道数为128;“BlockB 128”表示网络模块BlockB,通道数为128;“BlockB256”表示网络模块BlockB,通道数为256。
[0052] 其中,BlockA和BlockB的详细结构如图4所示。图4中,“conv1 k:1×3s:1”表示卷积核大小是1×3,卷积步长是1;“Lrelu”表示激活函数;“BatchNorm”表示批量规范化。
[0053] 在一些实施例中,所述判别器还包括第三判别网络模块,所述第三判别网络模块用于特征提取。
[0054] 如图1所示,判别器还包括BlockC(第三判别网络模块),BlockC用于特征提取。其中,“BlockC 100”表示网络模块BlockC,通道数为100。
[0055] 其中,BlockC的详细结构如图4所示。图4中,“conv1 k:1×3s:1”表示卷积核大小是1×3,卷积步长是1;“Lrelu”表示激活函数;“BatchNorm”表示批量规范化。
[0056] 在一些实施例中,所述第一判别网络模块、第二判别网络模块和第三判别网络模块均包括Padding层,所述Padding层用于输入和卷积操作后所述特征图长度对齐。
[0057] 如图4所示,BlockA(第一判别网络模块)、BlockB(第二判别网络模块)和BlockC(第三判别网络模块)均包括Padding层,“Padding”表示填充操作,用于输入和卷积后特征图长度对齐。
[0058] 在一些实施例中,所述判别器还包括全连接层,所述全连接层用于将所述生成对抗网络模型学到的“分布式特征表示”映射到样本标记空间。
[0059] 如图1所示,判别器还包括全连接层FC,FC用于将所述生成对抗网络模型学到的“分布式特征表示”映射到样本标记空间。
[0060] 在一些实施例中,所述判别器还包括二分类层。
[0061] 如图1所示,判别器还包括二分类层BCE。判别器的任务就是对输入样本进行分类。本申请按照类进行样本生成,因此分类问题为二分类问题,它的任务是判断输入是真实数据还是生成器生成数据。
[0062] 在一些实施例中,所述判别器的损失函数为二元交叉熵损失函数BCE。
[0063] 其中,本申请使用二元交叉熵损失函数(Binary Cross‑Entropy,BCE)作为判别器的损失函数。因为GAN模型(生成对抗网络模型)中的判别器是一个二元分类器,其任务区分输入样本是真实数据还是生成数据。BCE损失函数是对数似然损失(log‑likelihood loss),在概率框架下,对抗训练通过最大化生成样本属于真实分布的概率,最小化生成样本属于假数据的概率,达到以假乱真的目的。
[0064] 生成网络的损失函数是对抗损失V(D,G)(其中,D表示判别器,G表示生成器)中关于z(混合数据)的项,其损失函数为:
[0065]
[0066] 其中,Ez~pz(z)表示计算期望;z:生成器的输入混合数据;Pz表示其分布;G(z)表示生成器生成的样本;D(G(z)):表示生成样本在判别器上的输出,将其识别为真实样本的概率。
[0067] 生成网络的目标是希望生成的样本越像真的越好,即最小化LG:
[0068]
[0069] 判别网络的损失函数是关于真实样本x和生成样本 的函数,定义为:
[0070]
[0071] 判别器的目标是希望判别器能够更好的区分出生成样本和真实样本:
[0072]
[0073] 利用上述损失函数训练判别器和生成器,使生成样本与真实样本非常接近,达到以假乱真的目标,本申请生成的样本如图2所示。图2中第一行为真实样本,其余为生成的虚拟样本,从图2中可以直观看出生成的虚拟样本与真实样本非常接近,但又有差异,从而可以有效地丰富训练数据。
[0074] 如图5所示,图5为通过本申请实施例提供的生成对抗网络模型生成长度可变的联机手写字符数据的流程图。从图5中我们可以看到,实现该技术主要分为五个步骤:(1)准备一定量的真实的空中手写字符数据(如:IAHCC‑UCAS2016);(2)构建生成对抗网络模型;(3)训练构建好的生成对抗网络模型;(4)生成虚拟空中手写字符样本;(5)对生成样本的有效性进行验证。
[0075] 由上可知,本申请实施例为支持空中手写字符识别算法研究,采用数据增强技术,构建了一个端到端GAN模型用于生成不同长度的空中手写字符,自动生成大量虚拟数据配合真实数据进行分类模型训练。而现有技术中的GAN模型处理的是固定尺寸图像,不能胜任可变长度字符生成,因此还克服了现有基于卷积神经网络的生成式对抗网络模型无法生成可变长度的时间序列的缺陷。与传统的GAN模型用于固定尺寸图像数据生成不同,本申请用于长度可变的空中手写坐标序列数据的生成。作为一种新型的GAN结构,本申请为空中手写人机交互的发展与应用提供技术支持。相关理论、模型为推动模式识别、深度学习等的发展具有重要意义。
[0076] 上述所有可选技术方案,可以采用任意结合形成本申请的可选实施例,在此不再一一赘述。
[0077] 具体实施时,本申请不受所描述的各个步骤的执行顺序的限制,在不产生冲突的情况下,某些步骤还可以采用其它顺序进行或者同时进行。
[0078] 由上可知,本申请实施例提供的生成对抗网络模型,包括生成器和判别器;生成器和判别器均由卷积神经网络构成,卷积神经网络的输入卷积层的卷积核大小为2×2,其余各卷积层的卷积核大小为1×3;生成器用于生成长度可变的空中手写字符,其输入为各个类的混合了噪声的均值向量;判别器用于与生成器进行对抗训练,以使生成的长度可变的空中手写字符更加接近真实数据;判别器包括全局平均池化层GAP,平均池化层GAP用于对判别器进行卷积操作后生成的不同尺寸的特征图转化为等长向量。本申请的生成对抗网络模型结构,实现了对可变长度的空中手写字符的生成,帮助分类器(判别器)提高了识别精度,有效缓解了空中手写数据匮乏问题,为空中手写交互系统的发展提供了技术支持,克服了现有基于卷积神经网络的生成式对抗网络模型无法生成可变长度的时间序列的缺陷。
[0079] 以上对本申请实施例所提供的生成对抗网络模型进行了详细介绍。本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。