0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

使用Pytorch实现频谱归一化生成对抗网络(SN-GAN)

冬至子 来源:思否AI 作者:思否AI 2023-10-18 10:59 次阅读

自从扩散模型发布以来,GAN的关注度和论文是越来越少了,但是它们里面的一些思路还是值得我们了解和学习。所以本文我们来使用Pytorch 来实现SN-GAN

谱归一化生成对抗网络是一种生成对抗网络,它使用谱归一化技术来稳定鉴别器的训练。谱归一化是一种权值归一化技术,它约束了鉴别器中每一层的谱范数。这有助于防止鉴别器变得过于强大,从而导致不稳定和糟糕的结果。

SN-GAN由Miyato等人(2018)在论文“生成对抗网络的谱归一化”中提出,作者证明了sn - gan在各种图像生成任务上比其他gan具有更好的性能。

SN-GAN的训练方式与其他gan相同。生成器网络学习生成与真实图像无法区分的图像,而鉴别器网络学习区分真实图像和生成图像。这两个网络以竞争的方式进行训练,它们最终达到一个点,即生成器能够产生逼真的图像,从而欺骗鉴别器。

以下是SN-GAN相对于其他gan的优势总结:

  • 更稳定,更容易训练
  • 可以生成更高质量的图像
  • 更通用,可以用来生成更广泛的内容。

模式崩溃

模式崩溃是生成对抗网络(GANs)训练中常见的问题。当GAN的生成器网络无法产生多样化的输出,而是陷入特定的模式时,就会发生模式崩溃。这会导致生成的输出出现重复,缺乏多样性和细节,有时甚至与训练数据完全无关。

GAN中发生模式崩溃有几个原因。一个原因是生成器网络可能对训练数据过拟合。如果训练数据不够多样化,或者生成器网络太复杂,就会发生这种情况。另一个原因是生成器网络可能陷入损失函数的局部最小值。如果学习率太高,或者损失函数定义不明确,就会发生这种情况。

以前有许多技术可以用来防止模式崩溃。比如使用更多样化的训练数据集。或者使用正则化技术,例如dropout或批处理归一化,使用合适的学习率和损失函数也很重要。

Wassersteian损失

Wasserstein损失,也称为Earth Mover’s Distance(EMD)或Wasserstein GAN (WGAN)损失,是一种用于生成对抗网络(GAN)的损失函数。引入它是为了解决与传统GAN损失函数相关的一些问题,例如Jensen-Shannon散度和Kullback-Leibler散度。

Wasserstein损失测量真实数据和生成数据的概率分布之间的差异,同时确保它具有一定的数学性质。他的思想是最小化这两个分布之间的Wassersteian距离(也称为地球移动者距离)。Wasserstein距离可以被认为是将一个分布转换为另一个分布所需的最小“成本”,其中“成本”被定义为将概率质量从一个位置移动到另一个位置所需的“工作量”。

Wasserstein损失的数学定义如下:

对于生成器G和鉴别器D, Wasserstein损失(Wasserstein距离)可以表示为:

Jensen-Shannon散度(JSD): Jensen-Shannon散度是一种对称度量,用于量化两个概率分布之间的差异

对于概率分布P和Q, JSD定义如下:

JSD(P∥Q)=1/2(KL(P∥M)+KL(Q∥M))

M为平均分布,KL为Kullback-Leibler散度,P∥Q为分布P与分布Q之间的JSD。

JSD总是非负的,在0和1之间有界,并且对称(JSD(P|Q) = JSD(Q|P))。它可以被解释为KL散度的“平滑”版本。

Kullback-Leibler散度(KL散度):Kullback-Leibler散度,通常被称为KL散度或相对熵,通过量化“额外信息”来测量两个概率分布之间的差异,这些“额外信息”需要使用另一个分布作为参考来编码一个分布。

对于两个概率分布P和Q,从Q到P的KL散度定义为:KL(P∥Q)=∑x P(x)log(Q(x)/P(x))。KL散度是非负非对称的,即KL(P∥Q)≠KL(Q∥P)。当且仅当P和Q相等时它为零。KL散度是无界的,可以用来衡量分布之间的不相似性。

1-Lipschitz Contiunity

1- lipschitz函数是斜率的绝对值以1为界的函数。这意味着对于任意两个输入x和y,函数输出之间的差不超过输入之间的差。

数学上函数f是1-Lipschitz,如果对于f定义域内的所有x和y,以下不等式成立:

|f(x) — f(y)| <= |x — y|

在生成对抗网络(GANs)中强制Lipschitz连续性是一种用于稳定训练和防止与传统GANs相关的一些问题的技术,例如模式崩溃和训练不稳定。在GAN中实现Lipschitz连续性的主要方法是通过使用Lipschitz约束或正则化,一种常用的方法是Wasserstein GAN (WGAN)。

在标准gan中,鉴别器(也称为WGAN中的批评家)被训练来区分真实和虚假数据。为了加强Lipschitz连续性,WGAN增加了一个约束,即鉴别器函数应该是Lipschitz连续的,这意味着函数的梯度不应该增长得太大。在数学上,它被限制为:

∥∣D(x)D(y)∣≤K⋅∥xy

其中D(x)是评论家对数据点x的输出,D(y)是y的输出,K是Lipschitz 常数。

WGAN的权重裁剪:在原始的WGAN中,通过在每个训练步骤后将鉴别器网络的权重裁剪到一个小范围(例如,[-0.01,0.01])来强制执行该约束。权重裁剪确保了鉴别器的梯度保持在一定范围内,并加强了利普希茨连续性。

WGAN的梯度惩罚: WGAN的一种变体,称为WGAN-GP,它使用梯度惩罚而不是权值裁剪来强制Lipschitz约束。WGAN-GP基于鉴别器的输出相对于真实和虚假数据之间的随机点的梯度,在损失函数中添加了一个惩罚项。这种惩罚鼓励了Lipschitz约束,而不需要权重裁剪。

谱范数

从符号上看矩阵

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 编码器
    +关注

    关注

    45

    文章

    3643

    浏览量

    134525
  • 生成器
    +关注

    关注

    7

    文章

    315

    浏览量

    21011
  • 频谱仪
    +关注

    关注

    7

    文章

    340

    浏览量

    36048
  • pytorch
    +关注

    关注

    2

    文章

    808

    浏览量

    13226
收藏 人收藏

    评论

    相关推荐

    利用Arm Kleidi技术实现PyTorch优化

    PyTorch个广泛应用的开源机器学习 (ML) 库。近年来,Arm 与合作伙伴通力协作,持续改进 PyTorch 的推理性能。本文将详细介绍如何利用 Arm Kleidi 技术提升 Arm
    的头像 发表于 12-23 09:19 149次阅读
    利用Arm Kleidi技术<b class='flag-5'>实现</b><b class='flag-5'>PyTorch</b>优化

    大语言模型优化生成管理方法

    大语言模型的优化生成管理是个系统工程,涉及模型架构、数据处理、内容控制、实时响应以及伦理监管等多个层面。以下,是对大语言模型优化生成管理方法的梳理,由AI部落小编整理。
    的头像 发表于 12-02 10:45 82次阅读

    pytorch怎么在pycharm中运行

    部分:PyTorch和PyCharm的安装 1.1 安装PyTorch PyTorch个开源的机器学习库,用于构建和训练神经
    的头像 发表于 08-01 16:22 1427次阅读

    PyTorch如何实现多层全连接神经网络

    PyTorch实现多层全连接神经网络(也称为密集连接神经网络或DNN)是个相对直接的过程,涉及定义
    的头像 发表于 07-11 16:07 1202次阅读

    如何在PyTorch实现LeNet-5网络

    PyTorch实现LeNet-5网络个涉及深度学习基础知识、PyTorch框架使用以及网络
    的头像 发表于 07-11 10:58 794次阅读

    pytorch中有神经网络模型吗

    当然,PyTorch个广泛使用的深度学习框架,它提供了许多预训练的神经网络模型。 PyTorch中的神经网络模型 1. 引言 深度学习是
    的头像 发表于 07-11 09:59 701次阅读

    PyTorch神经网络模型构建过程

    PyTorch,作为个广泛使用的开源深度学习库,提供了丰富的工具和模块,帮助开发者构建、训练和部署神经网络模型。在神经网络模型中,输出层是尤为关键的部分,它负责将模型的预测结果以合适
    的头像 发表于 07-10 14:57 503次阅读

    PyTorch的介绍与使用案例

    学习领域的个重要工具。PyTorch底层由C++实现,提供了丰富的API接口,使得开发者能够高效地构建和训练神经网络模型。PyTorch
    的头像 发表于 07-10 14:19 398次阅读

    生成对抗网络(GANs)的原理与应用案例

    生成对抗网络(Generative Adversarial Networks,GANs)是种由蒙特利尔大学的Ian Goodfellow等人在2014年提出的深度学习算法。GANs通过构建两个
    的头像 发表于 07-09 11:34 1034次阅读

    如何使用PyTorch建立网络模型

    PyTorch个基于Python的开源机器学习库,因其易用性、灵活性和强大的动态图特性,在深度学习领域得到了广泛应用。本文将从PyTorch的基本概念、网络模型构建、优化方法、实际
    的头像 发表于 07-02 14:08 418次阅读

    使用PyTorch构建神经网络

    PyTorch个流行的深度学习框架,它以其简洁的API和强大的灵活性在学术界和工业界得到了广泛应用。在本文中,我们将深入探讨如何使用PyTorch构建神经网络,包括从基础概念到高级
    的头像 发表于 07-02 11:31 714次阅读

    神经网络架构有哪些

    、语音识别、自然语言处理等多个领域。本文将对几种主要的神经网络架构进行详细介绍,包括前馈神经网络、循环神经网络、卷积神经网络生成对抗
    的头像 发表于 07-01 14:16 715次阅读

    深度学习生成对抗网络GAN)全解析

    GANs真正的能力来源于它们遵循的对抗训练模式。生成器的权重是基于判别器的损失所学习到的。因此,生成器被它生成的图像所推动着进行训练,很难知道生成
    发表于 03-29 14:42 4573次阅读
    深度学习<b class='flag-5'>生成对抗</b><b class='flag-5'>网络</b>(<b class='flag-5'>GAN</b>)全解析

    生成式人工智能和感知式人工智能的区别

    生成新的内容和信息的人工智能系统。这些系统能够利用已有的数据和知识来生成全新的内容,如图片、音乐、文本等。生成式人工智能通常基于深度学习技术,如生成对抗
    的头像 发表于 02-19 16:43 1762次阅读

    基于国产AI编译器ICRAFT部署YOLOv5边缘端计算的实战案例

    人工智能领域中各种算法模型的不断研究和改进。随着深度学习的兴起,包括卷积神经网络(CNN)、循环神经网络(RNN)、生成对抗网络GAN)、
    的头像 发表于 01-03 10:17 3180次阅读
    基于国产AI编译器ICRAFT部署YOLOv5边缘端计算的实战案例