dcgan(bigbang)

来源:https://github.com/aymericdamien/TensorFlow-Examples#tutorials”””DeepConvolutionalGenerativeAdversarialNetwork(DCGAN).Usingdeepconvolutionalgenerativeadversarialnetworks(DCGAN)toge…

大家好,又见面了,我是你们的朋友全栈君。

来源:https://github.com/aymericdamien/TensorFlow-Examples#tutorials

""" Deep Convolutional Generative Adversarial Network (DCGAN).

Using deep convolutional generative adversarial networks (DCGAN) to generate
digit images from a noise distribution.

References:
    - Unsupervised representation learning with deep convolutional generative
    adversarial networks. A Radford, L Metz, S Chintala. arXiv:1511.06434.

Links:
    - [DCGAN Paper](https://arxiv.org/abs/1511.06434).
    - [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).

Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
"""

from __future__ import division, print_function, absolute_import

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

# Training Params
num_steps = 20000
batch_size = 32

# Network Params
image_dim = 784 # 28*28 pixels * 1 channel
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 200 # Noise data points


# Generator Network
# Input: Noise, Output: Image
def generator(x, reuse=False):
    with tf.variable_scope('Generator', reuse=reuse):
        # TensorFlow Layers automatically create variables and calculate their
        # shape, based on the input.
        x = tf.layers.dense(x, units=6 * 6 * 128)
        x = tf.nn.tanh(x)
        # Reshape to a 4-D array of images: (batch, height, width, channels)
        # New shape: (batch, 6, 6, 128)
        x = tf.reshape(x, shape=[-1, 6, 6, 128])
        # Deconvolution, image shape: (batch, 14, 14, 64)
        x = tf.layers.conv2d_transpose(x, 64, 4, strides=2)
        # Deconvolution, image shape: (batch, 28, 28, 1)
        x = tf.layers.conv2d_transpose(x, 1, 2, strides=2)
        # Apply sigmoid to clip values between 0 and 1
        x = tf.nn.sigmoid(x)
        return x


# Discriminator Network
# Input: Image, Output: Prediction Real/Fake Image
def discriminator(x, reuse=False):
    with tf.variable_scope('Discriminator', reuse=reuse):
        # Typical convolutional neural network to classify images.
        x = tf.layers.conv2d(x, 64, 5)
        x = tf.nn.tanh(x)
        x = tf.layers.average_pooling2d(x, 2, 2)
        x = tf.layers.conv2d(x, 128, 5)
        x = tf.nn.tanh(x)
        x = tf.layers.average_pooling2d(x, 2, 2)
        x = tf.contrib.layers.flatten(x)
        x = tf.layers.dense(x, 1024)
        x = tf.nn.tanh(x)
        # Output 2 classes: Real and Fake images
        x = tf.layers.dense(x, 2)
    return x

# Build Networks
# Network Inputs
noise_input = tf.placeholder(tf.float32, shape=[None, noise_dim])
real_image_input = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])

# Build Generator Network
gen_sample = generator(noise_input)

# Build 2 Discriminator Networks (one from noise input, one from generated samples)
disc_real = discriminator(real_image_input)
disc_fake = discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)

# Build the stacked generator/discriminator
stacked_gan = discriminator(gen_sample, reuse=True)

# Build Targets (real or fake images)
disc_target = tf.placeholder(tf.int32, shape=[None])
gen_target = tf.placeholder(tf.int32, shape=[None])

# Build Loss
disc_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits=disc_concat, labels=disc_target))
gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits=stacked_gan, labels=gen_target))

# Build Optimizers
optimizer_gen = tf.train.AdamOptimizer(learning_rate=0.001)
optimizer_disc = tf.train.AdamOptimizer(learning_rate=0.001)

# Training Variables for each optimizer
# By default in TensorFlow, all variables are updated by each optimizer, so we
# need to precise for each one of them the specific variables to update.
# Generator Network Variables
gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator')
# Discriminator Network Variables
disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')

# Create training operations
train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for i in range(1, num_steps+1):

        # Prepare Input Data
        # Get the next batch of MNIST data (only images are needed, not labels)
        batch_x, _ = mnist.train.next_batch(batch_size)
        batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])
        # Generate noise to feed to the generator
        z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])

        # Prepare Targets (Real image: 1, Fake image: 0)
        # The first half of data fed to the generator are real images,
        # the other half are fake images (coming from the generator).
        batch_disc_y = np.concatenate(
            [np.ones([batch_size]), np.zeros([batch_size])], axis=0)
        # Generator tries to fool the discriminator, thus targets are 1.
        batch_gen_y = np.ones([batch_size])

        # Training
        feed_dict = {real_image_input: batch_x, noise_input: z,
                     disc_target: batch_disc_y, gen_target: batch_gen_y}
        _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
                                feed_dict=feed_dict)
        if i % 100 == 0 or i == 1:
            print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))

    # Generate images from noise, using the generator network.
    f, a = plt.subplots(4, 10, figsize=(10, 4))
    for i in range(10):
        # Noise input.
        z = np.random.uniform(-1., 1., size=[4, noise_dim])
        g = sess.run(gen_sample, feed_dict={noise_input: z})
        for j in range(4):
            # Generate image from noise. Extend to 3 channels for matplot figure.
            img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
                             newshape=(28, 28, 3))
            a[j][i].imshow(img)

    f.show()
    plt.draw()
    plt.waitforbuttonpress()

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/128040.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 什么是android原生系统版本,定制安卓和原生Android到底有哪些不同之处?彻底真相了…

    什么是android原生系统版本,定制安卓和原生Android到底有哪些不同之处?彻底真相了…相信大家都知道最近在搞机圈有个大新闻,就是小米即将于8月份推出MIUI9。近日小米MIUI市场副总监@黄龙中就在微博上征求米粉意见,暗示MIUI9可能长下面这样。小米最新官方主题《几何》,浓浓flyme风自2010年MIUI横空出世,国产定制安卓ROM在国内掀起了一阵风暴。MIUI成功后,乐蛙、点心等三方定制ROM迅速崛起,但随着手机系统生态逐渐稳定、刷机需求降低,定制安卓系统的范围逐渐缩小…

    2022年6月19日
    50
  • HTML5 语义元素

    返回目录 http://hovertree.com/h/bjaf/html5zixueji.htm一个语义元素能够清楚的描述其意义给浏览器和开发者。无语义元素实例:<div&gt

    2021年12月23日
    38
  • 5G LTE窄带物联网(NB-IoT) 10

    5G LTE窄带物联网(NB-IoT) 10第7章物理子层物理子层是底层子层,负责MACPDU的物理信道,传输和接收;如图7.1所示。RRC提供PHY子层的配置参数。在MAC/PHY接口,传输信道在发送和接收时分别映射到物理信道,反之亦然[28]。 RRC将其配置参数发送到每个子层,包括PHY子层,如第4.2,5.2,6.2和7.1节所示。7.1RRC配置参数RRC将专用或默认无线电配置参数发送到PHY子层,…

    2022年10月6日
    2
  • Python简单游戏代码

    Python简单游戏代码本人新人一枚 第一次在 CSDN 上写博客 代码不难 主要是混个积分 代码如下 importpygame sys randomfrompy localsimport 定义颜色变量目标方块的颜色 redColor pygame Color 250 0 0 贪吃蛇的颜色 whiteColor pygame Color 255 255 255 背景颜色 b

    2025年11月3日
    2
  • oracle数据库创建user,Oracle数据库如何创建数据库用户呢?

    oracle数据库创建user,Oracle数据库如何创建数据库用户呢?摘要:下文讲述Oracle数据库中创建数据库用户的方法分享,如下所示;在oracle数据库中,当我们创建了相应的数据库后,通常我们会为数据库指定相应的用户,然后单独操作此数据库,下文讲述oracle数据库中创建数据库用户的方法分享实现思路:1.创建oracle用户前,需先创建表空间createtablespace表空间datafile’数据库文件名’size表空间大小如:SQL>…

    2022年7月14日
    19
  • 微信朋友圈奢侈品代购背后:圈子营销光明正大卖“假货”

    微信朋友圈奢侈品代购背后:圈子营销光明正大卖“假货”30岁的张华,一天是这样开始的:睁眼、拿起床头的手机,刷看自己的微信或微博。她自己的微信“朋友圈”更新的速度比往常多了许多,里面大多是一些名牌皮包、衣服的图片信息。记者随机采访了几名手机用户,发现大多数人的微信“朋友圈”里都有人在做这样的微信生意,集中在国际名牌LV、香奈儿、卡地亚等奢侈品,他们自称为“奢侈品代购”。这是一种新的电商“朋友销售模式”?还是暗藏的“假货A货圈”?“奢侈品”代购背后是怎

    2022年5月14日
    50

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号