0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

一个实用的GitHub项目:TensorFlow-Cookbook

DPVg_AI_era 来源:lq 2019-02-19 09:04 次阅读

今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook。 这是一个易用的TensorFlow代码集,包含了对GAN有用的一些通用架构和函数。

今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook。

这是一个易用的TensorFlow代码集,作者是来自韩国的AI研究科学家Junho Kim,内容涵盖了谱归一化卷积、部分卷积、pixel shuffle、几种归一化函数、 tf-datasetAPI,等等。

作者表示,这个repo包含了对GAN有用的一些通用架构和函数。

项目正在进行中,作者将持续为其他领域添加有用的代码,目前正在添加的是 tf-Eager mode的代码。欢迎提交pull requests和issues。

Github地址 :

https://github.com/taki0112/Tensorflow-Cookbook

如何使用

Import

ops.py

operations

from ops import *

utils.py

image processing

from utils import *

Network template

def network(x, is_training=True, reuse=False, scope="network"): with tf.variable_scope(scope, reuse=reuse): x = conv(...) ... return logit

使用DatasetAPI向网络插入数据

Image_Data_Class = ImageData(img_size, img_ch, augment_flag) trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16) trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat() trainA_iterator = trainA.make_one_shot_iterator() data_A = trainA_iterator.get_next() logit = network(data_A)

了解更多,请阅读:

https://github.com/taki0112/Tensorflow-DatasetAPI

Option

padding='SAME'

pad = ceil[ (kernel - stride) / 2 ]

pad_type

'zero' or 'reflect'

sn

usespectral_normalizationor not

Ra

userelativistic ganor not

loss_func

gan

lsgan

hinge

wgan

wgan-gp

dragan

注意

如果你不想共享变量,请以不同的方式设置所有作用域名称。

权重(Weight)

weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001) weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)

初始化(Initialization)

Xavier: tf.contrib.layers.xavier_initializer()

He: tf.contrib.layers.variance_scaling_initializer()

Normal: tf.random_normal_initializer(mean=0.0, stddev=0.02)

Truncated_normal: tf.truncated_normal_initializer(mean=0.0, stddev=0.02)

Orthogonal: tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0

正则化(Regularization)

l2_decay: tf.contrib.layers.l2_regularizer(0.0001)

orthogonal_regularizer: orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001)

卷积(Convolution)

basic conv

x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv')

partial conv (NVIDIAPartial Convolution)

x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv')

dilated conv

x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='SAME', sn=True, scope='dilate_conv')

Deconvolution

basic deconv

x = deconv(x, channels=64, kernel=3, stride=2, padding='SAME', use_bias=True, sn=True, scope='deconv')

Fully-connected

x = fully_conneted(x, units=64, use_bias=True, sn=True, scope='fully_connected')

Pixel shuffle

x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down') x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up')

down===> [height, width] -> [height // scale_factor, width // scale_factor]

up===> [height, width] -> [height * scale_factor, width * scale_factor]

Block

residual block

x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block') x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down') x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up')

down===> [height, width] -> [height // 2, width // 2]

up===> [height, width] -> [height * 2, width * 2]

attention block

x = self_attention(x, channels=64, use_bias=True, sn=True, scope='self_attention') x = self_attention_with_pooling(x, channels=64, use_bias=True, sn=True, scope='self_attention_version_2') x = squeeze_excitation(x, channels=64, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation') x = convolution_block_attention(x, channels=64, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention')

Normalization

Normalization

x = batch_norm(x, is_training=is_training, scope='batch_norm') x = instance_norm(x, scope='instance_norm') x = layer_norm(x, scope='layer_norm') x = group_norm(x, groups=32, scope='group_norm') x = pixel_norm(x) x = batch_instance_norm(x, scope='batch_instance_norm') x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'): x = adaptive_instance_norm(x, gamma, beta):

如何使用condition_batch_norm,请参考:

https://github.com/taki0112/BigGAN-Tensorflow

如何使用adaptive_instance_norm,请参考:

https://github.com/taki0112/MUNIT-Tensorflow

Activation

x = relu(x) x = lrelu(x, alpha=0.2) x = tanh(x) x = sigmoid(x) x = swish(x)

Pooling & Resize

x = up_sample(x, scale_factor=2) x = max_pooling(x, pool_size=2) x = avg_pooling(x, pool_size=2) x = global_max_pooling(x) x = global_avg_pooling(x) x = flatten(x) x = hw_flatten(x)

Loss

classification loss

loss, accuracy = classification_loss(logit, label)

pixel loss

loss = L1_loss(x, y) loss = L2_loss(x, y) loss = huber_loss(x, y) loss = histogram_loss(x, y)

histogram_loss表示图像像素值在颜色分布上的差异。

gan loss

d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit) g_loss = generator_loss(Ra=True, loss_func='wgan_gp', real=real_logit, fake=fake_logit)

如何使用gradient_penalty,请参考:

https://github.com/taki0112/BigGAN-Tensorflow/blob/master/BigGAN_512.py#L180

kl-divergence (z ~ N(0, 1))

loss = kl_loss(mean, logvar)

Author

Junho Kim

Github地址 :

https://github.com/taki0112/Tensorflow-Cookbook

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 函数
    +关注

    关注

    3

    文章

    4276

    浏览量

    62303
  • GitHub
    +关注

    关注

    3

    文章

    465

    浏览量

    16352
  • tensorflow
    +关注

    关注

    13

    文章

    328

    浏览量

    60461

原文标题:【收藏】简单易用 TensorFlow 代码集,GAN通用框架、函数

文章出处:【微信号:AI_era,微信公众号:新智元】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    如何使用tensorflow快速搭建起深度学习项目

    我们继续以 NG 课题组提供的 sign 手势数据集为例,学习如何通过Tensorflow快速搭建起深度学习项目。数据集标签共有零到五总共 6 类标签,示例如下
    的头像 发表于 10-25 08:57 7671次阅读

    干货 | TensorFlow的55经典案例

    ://yann.lecun.com/exdb/mnist/第二步:为TF新手准备的各个类型的案例、模型和数据集初步了解:TFLearn TensorFlow接下来的示例来自TFLearn,这是
    发表于 10-09 11:28

    TensorFlow是什么

    来发现和理解濒临灭绝的海牛。位日本农民运用 TensorFlow 开发了应用程序,使用大小和形状等物理特性对黄瓜进行分类。使用 Tensor
    发表于 07-22 10:14

    TensorFlow的特点和基本的操作方式

    Tensorflow是Google开源的深度学习框架,来自于Google Brain研究项目,在Google第代分布式机器学习框架DistBelief的基础上发展起来。Tensorflow
    发表于 11-23 09:56

    The VHDL Cookbook

    The VHDL Cookbook 好东西哦。网上搜集,希望对你有用。
    发表于 03-25 14:37 19次下载

    github入门到上传本地项目步骤

    GitHub可以托管各种git库,并提供web界面,但与其它像 SourceForge或Google Code这样的服务不同,GitHub的独特卖点在于从另外
    发表于 11-29 16:51 2206次阅读

    github使用教程_github菜鸟教程

    GitHub 拥有非常鼓励合作的社区氛围。这方面源于 GitHub 的付费模式:私有项目
    发表于 11-29 17:22 1.5w次阅读
    <b class='flag-5'>github</b>使用教程_<b class='flag-5'>github</b>菜鸟教程

    提出快速启动自己的 TensorFlow 项目模板

    简洁而精密的结构对于深度学习项目来说是必不可少的,在经过多次练习和 TensorFlow 项目开发之后,本文作者提出了结合简便性、优化文
    的头像 发表于 02-07 11:47 3088次阅读
    提出<b class='flag-5'>一</b><b class='flag-5'>个</b>快速启动自己的 <b class='flag-5'>TensorFlow</b> <b class='flag-5'>项目</b>模板

    总结Tensorflow纯干货学习资源,分为教程、视频和项目三大板块

    基于Facebook中FastText的简单嵌入式文本分类器:https://github.com/apcode/tensorflow_fasttext。该项目是源于Facebook中
    的头像 发表于 04-16 11:39 1.1w次阅读

    人工智能凉了? GitHub年度报告揭示真相

    去年GitHub的报告中,人工智能非常火。今年情况如何?在下面的图表中,可以看到: Tensorflow在最热开源项目中排第三;在增长最快的项目中Pytorch排名第二,
    的头像 发表于 10-23 10:16 3498次阅读

    总结GitHub热门开源项目

    项目的热门程度,较为直观的判断方式就是它的Stars增长速度,排行第的flutter依然是Google家的,Flutter 是在2018年的2月份才推出第
    的头像 发表于 01-18 14:15 2891次阅读

    GitHub年度报告:Python首次击败Java

    作为 GitHub 上最受欢迎的项目TensorFlow 已经建立了庞大的软件社区。去
    的头像 发表于 11-22 15:14 2451次阅读

    TensorFlow Community Spotlight获奖项目

    Spotlight 获奖者,她用 TensorFlow 开发出款追踪坐姿的工具,当使用者坐姿不正确的情况下屏幕会变模糊 在这四
    的头像 发表于 11-26 09:43 1786次阅读

    上传本地项目代码到github

    GitHub面向开源及私有软件项目的托管平台,因为只支持git 作为唯的版本库格式进行托管,故名
    的头像 发表于 11-14 16:45 1030次阅读
    上传本地<b class='flag-5'>项目</b>代码到<b class='flag-5'>github</b>

    如何使用Github高效率的查找项目

    GitHub各位应该都很熟悉了,全球最大的开源社区,也是全球最大的同性交友网站~~,但是大部分同学使用GitHub应该就是通过别人的开源链接,点进去下载对应的项目,而真正使用Github
    的头像 发表于 09-24 14:43 643次阅读
    如何使用<b class='flag-5'>Github</b>高效率的查找<b class='flag-5'>项目</b>