技术领域
[0001] 本发明涉及文本分类领域,具体为基于BiLSTM和混合注意力ResNet的文本分类网络。
相关背景技术
[0002] 文本分类是是自然语言处理中应用最广泛也是最重要的一个领域,目前文本分类已经在实际生活中的多个领域得到了大量的应用,例如情感分析、问题回答、主题标记等。早期的文本分类任务可以通过手动标记文本数据来完成,然而随着互联网的快速发展,需要处理的文本文数据的数量呈指数级增长,传统的手动处理数据面临巨大挑战,并且该方法的分类准确率容易受到专业人员的知识储备和精力等人为因素的影响,因此文本自动分类变得非常重要。
[0003] 自动文本分类方法可以分为基于规则的方法,基于机器学习和深度学习的方法。基本规则的方法通过一组事先定义好的规则来对文本进行分类,但是定义规则需要完整的领域知识。基于机器学习的方法已经被证实在文本分类领域非常有效。机器学习方法在分类过程中大多分为三步:数据预处理,特征提取(如词袋模型和TF‑IDF等方法)和分类方法(使用朴素贝叶斯( Bayes,NB)、K近邻算法(K‑Nearest Neighbor,KNN)、支持向量机(Support Vector Machines,SVM)等)尽管这些方法在分类任务中表现出色,但在实际使用过程中存在明显的限制。例如,非常依赖需要耗费大量时间和成本的特征工程,过于依赖领域知识导致该类方法在新的分类任务中效果有限。此外,这些方法通常忽略了文本数据中的序列信息、上下文信息以及单词本身的语义信息,这与人类理解句子的过程不符。
[0004] 近年来,深度学习方法在自然语言处理领域取得了突破性进展,与现有的的传统方法相比,深度学习方法避免了人工设计规则和特征的过程,可通过学习一种深层次非线性网络结构,自动从样本中挖掘出文本中的本质特征,能够捕获文本数据的深层次语义表征信息,使得识别和分类更加准确、可靠,以卷积神经网络(CNN)和循环神经网络(RNN)取得的效果最为显著。
[0005] CNN能够从数据中学习局部特征,不具备学习顺序相关性的能力,与CNN相比,RNN在顺序相关性的学习上有明显优势,无法并行地提取特征,文本分类实际上也是一种建模任务,由于RNN的特性,在文本相关的任务中RNN应用更加广泛,然而,对于长文本数据,传统的RNN会出现梯度爆炸和消失的问题,长短期记忆(LSTM)是一种以长短期记忆单元为隐藏单元的RNN架构,可以捕获长期依赖性,能够有效解决梯度消失和爆炸的问题,因为LSTM提取高级文本信息的强大能力,它在NLP中发挥着重要的作用,BiLSTM是LSTM的进一步发展,BiLSTM结合了前向隐藏层和后向隐藏层,可以访问前后的上下文表示。与BiLSTM相比,LSTM仅利用单向的上下文信息。因此,BiLSTM比LSTM在NLP任务中效果更好。
[0006] 对于现有的文本分类任务,仅使用BiLSTM也无法达到要求,原因在于文本的向量表示是高维向量,面对高维向量BiLSTM在进行分类任务时面临着网络参数大,优化困难,分类精度不够高的问题,针对现有文本分类模型存在的局限性,本发明提出了基于BiLSTM和混合注意力ResNet的文本分类网络。
具体实施方式
[0055] 下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施条例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
[0056] 实施例
[0057] 请参阅图1,图为本发明的模型框架图,基于BiLSTM和混合注意力ResNet的文本分类网络,所述Bi‑HybridAtt‑ResNet文本分类网络模型包括:
[0058] BiLSTM,用于建立长期依赖关系,BiLSTM从两个方向处理输入句子,即从左到右和从右到左,可以获得具有更多上下文信息的特征表示,采用了两个LSTM,其中一个处理过去的信息,另一个处理未来的信息,将两个LSTM的输出结合起来形成最终输出;
[0059] 混合注意力,用于提高分类精度,由多头注意力和坐标注意力组成,利用残差连接;
[0060] ResNet,用于增加网络深度,加快模型收敛,ResNet由卷积层和池化层组成,卷积层用于特征提取,池化层用于特征处理,数据进入ResNet经过池化层和多层卷积后,输出进入全连接层。
[0061] 词嵌入作为一种特征学习技术,其目标是将词汇表中的每个单词映射到一个特定维度实数向量,最后将向量输入模型进行训练,现有的词嵌入模型包括Word2Vec、GloVe、FastText等,本发明采用的是100‑d GloVe模型,GloVe模型是斯坦福大学在2014年提出的方法,通过将单词映射到向量空间之中,使词向量能够包含尽可能多的语义和语法信息,GloVe模型通过构建词共现矩阵,将语料库中每个单词的共现信息表示为矩阵形式,然后,对这个矩阵进行分解,生成每个单词的低维表征向量,GloVe结合了全局统计信息和局部上下文特征,从而生成的词向量既反映了语料库的整体统计特性,又捕捉了单词的局部上下文关系。
[0062] LSTM是对RNN的改进,通过在LSTM中引入门控机制来解决RNN的长距离依赖问题,和RNN一样,LSTM的输出受到当前输入和之前输出的影响,输入门、遗忘门和输出门,三个状态门和一个细胞状态构成一个LSTM单元,细胞状态与门是LSTM的关键部分,它们一起保留了输入元素之间的长程依赖性,通过这些门来调节进出LSTM的信息流,图2显示了LSTM的架构,等式(1)至(7)描述了输入门、遗忘门和输出门的工作原理,其中:
[0063] xt,ht:代表时间t时的输入状态和隐藏状态
[0064] Ct:代表时间t时的候选细胞状态和最终细胞状态
[0065] Ft,It,Outt:代表遗忘门、输入门和输出门
[0066] Wf,Wi,Wo:代表遗忘门、输入门、输出门的权重参数
[0067] Wc:代表输入单元权值参数
[0068] bf,bi,bo:代表遗忘门、输入门、输出门的偏置参数
[0069] σ:代表sigmoid函数
[0070] *,+:代表逐点乘法和加法
[0071] ++:代表串联运算符
[0072] ut=ht‑1++xt (1)
[0073] Ft=σ(Wfut+bf) (2)
[0074] It=σ(Wiut+bi) (3)
[0075]
[0076]
[0077] Outt=σ(Wout+bo) (6)
[0078] ht=Outt*tanh(Ct) (7)
[0079] 输入门定义了当前输入中应该关注哪些数据,而遗忘门则决定了哪些数据可以从之前的状态中保留,哪些数据可以从之前的状态中删除,输出门决定发送到下一个时间步的数据,单向LSTM通过处理右侧或左侧的输入来创建上下文表示,然而,前文和后文都会对单词的语义产生影响,仅处理单侧输入无法获得足够的语义信息,BiLSTM从两个方向处理输入句子,即从左到右和从右到左,这样可以获得具有更多上下文信息的特征表示,图3展示了BiLSTM的结构,它采用了两个LSTM,其中一个处理过去的信息,另一个处理未来的信息,将两个LSTM的输出结合起来形成最终输出。
[0080] 注意力机制的基本思想是根据每个信息元素的相关性和重要性分配不同的权重,用以强调那些对当前任务或上下文更为关键的部分,本发明提出的混合注意力机制由多头注意力和坐标注意力组成,利用残差连接来增强特征的重复利用。
[0081] 多头注意力机制的结构如图4所示,从图中可以看出多头注意力的核心部分是缩放点积运算,结构如图5所示,这与传统的注意力机制不同,传统的注意力机制只有一个前向神经网络,多头自注意力机制通过矩阵点积来加速计算,该模型给出了一个具有n个查询向量的矩阵Q,其中 键值 编码器的隐藏层值是 计算注意力机制的权重矩阵如下:
[0082]
[0083] 其中d是查询矩阵和键值的维度。两种最常用的注意力函数是加法注意力和点积(乘法)注意力,除了缩放因子 之外,点积注意力与本发明的算法相同,加法注意力使用具有单个隐藏层的前馈网络来计算兼容性函数,对于较小的d值,两种注意力的表现相似,甚至加法注意力表现更好,原因是当d值比较大时,点积会变大,softmax函数将具有较小的梯度值,为了减少这种影响,选择点积并乘以
[0084] 多头注意力计算首先输入一个矩阵向量 它使用不同的初始化矩阵线性计算h次编码器的键值和隐藏层,每次注意力机制并行执行,最终得到dv维度的输出值,然后将这些值连接一次并进行线性计算,得到最终的输出值,系数矩阵对应的第i个头查询、键值和编码器隐藏层值记为 然后通过缩放后点积计算查询与键值之间的相关性,最后输出混合语义表示Hi,数学公式如下:
[0085]
[0086] 最后,并行多头通过串联操作拼接形成向量M,然后通过线性计算得到最终的语义表示Y,计算公式如下:
[0087] M=Concat(H1,...,Hh) (10)
[0088]
[0089] 其中
[0090] 多头自注意力可以增强句子之间的表示能力,前向传播速度快并且多头自注意力计算实现了更高的并行性,这些优点有助于更快更全面的提取语义信息。
[0091] 坐标注意力(图6)利用两个一维全局池化操作分别将沿垂直和水平方向的输入特征聚合成两个单独的方向感知特征图,然后,这两个嵌入特定方向信息的特征图被分别编码为两个注意力图,每个注意力图捕获输入特征图沿一个空间方向的长程依赖性,通过这种方式,将位置信息保留在生成的注意力图中,然后,通过乘法将两个注意力图应用于输入特征图,以强调感兴趣的表示。
[0092] 坐标注意力中使用的是四维向量B×C×H×W,文本任务中的数据通常表示为三维向量B×C×S,为了使用完整的坐标注意力机制,数据在进入坐标注意力之前我们给数据填充一维,将其从三维变成四维B×C×S×F,F代表本文填充的第四维。
[0093] 坐标注意力将全局池化转化为一对一维特征编码,获得精确的位置信息,注意力模块通过位置信息捕获空间上的远程交互,具体来说,给定输入X,使用池化内核的两个空间范围(H,1)或(1,W)分别沿水平坐标和垂直坐标对每个通道进行编码,因此,第c个通道在高度h处的输出可以表示为:
[0094]
[0095] 同样地,宽度为w的第c个通道的输出可以写为:
[0096]
[0097] 上述两种变换分别沿着空间的两个方向聚合特征,产生一对方向感知特征图,这两个变换还允许注意力模块沿着一个空间方向捕获长距离依赖关系,并沿着另一个空间方向保留精确的位置信息,这有助于网络更准确地定位感兴趣的对象,产生特征图后,先将它们拼接起来,然后发送到共享的1×1卷积变换函数F1,得到:
[0098]
[0099] 其中[·,·]表示沿空间维度的串联操作,δ是非线性激活函数, 是在水平方向和垂直方向上编码空间信息的中间特征图,r是用于控制块大小的缩减比率,沿着h w空间维度将f分成两个独立的张量 和 另外两个1×1卷积变换F 和F
h w
用于分别将f和f变换为与输入X具有相同通道数的张量,得到:
[0100] gh=σ(Fh(fh)) (15)
[0101] gw=σ(Fw(fw)) (16)
[0102] σ是sigmoid函数,然后输出的gh和gw被扩展并分别用作注意力权重,坐标注意块Y的输出可以写为:
[0103]
[0104] 坐标注意力模块将沿水平和垂直方向的注意力同时应用于输入张量,两个注意力图中的每个元素都反映了相应行和列中是否存在感兴趣的对象,这种编码过程使坐标注意力能够更准确地定位感兴趣对象的确切位置,从而帮助整个模型更好地识别。
[0105] 将多头注意力和坐标注意力结合,形成混合注意力模块,能够结合两中注意机制的优势,在增强句子表示能力,提高并行性的基础上,更精确地聚焦感兴趣目标的位置,提升整个模型的识别能力。
[0106] 单纯使用BiLSTM无法提取深层次特征,取得更好的分类效果,使用多层卷积层可以挖掘文本的深层特征,但当神经网络参数过多时,会出现梯度消失和梯度爆炸导致参数更新停滞,为此引入ResNet,如图7所示,ResNet由卷积层和池化层组成,卷积层用于特征提取,池化层用于特征处理,数据进入ResNet经过池化层和多层卷积后,输出进入全连接层,ResNet使用快捷连接,恒等映射重构学习过程,重定向网络信息流,增加网络深度,这种架构提高了模型的表示能力,加快了网络收敛速度,有效解决了网络退化和梯度消失等常见问题,相邻的卷积层通过快捷连接形成残差块,多个残差块顺序排列组成ResNet,残差块如图8所示,定义为:
[0107] y=F(x,{Wi})+x (18)
[0108] 这里x和y是所考虑层的输入和输出向量,函数F(x,{Wi})表示要学习的残差映射,运算F+x通过快捷连接和逐元素加法来执行。
[0109] 式(18)中x和F的尺寸必须相等,如果尺寸不相等,可以通过线性投影Ws来匹配尺寸:
[0110] y=F(x,{Wi})+Wsx (19)
[0111] 与原始映射相比,残差映射F对输出的变化更加敏感,有更广的参数调整范围,能够加快模型的学习速度,提高了网络性能。
[0112] 使用交叉熵作为损失函数可以降低随机梯度下降过程中梯度消失的风险,损失函数可以表示如下式,二分类任务为(20),多分类任务为(21)
[0113]
[0114]
[0115] 其中N是训练样本的数量,M是类别数量,i表示训练样本,yi和yic是样本的标签,pi和pic是Bi‑HybridAtt‑ResNet的输出。
[0116] 本实施例通过对各种二分类和多分类数据集进行实验分析来评估提出的Bi‑HybridAtt‑ResNet模型,采用四个广泛使用的数据集来对本发明的模型进行评估,这些数据集在类型、大小、类别数和文档长度方面各不相同,表1显示了数据集的统计数据,使用Yelp Review Polarity(Yelp P)作为情感分析的数据集,使用AG’s News(AGNews)、DBPedia和Yahoo!Answers(Yahoo)作为主题分类的数据集,下面是数据集的详细描述:
[0117] Yelp Review Polarity:我们使用的极性Yelp数据集是Yelp数据集的子集,Yelp数据集来自于2015年的Yelp数据集挑战赛。极性数据集每个类别有280,000个训练样本和19,000个测试样本。
[0118] AGNews:是从AG的新闻文章语料库选择了四个最大的类构建的数据集,仅使用标题和描述字段。每个类包含30000个训练样本,1900个测试样本。
[0119] DBPedia:DBPedia是一个使用维基百科最常用的信息框生成的大型多语言知识库,从DBpedia 2014中选取了14个不重复的类构建的数据集,每个类包含40,000个训练样本和5,000个测试样本。用于此数据集的字段包含每篇维基百科文章的标题和摘要。
[0120] Yahoo!Answers:该数据集来自Yahoo!Answers综合问答1.0版数据集,从中选取了10个最大的主要类别构建了Yahoo!Answers数据集。每个类包含140,000个训练样本和5,
000个测试样本。我们使用的字段包括问题标题、问题内容和最佳答案。
[0121] 基准模型如下:
[0122] TextCNN:一种经典的基于CNN的分类器,利用一维卷积运算和最大池化来提取特征。
[0123] 基于CNN的模型:词级CNN和字符级CNN。
[0124] BERT:带有分类器的基于BERT的预训练模型。
[0125] 表1本文实验中使用的四个文本分类数据集。
[0126]
[0127] 文本分类任务常用的评价指标是准确率、精确率、召回率和F1分数,下面对这些指标进行详细说明。
[0128] 准确率指的是算法在各种预测中产生的正确假设的比例。
[0129]
[0130] 精确率是被正确预测为阳性的样本数占所有被预测为阳性样本数的比例。
[0131]
[0132] 召回率是正确阳性结果的数量除以所有相关样本(包括本来应该是阳性的样本)的数量
[0133]
[0134] F1分数是精度值和召回值的调和平均值。
[0135]
[0136] 其中TP:真阳性,TN:真阴性,FP:假阳性,FN:假阴性。
[0137] 涉及的超参数设置如下,ResNet滤波器大小为3,BILSTM隐藏层层数设置为1,大小设置为64,多头注意力的注意力头数目设置为8,采用Adam优化器,对所有非线性激活函数‑3使用ReLU,在训练期间应用比率为0.6的Dropout正则化,学习率设置为3×10 ,Batchsize设置为512模型训练20轮,本发明提出模型的实验环境编程语言是Python 3.8,深度学习框架是PyTorch1.8.0,硬件环境为英伟达RTX 4070GPU。
[0138] 本发明除了使用上述五个基准模型外还新增了两个流行的深度学习模型FastText和DPCNN,来与本发明提出的模型进行对比实验分析,四个数据集的对比结果分别在下面四个表上呈现,这个四个表对应的数据集分别是AGNews(表2)、DBpedia(表3)、Yelp P(表4)、Yahoo(表5),图9是所有模型在四个数据集上的准确率柱状图。
[0139] 表2所有模型在AGNews数据集上的表现,最佳结果以粗体显示。
[0140]
[0141] 表3所有模型在DBpedia数据集上的表现,最佳结果以粗体显示。
[0142]
[0143] 表4所有模型在Yelp P数据集上的表现,最佳结果以粗体显示。
[0144]
[0145]
[0146] 表5所有模型在Yahoo数据集上的表现,最佳结果以粗体显示。
[0147]
[0148] 如上面的表和图9所示,本发明提出的Bi‑HybridAtt‑ResNet在DBpedia、Yelp P和Yahoo三个数据集上的表现优于其他模型,在AGNews数据集上的表现仅次于BERT,这验证了其有效性,Bi‑HybridAtt‑ResNet融合了BiLSTM、混合注意力机制以及ResNet,通过这个方式弥补了单一RNN,CNN的不足,模型既能够更好的捕捉上下文语义信息,处理长距离依赖,也能给予关键信息更多的关注和权重,捕捉局部特征,增加特征的复用程度,减小过拟合风险,加速收敛。
[0149] 为了验证本文提出的模型配置的合理性,进行了四次消融实验,基于BiLSTM网络进行一步步改进,首先,在BILSTM网络后添加多头注意力模块,形成BiLSTM+Muti‑Head网络,接下来在网络底部添加坐标注意力模块,形成BiLSTM+Hybridatt网络,最后,添加ResNet,形成Bi‑HybridAtt‑ResNet。所有模型均在AGNews数据集上进行训练和测试,评价指标数据如表6所示
[0150] 损失函数曲线和评价指标数据如表6和图10所示,从图10可以看出,随着网络模型结构的不断改进,模型在数据集上的精度不断提高,曲线的波动也越来越小,如表6所示,随着网络结构的不断完善,各评价指标的值都在增加,数据相对稳定,多头注意力、坐标注意力和ResNet极大地改善了网络结构,使网络能够提取更多信息,获得更高的准确率。
[0151] 表6AGNews数据集上的消融实验
[0152]
[0153]
[0154] 以上内容是结合具体实施方式对本发明作进一步详细说明,不能认定本发明具体实施只局限于这些说明,对于本发明所属技术领域的普通技术人员来说,在不脱离本发明的构思的前提下,还可以作出若干简单的推演或替换,都应当视为属于本发明所提交的权利要求书确定的保护范围。