0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

生成式对抗网络基础知识直观解读

mK5P_AItists 来源:未知 作者:胡薇 2018-05-14 08:29 次阅读

大家都知道,自从生成式对抗网络(GAN)出现以来,便在图像处理方面有着广泛的应用。但还是有很多人对于GAN不是很了解,担心由于没有数学知识底蕴而学不会GAN。在本文中,谷歌研究员Stefan Hosein提供了一份初学者入门GAN的教程,在这份教程中,即使你没有拥有深厚的数学知识,你也能够了解什么是生成式对抗网络(GAN)。

类比

理解GAN的一个最为简单的方法是通过一个简单的比喻:

假设有一家商店,店主要从顾客那里购买某些种类的葡萄酒,然后再将这些葡萄酒销售出去。

然而,有些可恶的顾客为了赚取金钱而出售假酒。在这种情况下,店主必须能够区分假酒和正宗的葡萄酒。

你可以想象,在最初的时候,伪造者在试图出售假酒时可能会犯很多错误,并且店主很容易就会发现该酒不是正宗的葡萄酒。经历过这些失败之后,伪造者会继续尝试使用不同的技术来模拟真正的葡萄酒,而有些方法最终会取得成功。现在,伪造者知道某些技术已经能够躲过店主的检查,那么他就可以开始进一步对基于这些技术的假酒进行改善提升。

与此同时,店主可能会从其他店主或葡萄酒专家那里得到一些反馈,说明她所拥有的一些葡萄酒并不是原装的。这意味着店主必须改进她的判别方式,从而确定葡萄酒是伪造的还是正宗的。伪造者的目标是制造出与正宗葡萄酒无法区分的葡萄酒,而店主的目标是准确地分辨葡萄酒是否是正宗的。

可以这样说,这种循环往复的竞争正是GAN背后的主要思想。

生成式对抗网络的组成部分

通过上面的例子,我们可以提出一个GAN的体系结构。

GAN中有两个主要的组成部分:生成器和鉴别器。在上面我们所描述的例子中,店主被称为鉴别器网络,通常是一个卷积神经网络(因为GAN主要用于图像任务),主要是分配图像是真实的概率。

伪造者被称为生成式网络,并且通常也是一个卷积神经网络(具有解卷积层,deconvolution layers)。该网络接收一些噪声向量并输出一个图像。当对生成式网络进行训练时,它会学习可以对图像的哪些区域进行改进/更改,以便鉴别器将难以将其生成的图像与真实图像区分开来。

生成式网络不断地生成与真实图像更为接近的图像,而与此同时,鉴别式网络则试图确定真实图像和假图像之间的差异。最终的目标就是建立一个生成式网络,它可以生成与真实图像无法区分的图像。

用Keras编写一个简单的生成式对抗网络

现在,你已经了解什么是GAN,以及它们的主要组成部分,那么现在我们可以开始试着编写一个非常简单的代码。你可以使用Keras,如果你不熟悉这个Python库的话,则应在继续进行操作之前阅读本教程。本教程基于易于理解的GAN进行开发的。

首先,你需要做的第一件事是通过pip安装以下软件包:

- keras

- matplotlib

- tensorflow

- tqdm

你将使用matplotlib绘图,tensorflow作为Keras后端库和tqdm,以显示每个轮数(迭代)的花式进度条。

下一步是创建一个Python脚本,在这个脚本中,你首先需要导入你将要使用的所有模块和函数。在使用它们时将给出每个解释。

importos

import numpy as np

import matplotlib.pyplot as plt

from tqdm import tqdm

from keras.layers import Input

from keras.models import Model, Sequential

from keras.layers.core import Dense, Dropout

from keras.layers.advanced_activations import LeakyReLU

from keras.datasets import mnist

from keras.optimizers import Adam

from keras import initializers

你现在需要设置一些变量:

# Let Keras know that we are using tensorflow as our backend engine

os.environ["KERAS_BACKEND"] = "tensorflow"

# To make sure that we can reproduce the experiment and get the same results

np.random.seed(10)

# The dimension of our random noise vector.

random_dim = 100

在开始构建鉴别器和生成器之前,你首先应该收集数据,并对其进行预处理。你将会使用到常见的MNIST数据集,该数据集具有一组从0到9的单个数字图像。

MINST数字样本

def load_minst_data():

# load the data

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# normalize our inputs to be in the range[-1, 1]

x_train = (x_train.astype(np.float32) - 127.5)/127.5

# convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have

# 784 columns per row

x_train = x_train.reshape(60000, 784)

return (x_train, y_train, x_test, y_test)

需要注意的是,mnist.load_data()是Keras的一部分,这使得你可以轻松地将MNIST数据集导入至工作区域中。

现在,你可以开始创建你的生成器和鉴别器网络了。在这一过程中,你会使用到Adam优化器。此外,你还需要创建一个带有三个隐藏层的神经网络,其激活函数为Leaky Relu。对于鉴别器而言,你需要为其添加dropout层(dropout layers),以提高对未知图像的鲁棒性。

def get_optimizer():

return Adam(lr=0.0002, beta_1=0.5)

def get_generator(optimizer):

generator = Sequential()

generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))

generator.add(LeakyReLU(0.2))

generator.add(Dense(512))

generator.add(LeakyReLU(0.2))

generator.add(Dense(1024))

generator.add(LeakyReLU(0.2))

generator.add(Dense(784, activation='tanh'))

generator.compile(loss='binary_crossentropy', optimizer=optimizer)

return generator

def get_discriminator(optimizer):

discriminator = Sequential()

discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(512))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(256))

discriminator.add(LeakyReLU(0.2))

discriminator.add(Dropout(0.3))

discriminator.add(Dense(1, activation='sigmoid'))

discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)

return discriminator

接下来,则需要将发生器和鉴别器组合在一起!

def get_gan_network(discriminator, random_dim, generator, optimizer):

# We initially set trainable to False since we only want to train either the

# generator or discriminator at a time

discriminator.trainable = False

# gan input (noise) will be 100-dimensional vectors

gan_input = Input(shape=(random_dim,))

# the output of the generator (an image)

x = generator(gan_input)

# get the output of the discriminator (probability if the image is real or not)

gan_output = discriminator(x)

gan = Model(inputs=gan_input, outputs=gan_output)

gan.compile(loss='binary_crossentropy', optimizer=optimizer)

return gan

为了完整起见,你还可以创建一个函数,使其每训练20个轮数就对生成的图像进行1次保存。由于这不是本次课程的核心内容,因此你不必完全理解该函数。

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):

noise = np.random.normal(0, 1, size=[examples, random_dim])

generated_images = generator.predict(noise)

generated_images = generated_images.reshape(examples, 28, 28)

plt.figure(figsize=figsize)

for i in range(generated_images.shape[0]):

plt.subplot(dim[0], dim[1], i+1)

plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')

plt.axis('off')

plt.tight_layout()

plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

你现在已经编码了大部分网络,剩下的就是训练这个网络,并查看你创建的图像。

def train(epochs=1, batch_size=128):

# Get the training and testing data

x_train, y_train, x_test, y_test = load_minst_data()

# Split the training data into batches of size 128

batch_count = x_train.shape[0] / batch_size

# Build our GAN netowrk

adam = get_optimizer()

generator = get_generator(adam)

discriminator = get_discriminator(adam)

gan = get_gan_network(discriminator, random_dim, generator, adam)

for e in xrange(1, epochs+1):

print '-'*15, 'Epoch %d' % e, '-'*15

for _ in tqdm(xrange(batch_count)):

# Get a random set of input noise and images

noise = np.random.normal(0, 1, size=[batch_size, random_dim])

image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]

# Generate fake MNIST images

generated_images = generator.predict(noise)

X = np.concatenate([image_batch, generated_images])

# Labels for generated and real data

y_dis = np.zeros(2*batch_size)

# One-sided label smoothing

y_dis[:batch_size] = 0.9

# Train discriminator

discriminator.trainable = True

discriminator.train_on_batch(X, y_dis)

# Train generator

noise = np.random.normal(0, 1, size=[batch_size, random_dim])

y_gen = np.ones(batch_size)

discriminator.trainable = False

gan.train_on_batch(noise, y_gen)

if e == 1 or e % 20 == 0:

plot_generated_images(e, generator)

if __name__ == '__main__':

train(400, 128)

在训练400个轮数后,你可以查看生成的图像。在查看经过1个轮数训练后而生成的图像时,你会发现它没有任何真实结构,在查看经过40个轮数训练后而生成的图像时,你会发现数字开始成形,最后,在查看经过400个轮数训练后而生成的图像时,你会发现,除了一组数字难以辨识外,其余大多数数字都清晰可见。

训练1个轮数后的结果(上)| 训练40个轮数后的结果(中) | 训练400个轮数后的结果(下)

此代码在CPU上运行一次大约需要2分钟,这也是我们选择该代码的主要原因。你可以尝试进行更多轮数的训练,并向生成器和鉴别器中添加更多数量(种类)的层。当然,在仅使用CPU的前提下,采用更复杂和更深层的体系结构时,相应的代码运行时间也会有所延长。但也不要因此放弃尝试。

至此,你已经完成了全部的学习任务,你以一种直观的方式学习了生成式对抗网络(GAN)的基础知识!并且,你还在Keras库的协助下实现了你的第一个模型。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • GaN
    GaN
    +关注

    关注

    19

    文章

    1918

    浏览量

    72971

原文标题:无需数学背景!谷歌研究员为你解密生成式对抗网络

文章出处:【微信号:AItists,微信公众号:人工智能学家】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    华为网络基础知识教程

    华为网络基础知识教程
    发表于 08-18 15:16

    嵌入网络协议栈基础知识

    第2章 嵌入网络协议栈基础知识本章教程为大家介绍嵌入网络协议栈基础知识,本章先让大家有一个全
    发表于 08-03 06:24

    介绍嵌入网络协议栈基础知识

    第2章 嵌入网络协议栈基础知识本章教程为大家介绍嵌入网络协议栈基础知识,本章先让大家有一个全
    发表于 08-03 06:58

    介绍嵌入网络协议栈基础知识

    第2章 嵌入网络协议栈基础知识本章教程为大家介绍嵌入网络协议栈基础知识,本章先让大家有一个全
    发表于 08-04 08:17

    图像生成对抗生成网络

    图像生成对抗生成网络ganby Thalles Silva 由Thalles Silva暖身 (Warm up)Let’s say there’s a very cool party going
    发表于 09-15 09:29

    嵌入系统基础知识

    关于嵌入系统基础知识关于嵌入系统基础知识关于嵌入系统基础知识
    发表于 03-03 16:58 5次下载

    浅析生成对抗网络发展的内在逻辑

    生成对抗网络(Generative adversarial networks, GAN)是当前人工智能学界最为重要的研究热点之一。其突出的生成
    的头像 发表于 08-22 16:25 1w次阅读

    新型生成对抗分层网络表示学习算法

      针对当前链路预测算法无法有效保留网络图髙阶结构特征的问题,提岀一种生成对抗分层网络表示学习算法。根据网络图的一阶邻近性和二阶邻近性,递
    发表于 03-11 10:53 16次下载
    新型<b class='flag-5'>生成对抗</b><b class='flag-5'>式</b>分层<b class='flag-5'>网络</b>表示学习算法

    一种利用生成对抗网络的超分辨率重建算法

    针对传统图像超分辨率重建算法存在网络训练困难与生成图像存在伪影的冋题,提岀一种利用生成对抗网络
    发表于 03-22 15:40 4次下载
    一种利用<b class='flag-5'>生成</b><b class='flag-5'>式</b><b class='flag-5'>对抗</b><b class='flag-5'>网络</b>的超分辨率重建算法

    基于生成对抗网络的深度文本生成模型

    评论,对音乐作品自动生成评论可以在一定程度上解决此问题。在在线唱歌平台上的评论文本与音乐作品的表现评级存在一定的关系。因此,研究考虑音乐作品评级信息的评论文本自动生成的方为此提出了一种基于生成
    发表于 04-12 13:47 15次下载
    基于<b class='flag-5'>生成</b><b class='flag-5'>式</b><b class='flag-5'>对抗</b><b class='flag-5'>网络</b>的深度文本<b class='flag-5'>生成</b>模型

    基于生成对抗网络的图像补全方法

    图像补全是数字图像处理领域的重要研究方向,具有广阔的应用前景。提出了一种基于生成对抗网络(GAN)的图像补全方法。生成
    发表于 05-19 14:38 14次下载

    生成对抗网络应用及研究综述

    基于零和博弈思想的生成对抗网络(GAN)可通过无监督学习获得数据的分布,并生成较逼真的数据。基于GAN的基础概念及理论框架,硏究各类GAN
    发表于 06-09 11:16 13次下载

    基于像素级生成对抗网络的图像彩色化模型

    基于像素级生成对抗网络的图像彩色化模型
    发表于 06-27 11:02 4次下载

    「自行科技」一文了解生成对抗网络GAN

    生成对抗网络(Generative adversarial network, GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。
    的头像 发表于 09-16 09:25 4028次阅读

    生成对抗网络(GANs)的原理与应用案例

    生成对抗网络(Generative Adversarial Networks,GANs)是一种由蒙特利尔大学的Ian Goodfellow等人在2014年提出的深度学习算法。GANs通过构建两个
    的头像 发表于 07-09 11:34 793次阅读