摘要
针对海洋气象传感网(MMSN)环境下海洋移动终端资源受限和网络流量不平衡导致网络入侵难以被准确检测的问题,提出了一种基于移动边缘计算的MMSN物理架构和一种基于平衡生成对抗网络的入侵检测模型。首先,利用改进的平衡生成对抗网络对不平衡数据进行数据增强。其次,利用基于分组卷积的轻量级网络对入侵数据进行分类。最后,通过计算机仿真证明了所提模型较传统数据增强模型具有更高识别各类攻击的能力,尤其是针对MMSN的少数类样本攻击。
0 引言伴随5G新基建系统的完善与6G标准的制定,海洋物联网(MIoT, maritime Internet of things)迎来了新的快速发展契机,并加快了我国海事信息系统和通信基础设施现代化。海洋气象传感网(MMSN, maritime meteorological sensor network)作为MIoT不可或缺的组成部分,通过IoT设备与海上智能交通、海事智能感知和海洋气象灾害预警等系统共享关键的接口和信息,并全方位地为MIoT服务提供精准的气象数据。
利用移动边缘计算(MEC, mobile edge com puting)将部分数据卸载到近端 MEC 服务器上处理,可以及时有效地处理各类海洋移动终端收集的海量数据,满足低时延的海事应用服务需求。然而,IoT 和先进通信技术在带来便利的同时,也使 MMSN 存在较大的攻击面。设备间的频繁信息交互易被入侵者侦测,进而增加了设备受到网络攻击的风险,甚至可能对整个 MMSN 造成严重的破坏。因此,为保障 MMSN 的完整性、可靠性以及可用性,亟须设计一个有效和可靠的安全机制。
防火墙、加密技术和入侵防御等传统的安全防御机制大多基于启发式和静态攻击签名,难以识别网络中日益多样化的攻击。近年来,基于人工智能的网络入侵检测系统(NIDS, network intrusion detection system)已被广泛应用到智慧电网、工业4.0和车联网等领域,并可以提供更可靠的安全服务保障。然而,与传统陆地IoT入侵检测不同,设计面向MMSN的入侵检测面临如下挑战。
1) MMSN 中的海洋移动终端分布范围广泛且稀疏,缺乏中心基础设施,受到网络攻击的方式相对隐蔽,导致收集的网络流量数据呈现高度不平衡特性。这严重限制了现有入侵检测模型的性能。
2) 海洋无线通信环境复杂多变,各类海洋移动终端计算与存储资源差异大,能耗敏感度不一,移动终端的强异构性导致部分终端出现入侵检测任务处理超负荷情况,实现可持续的入侵检测是保障MMSN的关键。
通过参考 MEC 卸载技术在海洋观监测传感网的研究以及现有的入侵检测研究,结合 MMSN 中入侵检测存在的挑战,以提供高可靠、可持续的入侵检测能力为目标,本文研究了 MMSN 中的入侵检测技术,主要贡献如下。
1) 提出一种基于MEC的MMSN物理架构,海洋移动终端可将数据处理和入侵检测任务部分卸载至近端MEC服务器上处理,保障MMSN的低时延和安全服务需求。
2) 针对入侵检测数据不平衡问题,提出一种改进的平衡生成对抗网络(A-BAGAN, advanced balancing generative adversarial network)数据增强模型来生成少数类攻击样本,改善入侵检测分类器受训练数据集不平衡的影响。
3) 针对MMSN中海洋移动终端的强异构性,本文提出一种基于分组卷积神经网络的轻量级入侵检测分类器LGCNN(lightweight group convolu tional neural network),在准确识别各类攻击的同时,降低对终端计算与存储资源的消耗。
1 相关工作1.1 智能型入侵检测系统NIDS 用来实时监控网络数据传输的异常行为,同时对检测到的网络攻击采取可应对的安全响应措施。NIDS根据检测技术可以分为基于签名和异常数据。前者通过已有攻击签名库对待检测数据特征进行匹配,该方法对已知攻击的识别效果较好,但是难以识别未知攻击。后者将正常数据与待检测数据之间的差异作为异常判断准则,其优点是可以检测出未知攻击,但常常伴随着较高的误报率。
近年来,基于机器学习(ML, machine learning)和深度学习(DL, deep learning)的智能型入侵检测方法已经得到深入的研究。文献提出了一种基于互信息的最优特征选择算法,并利用最小二乘支持向量机(LSSVM, least squares support vector machine)进行入侵检测,提高了检测精度,但时间复杂度有所提高。文献研究了特征选择对入侵检测分类器性能的影响,提出了一种基于余弦相似度的智能鸽群算法来选取最优特征子集,相较于传统算法具有更快的收敛速度。然而,该算法仍存在时间复杂度高的问题,且未对攻击分类。文献实现了基于单类支持向量机的入侵检测模型,从直方图角度提取网络数据包的特征,提高了检测精度,但增加了训练复杂度和部署成本。上述基于传统 ML 的入侵检测方法在处理大量高维数据时出现能力不足的问题,并且不能很好地处理数据不平衡问题。
相较于ML的入侵检测,基于DL的方法因其具有强大的数据表达能力,通常可获得更好的检测性能。文献在正常数据和异常数据的低维表示服从不同分布的假设下,提出了一种基于表示学习的异常检测方法,但该文献仅考虑了二分类情况。同时,为了降低模型的复杂度,文献提出了一种轻量级的入侵检测模型,使用改进的自动编码器(AE, autoencoder)来提取数据的特征,获得了较高的检测率,但其未能对攻击进行有效分类。为了减少对经验性知识的依赖,文献提出了一种基于卷积神经网络(CNN, convolutional neural network)的入侵检测模型,使用多目标优化算法搜索CNN结构,可以在参数空间获得较优的解,但搜索需要的时间开销大,难以实时部署。文献将CNN和长短期记忆(LSTM, long short-term memory)网络相结合,提出了一种分析时空特征的入侵检测模型,借助注意力机制充分融合时间和空间特征,提高了检测性能。然而,文献提出的模型参数量庞大,无法应用于资源受限场景。文献受群体分组决策启发,提出了一种新颖的RANet入侵检测模型,使用分组-门控卷积模块有效提取输入数据的特征,并减少了模型需要学习的参数量。为了进一步提高检测的准确率,文献提出了一种多阶段入侵检测模型,使用人工蜂群算法选取特征,并使用黑寡妇算法来优化卷积LSTM结构的检测模型,在多个公共数据集上取得了较高的检测率。然而,文献提出的模型复杂度高,占用资源大。文献设计了一种可解释性多输出结构的神经网络入侵检测模型,利用二分类输出结果辅助多分类决策,并使用注意力机制来解释特征的重要性,但检测精度较低。上述DL的研究工作增强了对高维数据的表达能力,但是依然存在难以处理数据不平衡的问题。由此可见,这些模型并不能很好地识别出频率低的攻击。
入侵检测数据集是典型的不平衡数据集,而上述基于智能型NIDS的研究又极少关注少数类攻击样本的检测效果,致使入侵者可以有针对地发起少数类攻击,从而不可避免地存在数据泄露的风险。因此,亟须在入侵检测模型训练之前,对训练数据集进行平衡处理。
1.2 数据不平衡处理方法数据不平衡指隶属某一类的样本数量远低于其他类别,该问题广泛存在于银行数据、医疗数据等领域。解决这一问题可以从算法和数据两方面入手进行研究。在算法方面,可以尝试去适应基于不平衡数据集的训练,进而提高少数类样本的识别精度,如代价敏感函数。然而,设计合适的代价系数矩阵需要专家知识,且相当复杂。在数据方面,可以通过增加少数类样本数量来处理不平衡问题,目前已经成为主流的研究方向。
在入侵检测领域,数据不平衡问题是制约检测性能的重要因素。这源于真实网络环境中收集到的原始流量大部分都是正常流量,某些低频率的攻击流量数量较少。因此,文献构建了门控循环单元模型来检测网络中的分布式拒绝服务(DDoS,distributed denial of service)攻击,并使用合成少数过采样技术(SMOTE, synthetic minority oversampling technique)对训练集中少数类样本进行了扩充,取得了显著的性能优势。然而,该文献采用的数据集中正负样本分布与实际不符。此外,随机欠采样、自适应综合(ADASYN, adaptive synthetic)过采样和随机过采样(ROS, random over sample)技术也常用于入侵检测数据不平衡处理。然而,传统的欠采样技术可能丢失多数类样本的有用信息,而过采样技术又无法很好地学习真实的数据分布且易受噪声点影响,导致生成的样本分布和真实数据分布差异很大。因此,上述传统数据增强方法不能充分地利用数据的深层次信息。这就表明不平衡数据的分布无法被准确地映射出来,同时可能对分类器的性能造成损害。
最近,生成对抗网络(GAN, generative adversarial network)在处理入侵检测数据不平衡问题上受到了极大的关注。文献利用条件 GAN(CGAN, conditional GAN)来生成少数类攻击样本。使用前馈神经网络从网络流量中生成特征向量,再使用CGAN为少数类攻击生成新样本,提高了少数类攻击检测率。文献受SMOTE-SVM数据合成思想启发,利用辅助分类器GAN(ACGAN, auxiliary classifier GAN)来生成支持向量附近的困难样本。文献与文献相似,均采用ACGAN来生成少数类攻击样本,不同之处在于文献将一维网络数据转化成二维图像数据。其中,文献使用常规的顺序排列方式将一维网络数据转化为图像数据;而文献则充分考虑图像中像素点与邻居点之间的相关性,运用t-SNE(t-distributed stochastic neighbor embedding)技术将一维网络数据转化为二维图像数据。上述工作均未充分地考虑训练集中少数类样本的数量可能会导致 GAN 无法准确学习少数类样本分布的问题。例如,ACGAN 在不平衡数据集上训练时,判别器的2个输出是相互矛盾的,这可能使生成的样本无法兼顾真实性和类别属性。此外,利用 GAN 作为数据增强的相关工作也并未使用度量性指标来衡量 GAN 生成的样本,缺乏对生成样本的有效性评估。值得说明的是,在计算机视觉领域可使用IS(inception score)和FID(Fréchet inception distance)来衡量GAN生成图像样本的质量,但上述指标在入侵检测领域并不适用。
2 海洋气象传感网入侵检测系统2.1 海洋气象传感网物理架构图1展示了本文提出的基于MEC的MMSN物理架构,该架构主要由陆地云服务器、卫星、MEC服务器和各类海洋移动终端组成。在 MMSN 中, MEC 服务器可以根据海域地理位置部署在不同基础设施上。对于近海区域和远海区域,MEC服务器分别部署在海岛基站和远海基站上,该服务器集合可由<span class="MathJax" id="MathJax-Element-1-Frame" tabindex="0" data-mathml="S={s1,s2,⋯,sM}" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">S={s1,s2,⋯,sM}S={s1,s2,⋯,sM}表示。对于每一个MEC服务器,均可通过卫星链路与陆地云服务器进行通信;船舶、海上浮标、无人机、探空气球和无人飞艇等移动终端在指定区域运行,这些移动终端集合表示为<span class="MathJax" id="MathJax-Element-2-Frame" tabindex="0" data-mathml="MT={mt1,mt2,⋯,mtL}" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">MT={mt1,mt2,⋯,mtL}MT={mt1,mt2,⋯,mtL};各移动终端上集成多个IoT 设备用于收集温湿度、气压、风向、风速和能见度等气象数据,移动终端mti上的 IoT 设备集合表示为<span class="MathJax" id="MathJax-Element-3-Frame" tabindex="0" data-mathml="Ik={i1,i2,⋯,iN}" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">Ik={i1,i2,⋯,iN}Ik={i1,i2,⋯,iN},1≤k≤L。
为了缓解部分移动终端计算资源不足的问题,可以通过正交频分多址接入(OFDMA, orthogonal frequency division multiple access)通信方式向近端MEC服务器进行部分任务的卸载。此外,考虑到海洋区域辽阔而 MEC 服务器覆盖范围有限,对于未覆盖的移动终端可通过卫星通信与陆地云服务器通信。由于各类移动终端隶属不同的海事机构和企业,数据中存在敏感信息。因此,在本文构建的MMSN物理架构中,假设陆地云服务器和MEC服务器都是诚信可靠的,对移动终端的信息和数据内容不感兴趣,能够严格履行卸载任务处理与计算结果反馈职责,且各移动终端之间不进行任务卸载。与现有的集中式 MMSN 物理架构相比,本文通过引入 MEC 技术可显著降低各终端入侵检测的响应时间、能耗和丢包率,特别是存在网络流量负荷的情况下。
2.2 面向海洋气象传感网的入侵检测当各类海洋移动终端在指定区域工作时,IoT设备每时每刻都在产生和收集数据,常见的MMSN访问会使这些设备容易受到网络攻击。在MMSN中,移动终端大至无人飞艇和货轮,小至无人水面艇、探空气球和海上浮标,根据移动终端的可用算力、存储资源、运行速率和安全程度等异构的特征属性,入侵者能够有针对性地发起不同类型的攻击。
探测(Probe)攻击。无人飞艇、货轮等大型移动终端航速相对缓慢稳定,计算处理能力和存储能力强,能源充足,安全程度高。然而,大型移动终端部署的传统安全防御机制仍然存在一定的漏洞。入侵者可根据这些漏洞对大型移动终端发起探测攻击,在不被发觉的情况下,收集有关 MMSN的有价值信息、网络拓扑结构和设备特点(如设备的型号、功能和支持的网络协议等),为进一步攻击做准备。
DoS/DDoS攻击。无人机、中型船舶等中型移动终端运行速率较快,计算处理能力和存储能力一般,能源较充足,安全程度一般。当中型移动终端频繁通过开放的无线信道与MEC服务器卸载通信时,容易被入侵者通过监听信道的方式发现。这使入侵者可以在短时间内通过一对一或多对一的方式向这些目标发送大量的无效请求,导致设备不胜负荷,这可能会中断MIoT服务或者阻碍合法请求的实现,甚至可能导致上述移动终端偏离预定航线。
图1
图1 基于MEC的MMSN物理架构
暴力破解攻击。海上浮标、无人水面艇和探空气球等小型移动终端计算处理能力和存储能力弱,能源受限,安全程度低。入侵者可通过枚举、字典等暴力破解方式对小型移动终端上的设备进行大量认证,从而获取设备信息和敏感数据。
因此,为避免海洋移动终端上的各类IoT设备在正常运行过程中受到网络攻击,在各移动终端和MEC服务器上部署NIDS来有效地检测网络攻击是至关重要的。NIDS 检测过程主要包括以下 4 个步骤:1) 各移动终端利用抓包工具对经过其IoT设备的网络流量进行捕获;2) 将捕获的原始数据包转化为观测值,每组观测值包含有关网络连接的统计信息和属性,这些观测值有助于识别网络攻击;3) 将上述观测值进行预处理后输入检测分类器中检测;4) 检测分类器根据输入数据进行判别,然后输出检测结果。
对于资源约束型移动终端,可将部分检测任务卸载到近端 MEC 服务器上处理,MEC 服务器可间歇性地访问陆地云服务器以提供足够的计算资源。本文假设从各移动终端捕获的网络数据具有相同的特征空间,且构建的NIDS拟在陆地云服务器上进行离线集中式训练,然后在线分布式部署到移动终端和MEC服务器上,因此检测模型的训练过程不会占用大量海洋移动终端的计算与存储资源。
3 基于平衡生成对抗网络的入侵检测模型针对 MMSN 中存在的网络安全隐患,本文提出了一种基于平衡生成对抗网络(BAGAN, balancing generative adversarial network)的入侵检测模型,可以有效地降低Probe、DoS/DDoS和暴力破解等网络攻击的威胁。BAGAN 的入侵检测模型整体框架如图2所示,主要包括以下3个模块:1) 预处理模块将原始网络流量表征成向量,并划分训练集和测试集;2) 不平衡处理模块对训练数据进行增强;3) 检测模块利用增强后的混合数据集和测试集进行训练和测试。面向不平衡处理模块,本文提出了一种针对 MMSN 中缺乏少数类训练样本问题的平衡生成对抗网络数据增强算法,该算法可有效提高少数类攻击的识别精度;面向检测模块,本文构建了一种基于分组卷积的检测方案,用于解决MMSN中部分移动终端资源受限的问题。
图2
图2 BAGAN的入侵检测模型整体框架
3.1 不平衡处理模块入侵检测数据不平衡容易提高少数类攻击漏检的概率,影响检测模型的分类性能。与传统的陆地IoT和车联网相比,MMSN中的移动终端分布范围广泛且密度低,受到网络攻击的方式相对隐蔽,资源相对匮乏,部署环境复杂多变。这会导致在MMSN中收集到的网络数据不平衡的特性更加显著,严重制约现有检测模型的检测能力。为解决这一问题,可以利用 GAN 来生成少数类攻击样本。然而,传统 GAN训练成功的概率依靠一定规模的样本数量,直接使用传统GAN来处理不平衡的入侵检测数据,会严重抑制GAN对少数类攻击样本的建模能力。
BAGAN 主要用于解决 GAN 在不平衡图像数据集上生成少数类样本困难的问题,其训练过程包括3个阶段,即AE训练阶段、GAN初始化阶段和GAN训练阶段。AE训练阶段不需要样本确切的类别信息,以无监督方式处理多数类和少数类样本,能够有效地学习所有类样本的公共特征。AE训练完成后分别计算各类样本的潜在表示服从的高斯分布,并将其作为 GAN 训练阶段随机噪声的先验分布。然后,将 AE 的权重迁移到 GAN 中,作为GAN的初始状态,使GAN在训练前能够继承AE的先验知识,处于一个良好的初始状态。最后,对 GAN 进行对抗训练。其中,判别器为多节点输出结构,需要将样本与类别相匹配。
由于 MMSN 中入侵检测数据的高度不平衡特点,直接采用传统 BAGAN 来生成少数类攻击样本存在较大困难。首先,在AE完成训练后,各类样本的潜在表示之间存在较大重叠区域,使生成的样本类别模糊,从而难以学习到良好的条件先验分布。其次,BAGAN在对抗训练中进行优化所得到的交叉熵损失函数与f散度相关,增加了训练不稳定性;同时也未利用梯度惩罚来稳定GAN的训练过程。针对上述 BAGAN 存在的问题,本文提出了一种改进的平衡生成对抗网络来生成少数类攻击样本。
具体地,针对样本类别模糊问题,通过改进条件变分自动编码器(ICVAE, improving conditional variational autoencoder)代替BAGAN中的AE,其结构如图3(a)所示,主要由编码器En和解码器De构成。其中,输入数据<span class="MathJax" id="MathJax-Element-4-Frame" tabindex="0" data-mathml="x" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">xx传送至编码器中计算均值编码<span class="MathJax" id="MathJax-Element-5-Frame" tabindex="0" data-mathml="μ" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">μμ和标准差编码<span class="MathJax" id="MathJax-Element-6-Frame" tabindex="0" data-mathml="σ" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">σσ;<span class="MathJax" id="MathJax-Element-7-Frame" tabindex="0" data-mathml="μ" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">μμ和<span class="MathJax" id="MathJax-Element-8-Frame" tabindex="0" data-mathml="σ" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">σσ再通过“重参数技巧”计算得到潜在表示向量<span class="MathJax" id="MathJax-Element-9-Frame" tabindex="0" data-mathml="z" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">zz;最后,向量<span class="MathJax" id="MathJax-Element-10-Frame" tabindex="0" data-mathml="z" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">zz联合类别信息<span class="MathJax" id="MathJax-Element-11-Frame" tabindex="0" data-mathml="y" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">yy输入解码器中得到重构数据<span class="MathJax" id="MathJax-Element-12-Frame" tabindex="0" data-mathml="x^" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">xˆx^。当<span class="MathJax" id="MathJax-Element-13-Frame" tabindex="0" data-mathml="∀x∼Pr(x)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">∀x∼Pr(x)∀x∼Pr(x) (<span class="MathJax" id="MathJax-Element-14-Frame" tabindex="0" data-mathml="Pr(x)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">Pr(x)Pr(x)为真实数据分布)时,结合ICVAE的结构特点和最大似然准则可得不等式关系如下
[size="normal">y<mo]logpθ(x|y)=DKL(qφ(z|x)||pθ(z|x,y))+logpθ(x|y)=DKL(qφ(z|x)||pθ(z|x,y))+
[size="normal">z|<mstyle] ELBO≥ELBO=Eqφ(z∣∣x)[logpθ(x∣∣z,y)] − ELBO≥ELBO=Eqφ(z|x)[logpθ(x|z,y)] −
[size="normal">z|<mstyle] DKL(qφ(z|x)||pθ(z|y)) (1) DKL(qφ(z|x)||pθ(z|y)) (1)
其中,<span class="MathJax" id="MathJax-Element-18-Frame" tabindex="0" data-mathml="DKL(⋅||⋅)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">DKL(⋅||⋅)DKL(⋅||⋅) 表示 2 个分布之间的相对熵, <span class="MathJax" id="MathJax-Element-19-Frame" tabindex="0" data-mathml="qφ(z|x)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">qφ(z|x)qφ(z|x) 表示通过编码器估计的后验分布, <span class="MathJax" id="MathJax-Element-20-Frame" tabindex="0" data-mathml="pθ(x|z,y)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">pθ(x|z,y)pθ(x|z,y)表示重构结果,<span class="MathJax" id="MathJax-Element-21-Frame" tabindex="0" data-mathml="pθ(z|y)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">pθ(z|y)pθ(z|y)表示先验分布。
本文设<span class="MathJax" id="MathJax-Element-22-Frame" tabindex="0" data-mathml="pθ(z|y)≡N(0,I)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">pθ(z|y)≡N(0,I)pθ(z|y)≡N(0,I) ,满足<span class="MathJax" id="MathJax-Element-23-Frame" tabindex="0" data-mathml="z" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">zz与<span class="MathJax" id="MathJax-Element-24-Frame" tabindex="0" data-mathml="y" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">yy是解纠缠的条件,从而可以最大限度地促使编码器学习样本中的公共特征。由式(1)知,ICVAE 在训练过程中需尽可能地提高证据下界(ELBO, evidence lower bound),即最小化样本重构误差和<span class="MathJax" id="MathJax-Element-25-Frame" tabindex="0" data-mathml="qφ(z|x)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">qφ(z|x)qφ(z|x)与<span class="MathJax" id="MathJax-Element-26-Frame" tabindex="0" data-mathml="pθ(z|y)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">pθ(z|y)pθ(z|y)之间的相对熵损失,表示为
[size="normal">x|<mstyle]LICVAE=−Eqφ(z|x)[logpθ(x∣∣z,y)]+LICVAE=−Eqφ(z|x)[logpθ(x|z,y)]+
[size="normal">z|<mstyle] DKL(qφ(z|x)∥p(z∣∣y)) (2) DKL(qφ(z|x)‖p(z|y)) (2)
另外,针对BAGAN训练不稳定问题,本文对其结构进行了调整,GAN结构如图3(b)所示。首先,将判别器D中的多节点输出层Dc替换为原始GAN中的单节点输出层,并将样本类别信息输入生成器G和判别器D中。其次,结合BAGAN对少数类样本学习策略,可以从均匀的标签集合中随机采样来生成伪标签,但需要满足生成样本的数量等于真实样本的采样数量的条件。再次,设置随机噪声<span class="MathJax" id="MathJax-Element-29-Frame" tabindex="0" data-mathml="z" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">zz服从<span class="MathJax" id="MathJax-Element-30-Frame" tabindex="0" data-mathml="Pz(z)≡pθ(z|y)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">Pz(z)≡pθ(z|y)Pz(z)≡pθ(z|y)分布且特定类样本的生成由输入类别信息控制。最后,判别器D的权重采取随机方式进行初始化,这是由于D的输入联合了类别信息导致其与ICVAE中的En输入不一致。
图3
图3 A-BAGAN结构
在对抗训练过程中,对判别器D施加梯度惩罚Wasserstein GAN(WGAN-GP, Wasserstein GAN with gradient penalty)和深度无悔分析GAN相结合的梯度约束项来稳定对抗训练;结合样本和标签信息的匹配关系,对错误匹配标签的样本加以惩罚。因此,判别器D优化的损失函数LD为
[size="normal">xr~Pr<mo]LD=−εxr~Pr(x)[logD(xr,yr)]−LD=−εxr~Pr(x)[logD(xr,yr)]−
[size="normal">z~Pz<mo] εz~Pz(z)[log(1−D(G(z,yf),yf))] − εz~Pz(z)[log(1−D(G(z,yf),yf))] −
[size="normal">xr~Pr<mo] εxr~Pr(x)[log(1−D(xr,yw))]+ εxr~Pr(x)[log(1−D(xr,yw))]+
[size="normal"><mover] λεxˆ~Pxˆ(xˆ)[(∥∇xˆD(xˆ,yr)∥2−1) 2] (3) λεx^~Px^(x^)[(‖∇x^D(x^,yr)‖2−1) 2] (3)
其中,<span class="MathJax" id="MathJax-Element-35-Frame" tabindex="0" data-mathml="yr" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">yryr、<span class="MathJax" id="MathJax-Element-36-Frame" tabindex="0" data-mathml="yf" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">yfyf和<span class="MathJax" id="MathJax-Element-37-Frame" tabindex="0" data-mathml="yw" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">ywyw分别表示真实标签、伪标签和错误标签的独热形式,<span class="MathJax" id="MathJax-Element-38-Frame" tabindex="0" data-mathml="x^=αxr+(1−α)xf" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">xˆ=αxr+(1−α)xfx^=αxr+(1−α)xf, <span class="MathJax" id="MathJax-Element-39-Frame" tabindex="0" data-mathml="α∼N(0,1)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">α∼N(0,1)α∼N(0,1) 表示 真 实 样 本 <span class="MathJax" id="MathJax-Element-40-Frame" tabindex="0" data-mathml="xr" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">xrxr 和 生 成 样 本<span class="MathJax" id="MathJax-Element-41-Frame" tabindex="0" data-mathml="xf=G(z,yf)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">xf=G(z,yf)xf=G(z,yf)的插值样本,λ表示梯度惩罚因子。
生成器 G 在对抗训练中需要优化的损失函数LG为
[size="normal">z~P<mstyle]LG= −Ez~Pz(z)[logD(G(z,yf),yf)] (4)LG= −Ez~Pz(z)[logD(G(z,yf),yf)] (4)
A-BAGAN模型训练算法如算法1所示。
算法1 A-BAGAN模型训练算法
初始化 编码器En的参数<span class="MathJax" id="MathJax-Element-43-Frame" tabindex="0" data-mathml="θEn" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θEnθEn ,解码器De的参数<span class="MathJax" id="MathJax-Element-44-Frame" tabindex="0" data-mathml="θDe" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θDeθDe ,生成器G的参数<span class="MathJax" id="MathJax-Element-45-Frame" tabindex="0" data-mathml="θG" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θGθG,判别器D的参数<span class="MathJax" id="MathJax-Element-46-Frame" tabindex="0" data-mathml="θD" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θDθD
1) 定义 ICVAE 训练阶段的迭代次数 T1,批次大小m1
2) for t1=1:T1
3) 采样真实样本 <span class="MathJax" id="MathJax-Element-47-Frame" tabindex="0" data-mathml="{xri} i=1m1∼Pr(x)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">{xir} m1i=1∼Pr(x){xri} i=1m1∼Pr(x)
4) 根据式(2)计算批样本的重构误差和相对熵损失之和的平均值
5) 利用Adam优化算法更新参数<span class="MathJax" id="MathJax-Element-48-Frame" tabindex="0" data-mathml="θEn" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θEnθEn和<span class="MathJax" id="MathJax-Element-49-Frame" tabindex="0" data-mathml="θDe" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θDeθDe
6) end for
7) 生成器G继承解码器De的权重,判别器D的参数随机初始化
8) 定义GAN训练阶段的迭代次数T2,批次大小m2,D和G训练次数比nd
9) for t2=1:T2
10) for t3=1:nd
11) 采样真实样本 <span class="MathJax" id="MathJax-Element-50-Frame" tabindex="0" data-mathml="{xri} i=1m2∼Pr(x)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">{xir} m2i=1∼Pr(x){xri} i=1m2∼Pr(x)
12) 为真实样本匹配错误标签 <span class="MathJax" id="MathJax-Element-51-Frame" tabindex="0" data-mathml="{ywi} i=1m2" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">{yiw} m2i=1{ywi} i=1m2, <span class="MathJax" id="MathJax-Element-52-Frame" tabindex="0" data-mathml="ywi~U{1,2,⋯,C}\yri" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">yiw~U{1,2,⋯,C}\yirywi~U{1,2,⋯,C}\yri,C 为类别总数
13) 采样随机噪声 <span class="MathJax" id="MathJax-Element-53-Frame" tabindex="0" data-mathml="{zi} i=1m2∼Pz(z)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">{zi} m2i=1∼Pz(z){zi} i=1m2∼Pz(z) ,伪标签<span class="MathJax" id="MathJax-Element-54-Frame" tabindex="0" data-mathml="{yfi} i=1m2" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">{yif} m2i=1{yfi} i=1m2 ,<span class="MathJax" id="MathJax-Element-55-Frame" tabindex="0" data-mathml="yfi~U{1,2,⋯,C}" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">yif~U{1,2,⋯,C}yfi~U{1,2,⋯,C}
14) 计算各插值样本 <span class="MathJax" id="MathJax-Element-56-Frame" tabindex="0" data-mathml="x^i=αxri+(1−α)⋅xfi" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">xˆi=αxir+(1−α)⋅xifx^i=αxri+(1−α)⋅xfi ,<span class="MathJax" id="MathJax-Element-57-Frame" tabindex="0" data-mathml="α∼N(0,1)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">α∼N(0,1)α∼N(0,1),<span class="MathJax" id="MathJax-Element-58-Frame" tabindex="0" data-mathml="1≤i≤m2" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">1≤i≤m21≤i≤m2
15) 根据式(3)计算判别损失LD
16) 利用Adam优化算法更新参数<span class="MathJax" id="MathJax-Element-59-Frame" tabindex="0" data-mathml="θD" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θDθD
17) end for
18) 采样随机噪声 <span class="MathJax" id="MathJax-Element-60-Frame" tabindex="0" data-mathml="{zi} i=1m2∼Pr(x)" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">{zi} m2i=1∼Pr(x){zi} i=1m2∼Pr(x),伪标签<span class="MathJax" id="MathJax-Element-61-Frame" tabindex="0" data-mathml="{yfi} i=1m2" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">{yif} m2i=1{yfi} i=1m2,<span class="MathJax" id="MathJax-Element-62-Frame" tabindex="0" data-mathml="yfi~U{1,2,⋯,C}" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">yif~U{1,2,⋯,C}yfi~U{1,2,⋯,C}
19) 根据式(4)计算生成损失LG
20) 利用Adam优化算法更新参数<span class="MathJax" id="MathJax-Element-63-Frame" tabindex="0" data-mathml="θG" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">θGθG
21) end for
3.2 检测模块相比陆地通信环境,海洋无线传输环境更加复杂多变,MMSN中存在大量的计算和存储能力较弱且能源受限的多种移动终端。为提供持续入侵检测能力,适当降低模型所需的计算与存储资源是迫切和必要的。因此,在构建检测模型时,本文借鉴了AlexNet 中的分组卷积结构。与普通卷积相比,分组卷积不仅占用资源更少且拥有一定的正则化作用。
由于网络数据为序列数据,本文使用的卷积均为一维卷积。LGCNN结构如图4所示,主要由普通卷积层、分组卷积层、逐点卷积层和分类层组成。具体地,普通卷积层包含一层卷积和一层最大池化,用于对输入数据升维;分组卷积层包含上下2个分支,每个分支均为两层卷积,有利于上下分支学习到输入特征图不同通道的局部信息;逐点卷积层包含一个1×1卷积,用于融合特征图不同通道的信息,同时具有降维作用;分类层由两层全连接和一层Softmax构成,用于对输入特征的概率建模和分类,并在第一个全连接层后添加Dropout防止过拟合。
4 仿真实验结果与讨论分析4.1 实验数据集选取本文选取NSL-KDD和CIC-IDS2017数据集进行仿真实验。前者作为基准网络数据集,已经被广泛用于验证入侵检测模型的有效性。后者作为最新的数据测试集之一,能够较全面地代表当下的MMSN网络环境,有效地模拟真实MMSN的网络流量特性。
图4
图4 LGCNN结构
NSL-KDD数据集是在KDDCup99数据集的基础上删除冗余数据所生成的,各类样本呈现出高度不平衡特性。其中,训练集 KDDTrain+包含 21 种不同的攻击,并将这些攻击划分成 4 种攻击,即DoS、Probe、U2R和R2L。测试集KDDTest+将网络攻击细分为 37 种,包含众多未知攻击。训练集和测试集一共包括148 517条记录,每条记录含有41个有关网络连接的特征和一个标签,这41个特征中包含 3 个字符型特征和 38 个数值型特征, NSL-KDD数据分布如表1所示。
表1 NSL-KDD数据分布
[td]数据集 | Normal/条 | DoS/条 | Probe/条 | U2R/条 | R2L/条 | 总计/条 | KDDTrain+ | 67 343 | 45 927 | 11 656 | 52 | 995 | 125 973 | KDDTest+ | 9 711 | 7 459 | 2 421 | 200 | 2 754 | 22 544 |
新窗口打开| 下载CSV
CIC-IDS2017数据集一共包括2 830 743条记录,涵盖DoS、暴力破解、Port Scan和Bot等14种攻击。每条记录包括 78 个数值型特征和一个标签。本文首先对数据集进行数据清洗,删除了1 358条存在空字符的记录和1 509条存在无穷值的记录,然后将作用相似的攻击合并成一种攻击,形成6种攻击。此外,因为该数据集过于庞大,使用全部的记录只会增加训练时长,不利于实验验证。因此,本文对Benign、DoS和Port Scan类型随机抽取部分记录,其余类型则保持不变,并按照4:1的比例划分训练集和测试集,如表2所示。
4.2 实验数据集预处理NSL-KDD 数据集中存在字符型特征和数值型特征,且数值型特征取值范围差异很大,因此需要对数据进行预处理。对于字符型特征,本文使用独热编码进行转化。对于数值型特征,使用极小极大归一化方法将特征的取值范围限制在[0,1]。预处理后的数据维度从41扩展至122。CIC-IDS2017数据集中只包含数值型特征,因此仅对数据进行归一化处理,处理后的数据维度不变。
4.3 实验评估指标本文使用平均欧氏距离(MED, mean Euclidean distance)和最大均值差异(MMD, maximum mean discrepancy)2 个统计指标来衡量生成样本的有效性,如式(5)和式(6)所示。MED 通过计算真实样本集<span class="MathJax" id="MathJax-Element-64-Frame" tabindex="0" data-mathml="Xr={xri}i=1m" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">Xr={xir}mi=1Xr={xri}i=1m 和生成样本集<span class="MathJax" id="MathJax-Element-65-Frame" tabindex="0" data-mathml="Xf={xfj}j=1n" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">Xf={xjf}nj=1Xf={xfj}j=1n 的总体均值之间的欧氏距离来评估样本的相似性。MMD 使用核函数<span class="MathJax" id="MathJax-Element-66-Frame" tabindex="0" data-mathml="k:Xr⊗Xf→H" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">k:Xr⊗Xf→Hk:Xr⊗Xf→H将样本映射到再生希尔伯特空间<span class="MathJax" id="MathJax-Element-67-Frame" tabindex="0" data-mathml="H" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">HH,计算投影后真实样本与生成样本的总体均值之差来衡量真实分布和生成分布的差异,<span class="MathJax" id="MathJax-Element-68-Frame" tabindex="0" data-mathml="⊗" role="presentation" style="box-sizing: border-box; list-style: none; display: inline; line-height: normal; font-size-adjust: none; text-indent: 0px; text-align: left; letter-spacing: normal; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">⊗⊗表示哈达玛积。
[size="normal">Xr,<mstyle]MED(Xr,Xf)=∥∥∥1m∑i=1mxir−1n∑j=1nxjf∥∥∥22 (5)MED(Xr,Xf)=‖1m∑i=1mxri−1n∑j=1nxfj‖22 (5)
[size="normal">Xr,<mstyle]MMD2(Xr,Xf)=∥∥∥1m∑i=1mkxir−1n∑j=1nkxjf∥∥∥2H (6)MMD2(Xr,Xf)=‖1m∑i=1mkxri−1n∑j=1nkxfj‖H2 (6)
另外,采用精确率(Precision)、召回率(Recall)和F1值作为模型分类效果的评估指标。
4.4 仿真实验参数设置本文仿真实验是在Python3.8和Pytorch 1.10.0环境下进行的。A-BAGAN采用全连接结构,其参数设置如表3所示。LGCNN参数设置如表4所示。为更好地分析本文模型的性能,本文还实现了5种数据增强模型作为对比,分别是 ROS、SMOTE、ADASYN、CWGAN-GP(conditional WGAN-GP)以及BAGAN。
表2 CIC-IDS2017数据分布
[td]数据集 | Benign/条 | DoS/条 | Port Scan/条 | 暴力破解/条 | Web Attack/条 | Bot/条 | Infiltration/条 | 总计/条 | 训练集 | 127 193 | 30 380 | 12 704 | 11 065 | 1 744 | 1 565 | 29 | 184 680 | 测试集 | 31 799 | 7 595 | 3 176 | 2 767 | 436 | 391 | 7 | 46 171 |
新窗口打开| 下载CSV
表3 A-BAGAN参数设置
[td]参数设置 | NSL-KDD | CIC-IDS2017 | 编码器En | 122-256-128-64-32 | 78-85-128-64-32-15 | 解码器De | 37-64-128-256-122 | 22-32-64-128-85-78 | 训练周期/轮 | 30, 300 | 40, 500 | 批次大小/条 | 128, 128 | 128, 128 | 学习率 | 0.000 1, 0.000 2 | 0.01, 0.000 2 |
新窗口打开| 下载CSV
表4 LGCNN参数设置
[td]参数 | NSL-KDD | CIC-IDS2017 | 普通卷积层 | [1,3]×8, poolsize:2 | [1,3]×8, poolsize:2 | 分组卷积层 | [1,2]×3,[1,2]×8 | [1,2]×12,[1,2]×18 | 逐点卷积层 | [1,1]×1 | [1,1]×1 | 分类层 | 160-80-5, Dropout:0.5 | 128-64-7, Dropout:0.5 | 训练周期/轮 | 100 | 200 | 批次大小/条 | 256 | 1 024 | 学习率 | 0.000 5 | 0.001 |
新窗口打开| 下载CSV
4.5 实验结果分析4.5.1 潜在表示可视化对比
首先,对比分析BAGAN和A-BAGAN中AE训练阶段学习之后的潜在表示,并利用t-SNE技术将潜在表示映射到二维平面可视化。图5(a)和图6(a)分别展示了NSL-KDD和CIC-IDS2017数据集在计算机仿真实验中AE训练阶段学习到的潜在表示可视化结果。从图5(a)中可以看出,各类样本分布散乱且重叠现象明显。例如,R2L 与 Normal 样本的重叠比例较高,这是因为R2L是一种伪装式攻击,通常作用于数据包负载;其余部分与正常数据包的特点相似,因此R2L攻击样本与Normal样本区分度不大。同样地,如图6(a)所示,在 CIC-IDS2017实验中可以观察到类似的现象。
本文提出的 ICVAE 可以有效地学习不同类型样本之间的公共特征。从图5(b)和图6(b)可以看出,各类样本均匀地融合在一起,并难以通过某种规则推断出潜在表示所属的类别信息。这说明本文提出的ICVAE在训练后能够为GAN提供一个良好的初始状态,从而有助于在后续对抗训练中克服生成样本类别模糊的问题。
4.5.2 生成样本统计评估
为直观地评估 A-BAGAN 的生成能力,通过MED和MMD衡量真实攻击样本与生成攻击样本之间的相似性,并选取 ADASYN、CWGAN-GP和BAGAN作为对比。具体评估结果如表5和表6所示。
从表5可知,在NSL-KDD数据集生成样本的统计评估中,本文模型在4种生成的攻击样本上取得了总体最优的结果。其中,MED最大值不超过0.062 7, MMD平均值为0.075 0,这说明本文模型能够有效学到真实数据分布;ADASYN 在攻击样本生成中取得了整体最差结果。例如,DoS类生成样本的MED和MMD分别为5.482 5、2.531 2,这说明生成的DoS样本分布与实际分布差异大;CWGAN-GP在攻击样本的生成中获得了总体次优的结果,这源于其采用Wasserstein 距离作为度量,并引入了梯度惩罚;BAGAN作为一种基于不平衡数据集的生成模型,并未表现出良好的统计结果,这是因为BAGAN并未解决类别模糊和训练不稳定的问题。
从表6可知,在CIC-IDS2017数据集生成样本的统计评估中,本文模型依然能够取得总体最优的结果;ADASYN 总体结果表现最差;CWGAN-GP在Infiltration类生成样本表现出最差的统计结果,MMD达到了4.374 4;与NSL-KDD数据集实验中表现结果不同,BAGAN 在 DoS、暴力破解和Bot 等攻击上的统计指标表现良好。上述分析表明,本文模型生成的样本兼顾真实性和类别属性,且在对抗训练过程中具有良好的稳定性,其效果优于传统的数据增强模型,这进一步体现了A-BAGAN 应用于 MMSN 入侵检测数据不平衡处理是可行的。
图5
图5 NSL-KDD数据集潜在表示可视化结果
图6
图6 CIC-IDS2017数据集潜在表示可视化结果
表5 NSL-KDD攻击样本扩充的统计评估
[td]模型 | DoS |
| Probe |
| U2R |
| R2L | MED | MMD | MED | MMD | MED | MMD | MED | MMD | ADASYN | 5.482 5 | 2.531 2 | | 0.915 6 | 0.501 9 | | 0.163 8 | 0.262 5 | | 0.109 0 | 0.224 2 | CWGAN-GP | 0.178 2 | 0.117 3 | | 0.006 1 | 0.005 1 | | 0.098 2 | 0.189 9 | | 0.034 1 | 0.064 4 | BAGAN | 0.546 9 | 0.673 6 | | 1.037 0 | 0.561 9 | | 0.340 7 | 0.505 0 | | 0.052 3 | 0.133 4 | 本文模型 | 0.003 7 | 0.009 7 |
| 0.022 7 | 0.014 1 |
| 0.062 7 | 0.256 2 |
| 0.009 7 | 0.019 8 |
新窗口打开| 下载CSV
表6 CIC-IDS2017数据集攻击样本扩充的统计评估
[td]模型 | DoS |
| Port Scan |
| 暴力破解 |
| Web Attack |
| Bot |
| Infiltration | MED | MMD | MED | MMD | MED | MMD | MED | MMD | MED | MMD | MED | MMD | ADASYN | 1.234 6 | 1.073 2 | | 0.044 2 | 1.093 8 | | 0.260 6 | 0.626 0 | | 0.378 0 | 1.721 7 | | 0.403 3 | 0.826 6 | | 0.012 9 | 0.039 2 | CWGAN-GP | 0.001 8 | 0.003 7 | | 0.017 4 | 0.267 4 | | 0.078 7 | 0.138 5 | | 0.143 3 | 2.373 5 | | 0.007 9 | 0.060 0 | | 4.078 1 | 4.374 4 | BAGAN | 0.130 5 | 0.157 3 | | 0.018 8 | 0.683 8 | | 0.012 9 | 0.029 4 | | 0.017 0 | 0.101 3 | | 0.028 1 | 0.063 7 | | 0.263 9 | 0.792 7 | 本文模型 | 0.148 4 | 0.134 4 |
| 0.001 0 | 0.035 2 |
| 0.001 4 | 0.004 0 |
| 0.025 6 | 0.140 5 |
| 0.016 3 | 0.029 1 |
| 0.511 9 | 0.759 8 |
新窗口打开| 下载CSV
4.5.3 检测模型性能分析
为分析本文数据增强模型对检测性能的改善情况,本节对比了LGCNN在原始数据集和其他5种数据增强的混合数据集上训练后的整体分类性能。在NSL-KDD 数据集中,混合数据集中各类样本数量比例为 1:1,整体分类性能对比结果如图7 所示;CIC-IDS2017 数据集仅对 Web Attack、Bot 和Infiltration 类进行十倍数生成,生成数量分别为15 696、14 085和2 871,该差异取决于数据集本身的性质,生成过多样本易造成冗余和噪声,整体分类性能对比结果如图8所示。
图7
图7 NSL-KDD数据集中整体分类性能对比
图8
图8 CIC-IDS2017数据集中整体分类性能对比
由图7可知,在NSL-KDD数据集中,本文模型在3个指标上均取得了最优结果,其精确率、召回率和F1值分别为0.876 2、0.864 2和0.870 2。相较于数据增强之前,提高了0.040 5的精确率、0.062 3的召回率和0.051 7的F1值。相较于其余5种数据增强模型的平均性能,提高了0.032 4的精确率、0.038 2的召回率和0.035 5的F1值。在原始数据集上训练后,分类器整体性能差是由部分攻击类型样本数量少导致的。ROS、SMOTE和ADASYN等传统的数据增强模型取得了较差的性能,其最高的 F1 值为0.832 5。CWGAN-GP取得了次优的结果,在3个指标上的性能较均衡,但仍比本文模型低 0.013 3的 F1 值。BAGAN 相比较于传统增强方法略有提升,其F1值为0.839 4,提升效果并不显著。
由图8可知,CIC-IDS2017数据集中各增强模型在精确率、召回率和F1值指标上呈现出较均衡的结果,且分类性能均优于原始数据集上的性能。ADASYN相比较于ROS和SMOTE,取得了更好的性能,其F1值达到了0.979 2。BAGAN整体性能优于CWGAN-GP,取得了0.986 2的精确率、0.984 9的召回率、0.985 5的F1值的次优结果,这是因为 BAGAN 在该数据集生成了更加真实的样本。但本文模型相比较于BAGAN,分别提升了0.003 6的精确率、0.004 5的召回率和0.004 1的F1 值。上述实验结果说明了本文模型优于对比模型,也验证了对BAGAN改进的有效性。
为更直观地评估本文模型与其他增强模型对少数类攻击的识别效果,本文从2个数据集中共选取4种少数类攻击用于对比,以综合指标F1值作为衡量指标,其对比结果如图9所示。从图9可以看出,本文模型对U2R、R2L和Web Attack类攻击均取得了最优结果,对Bot类攻击具有最高的竞争能力。
图9
图9 少数类攻击检测效果对比
最后,本文对比了几种入侵检测领域的最新研究,与文献中提出的算法作为对比。NSL-KDD数据集中不同算法的总体性能对比如表7所示。从表7 可以看出,本文模型在召回率和 F1值指标上展现出了最佳性能,而在精确率上取得次优结果。其中,F1值达到了0.870 2;相较于RANet、ROULETTE、IGAN和LCVAE分别提升了0.044 5、0.090 8、0.028 5和0.062 3。
表7 NSL-KDD数据集中不同算法的总体性能对比
[td]算法 | Precision | Recall | F1值 | RANet | 0.819 2 | 0.832 3 | 0.825 7 | ROULETTE | 0.828 5 | 0.804 3 | 0.779 4 | IGAN | 0.848 5 | 0.844 5 | 0.841 7 | LCVAE | 0.976 1 | 0.685 3 | 0.807 9 | 本文模型 | 0.876 2 | 0.864 2 | 0.870 2 |
新窗口打开| 下载CSV
由于在CIC-IDS2017数据集中数据设置不一致,本文实现了 LR(logistic regression)、SVM、MLP (multilayer perceptron)、CNN和LSTM作为对比算法,不同算法的总体性能对比对比结果如表8所示。从表8 可以看出,本文模型在精确率、召回率和 F1值指标中均取得最优结果。其中,F1值达到了0.989 6;相较于LR、SVM、MLP、CNN和LSTM分别提升了0.088 0、0.056 4、0.012 2、0.013 3和0.010 5。
表8 CIC-IDS2017数据集中不同算法的总体性能对比
[td]算法 | Precision | Recall | F1值 | LR | 0.896 6 | 0.906 7 | 0.901 6 | SVM | 0.935 1 | 0.931 3 | 0.933 2 | MLP | 0.978 2 | 0.976 7 | 0.977 4 | CNN | 0.976 8 | 0.975 8 | 0.976 3 | LSTM | 0.979 6 | 0.978 7 | 0.979 1 | 本文模型 | 0.989 8 | 0.989 4 | 0.989 6 |
新窗口打开| 下载CSV
LGCNN占用资源和模型大小如表9所示。从表9可以看出,本文提出的LGCNN具有较低的计算复杂度,其中,平均训练参数数量和浮点计算量(FLOP)分别为18 786个和44 632次,模型规模较小,这表明 LGCNN 具有轻量级性质,可满足MMSN的实际部署需求。
表9 LGCNN占用资源和模型大小
[td]数据集 | 训练参数数量/个 | FLOP/次 | 模型大小/KB | NSL-KDD | 22 940 | 34 688 | 92 | CIC-IDS2017 | 14 632 | 54 576 | 62 |
新窗口打开| 下载CSV
5 结束语本文描述了基于移动边缘计算的海洋气象传感网物理架构,并提出了一种基于平衡生成对抗网络的入侵检测模型。该模型利用A-BAGAN来解决入侵检测数据集不平衡的问题,构建基于分组卷积的LGCNN检测分类器以适应资源约束型海洋移动终端,并分别在公共网络数据集 NSL-KDD 和CIC-IDS2017上进行了计算机模拟仿真实验。实验结果表明,与传统的数据增强模型相比,A-BAGAN生成的样本兼顾真实性和类别属性,且对抗训练过程稳定,能够有效提高入侵检测分类器的识别效果,尤其是针对少数类样本的攻击。未来,将结合深度强化学习研究移动边缘计算对 MMSN 入侵检测任务计算量的影响。此外,为进一步提高MMSN安全性,将结合迁移学习开展特征选择的入侵检测研究。
|