深度学习之GAN对抗神经网络

深度学习之GAN对抗神经网络1、结构图2、知识点3、代码及案例#coding:utf-8###对抗生成网络案例#####<imgsrc="jpg/3.png"alt=&qu

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

1、结构图

深度学习之GAN对抗神经网络

 

2、知识点

生成器(G):将噪音数据生成一个想要的数据
判别器(D):将生成器的结果进行判别,

3、代码及案例

深度学习之GAN对抗神经网络
深度学习之GAN对抗神经网络

# coding: utf-8

# ## 对抗生成网络案例 ##
# 
# 
# <img src="jpg/3.png" alt="FAO" width="590" >

# - 判别器 : 火眼金睛,分辨出生成和真实的 <br /> 
# <br /> 
# - 生成器 : 瞒天过海,骗过判别器 <br /> 
# <br /> 
# - 损失函数定义 : 一方面要让判别器分辨能力更强,另一方面要让生成器更真 <br /> 
# <br /> 
# 
# <img src="jpg/1.jpg" alt="FAO" width="590" >

# In[1]:


import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt

get_ipython().run_line_magic('matplotlib', 'inline')


# # 导入数据

# In[2]:


from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/data')


# ## 网络架构
# 
# ### 输入层 :待生成图像(噪音)和真实数据
# 
# ### 生成网络:将噪音图像进行生成
# 
# ### 判别网络:
# - (1)判断真实图像输出结果 
# - (2)判断生成图像输出结果
# 
# ### 目标函数:
# - (1)对于生成网络要使得生成结果通过判别网络为真 
# - (2)对于判别网络要使得输入为真实图像时判别为真 输入为生成图像时判别为假
# 
# <img src="jpg/2.png" alt="FAO" width="590" >

# ## Inputs

# In[3]:


#真实数据和噪音数据
def get_inputs(real_size, noise_size):
    
    real_img = tf.placeholder(tf.float32, [None, real_size])
    noise_img = tf.placeholder(tf.float32, [None, noise_size])
    
    return real_img, noise_img


# ## 生成器
# * noise_img: 产生的噪音输入
# * n_units: 隐层单元个数
# * out_dim: 输出的大小(28 * 28 * 1)

# In[4]:


def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
  
    with tf.variable_scope("generator", reuse=reuse):
        # hidden layer
        hidden1 = tf.layers.dense(noise_img, n_units)
        # leaky ReLU
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        # dropout
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)

        # logits & outputs
        logits = tf.layers.dense(hidden1, out_dim)
        outputs = tf.tanh(logits)
        
        return logits, outputs


# ## 判别器
# * img:输入
# * n_units:隐层单元数量
# * reuse:由于要使用两次

# In[5]:


def get_discriminator(img, n_units, reuse=False, alpha=0.01):

    with tf.variable_scope("discriminator", reuse=reuse):
        # hidden layer
        hidden1 = tf.layers.dense(img, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        
        # logits & outputs
        logits = tf.layers.dense(hidden1, 1)
        outputs = tf.sigmoid(logits)
        
        return logits, outputs


# ## 网络参数定义
# * img_size:输入大小
# * noise_size:噪音图像大小
# * g_units:生成器隐层参数
# * d_units:判别器隐层参数
# * learning_rate:学习率

# In[6]:


img_size = mnist.train.images[0].shape[0]

noise_size = 100

g_units = 128

d_units = 128

learning_rate = 0.001

alpha = 0.01


# ## 构建网络

# In[7]:


tf.reset_default_graph()

real_img, noise_img = get_inputs(img_size, noise_size)

# generator
g_logits, g_outputs = get_generator(noise_img, g_units, img_size)

# discriminator
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)


# ### 目标函数:
# - (1)对于生成网络要使得生成结果通过判别网络为真 
# - (2)对于判别网络要使得输入为真实图像时判别为真 输入为生成图像时判别为假
# 
# <img src="jpg/2.png" alt="FAO" width="590" >

# In[8]:


# discriminator的loss
# 识别真实图片
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 
                                                                     labels=tf.ones_like(d_logits_real)))
# 识别生成的图片
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                                     labels=tf.zeros_like(d_logits_fake)))
# 总体loss
d_loss = tf.add(d_loss_real, d_loss_fake)

# generator的loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                labels=tf.ones_like(d_logits_fake)))


# ## 优化器

# In[9]:


train_vars = tf.trainable_variables()

# generator
g_vars = [var for var in train_vars if var.name.startswith("generator")]
# discriminator
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

# optimizer
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)


# # 训练

# In[10]:


# batch_size
batch_size = 64
# 训练迭代轮数
epochs = 300
# 抽取样本数
n_sample = 25

# 存储测试样例
samples = []
# 存储loss
losses = []
# 保存生成器变量
saver = tf.train.Saver(var_list = g_vars)
# 开始训练
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for batch_i in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            batch_images = batch[0].reshape((batch_size, 784))
            # 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数
            batch_images = batch_images*2 - 1
            
            # generator的输入噪声
            batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
            
            # Run optimizers
            _ = sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img: batch_noise})
            _ = sess.run(g_train_opt, feed_dict={noise_img: batch_noise})
        
        # 每一轮结束计算loss
        train_loss_d = sess.run(d_loss, 
                                feed_dict = {real_img: batch_images, 
                                             noise_img: batch_noise})
        # real img loss
        train_loss_d_real = sess.run(d_loss_real, 
                                     feed_dict = {real_img: batch_images, 
                                                 noise_img: batch_noise})
        # fake img loss
        train_loss_d_fake = sess.run(d_loss_fake, 
                                    feed_dict = {real_img: batch_images, 
                                                 noise_img: batch_noise})
        # generator loss
        train_loss_g = sess.run(g_loss, 
                                feed_dict = {noise_img: batch_noise})
        
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "判别器损失: {:.4f}(判别真实的: {:.4f} + 判别生成的: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake),
              "生成器损失: {:.4f}".format(train_loss_g))    
        
        losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))
        
        # 保存样本
        sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))
        gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),
                               feed_dict={noise_img: sample_noise})
        samples.append(gen_samples)
        
        
        saver.save(sess, './checkpoints/generator.ckpt')

# 保存到本地
with open('train_samples.pkl', 'wb') as f:
    pickle.dump(samples, f)


# # loss迭代曲线

# In[11]:


fig, ax = plt.subplots(figsize=(20,7))
losses = np.array(losses)
plt.plot(losses.T[0], label='判别器总损失')
plt.plot(losses.T[1], label='判别真实损失')
plt.plot(losses.T[2], label='判别生成损失')
plt.plot(losses.T[3], label='生成器损失')
plt.title("对抗生成网络")
ax.set_xlabel('epoch')
plt.legend()


# # 生成结果

# In[12]:


# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pickle.load(f)


# In[13]:


#samples是保存的结果 epoch是第多少次迭代
def view_samples(epoch, samples):
    
    fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch][1]): # 这里samples[epoch][1]代表生成的图像结果,而[0]代表对应的logits
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes


# In[14]:


_ = view_samples(-1, samples) # 显示最终的生成结果


# # 显示整个生成过程图片

# In[15]:


# 指定要查看的轮次
epoch_idx = [10, 30, 60, 90, 120, 150, 180, 210, 240, 290] 
show_imgs = []
for i in epoch_idx:
    show_imgs.append(samples[i][1])


# In[16]:


# 指定图片形状
rows, cols = 10, 25
fig, axes = plt.subplots(figsize=(30,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

idx = range(0, epochs, int(epochs/rows))

for sample, ax_row in zip(show_imgs, axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)


# # 生成新的图片

# In[17]:


# 加载我们的生成器变量
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_noise = np.random.uniform(-1, 1, size=(25, noise_size))
    gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),
                           feed_dict={noise_img: sample_noise})


# In[18]:


_ = view_samples(0, [gen_samples])

View Code

4、优化目标

深度学习之GAN对抗神经网络

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/167863.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • ip addr命令作用_ipconfig命令的功能和作用

    ip addr命令作用_ipconfig命令的功能和作用一、ifconfig命令1)配置地址:比如修改eth0网卡的ip为192.168.174.100,子网掩码为255.255.255.0ifconfigeth0192.168.174.100/24使用ifconfig修改ip会直接在内存中生效,重启系统或者重启network服务就丢失。重启服务:Centos6:ser…

    2022年7月27日
    34
  • bs与cs架构的优缺点_bs架构与cs架构的区别详细讲解

    bs与cs架构的优缺点_bs架构与cs架构的区别详细讲解简介C/S又称Client/Server或客户/服务器模式。服务器通常采用高性能的PC、工作站或小型机,并采用大型数据库系统,如Oracle、Sybase、Informix或SQLServer。客户端需要安装专用的客户端软件。B/S是Brower/Server的缩写,客户机上只要安装一个浏览器(Browser),如NetscapeNavigator或InternetExplorer,服务器安装Oracle、Sybase、Informix或SQLServer等数据库。浏览器通过Web

    2022年8月31日
    4
  • 力争群雄:2012年度IT博客大赛100强脱颖而出[通俗易懂]

    力争群雄:2012年度IT博客大赛100强脱颖而出[通俗易懂]2012年度IT博客大赛于11月20日圆满结束。这一所谓的“海选”阶段为期33天,引无数网友和博主翘首以待,来源包括51CTO、独立个人博客、其他博客服务托管商,以及今年评选新增加的分类如独立博客、学生博客和团队博客等众多博主共同参加了这一角逐,其中100位实力雄厚和人气充盈的博主获得了前100强的殊荣。他们占据了25万张票选中的8成以上份量,并将为2012年度IT博客50…

    2022年7月21日
    12
  • win10自动更新有效强制永久关闭怎么办_win10怎么不自动更新

    win10自动更新有效强制永久关闭怎么办_win10怎么不自动更新网上的一些Win10彻底关闭WindowsUpdate自动更新的方法,主要是通过一些如设置流量计费或借助一些专门的小工具来实现,比如360来限速,但往往会发现,Win10自动更新就像打不死的小强,不管怎么关闭,之后还是会自动更新,让不少小伙伴颇为不爽。今天为大家带来了这篇教程,通过服务、注册表、组策略、计划任务中,全方位设置,彻底关闭Win10自动更新,感兴趣的小伙伴不妨试试吧。服务中关闭Wi…

    2025年6月16日
    3
  • ASP .NET DropDownList多级联动事件

    ASP .NET DropDownList多级联动事件思路假如有三级省、市、区,先加载出所有省选择省之后,加载出该省所有市选择市之后,加载出该市所有区重新选择省,则清空市和区重新选择市,则清空区想好数据结构,不同的数据结构做法不同例子数据结构publicclassArea{publicintPKID{get;set;}publicintParentID{get;set;}…

    2025年10月25日
    2
  • linux

    linux

    2021年6月30日
    85

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号