GAN网络详解(从零入门)

GAN网络详解(从零入门)从一个小白的方式理解 GAN 网络 生成对抗网络 可以认为是一个造假机器 造出来的东西跟真的一样 下面开始讲如何造假 主要讲解 GAN 代码 代码很简单 我们首先以造小狗的假图片为例 首先需要一个生成小狗图片的模型 我们称之为 generator 还有一个判断小狗图片是否是真假的判别模型 discrimator 首先输入一个 1000 维的噪声 然后送入生成器 生成器的具体结构如下所示 不看也

从一个小白的方式理解GAN网络(生成对抗网络),可以认为是一个造假机器,造出来的东西跟真的一样,下面开始讲如何造假:(主要讲解GAN代码,代码很简单)

我们首先以造小狗的假图片为例。

首先需要一个生成小狗图片的模型,我们称之为generator,还有一个判断小狗图片是否是真假的判别模型discrimator,

GAN网络详解(从零入门)

首先输入一个1000维的噪声,然后送入生成器,生成器的具体结构如下所示(不看也可以,看完全篇回来再看也一样):

GAN网络详解(从零入门)

其实比较简单,代码如下所示:

def generator_model(): model = Sequential() model.add(Dense(input_dim=1000, output_dim=1024)) model.add(Activation('tanh')) model.add(Dense(128 * 8 * 8)) model.add(BatchNormalization()) model.add(Activation('tanh')) model.add(Reshape((8, 8, 128), input_shape=(8 * 8 * 128,))) model.add(UpSampling2D(size=(4, 4))) model.add(Conv2D(64, (5, 5), padding='same')) model.add(Activation('tanh')) model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(3, (5, 5), padding='same')) model.add(Activation('tanh')) return model

生成器接受一个1000维的随机生成的数组,然后输出一个64×64×3通道的图片数据。输出就是一个图片。不必太过深究,输入是1000个随机数字,输出是一张图片。

 

下面再看判别器代码与结构:

GAN网络详解(从零入门)

代码如下所示:

 def discriminator_model(): model = Sequential() model.add(Conv2D(64, (5, 5), padding='same', input_shape=(64, 64, 3))) model.add(Activation('tanh')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Conv2D(128, (5, 5))) model.add(Activation('tanh')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Flatten()) model.add(Dense(1024)) model.add(Activation('tanh')) model.add(Dense(1)) model.add(Activation('sigmoid')) return model

输入是64,64,3的图片,输出是一个数1或者0,代表图片是否是狗。

下面根据代码讲具体操作:

GAN网络详解(从零入门)

把真图与假图。进行拼接,然后打上标签,真图标签是1,假图标签是0,送入训练的网络。

# 随机生成的1000维的噪声 noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 1000)) # X_train是训练的图片数据,这里取出一个batchsize的图片用于训练,这个是真图(64张) image_batch = X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] # 这里是经过生成器生成的假图 generated_images = generator_model.predict(noise, verbose=0) # 将真图与假图进行拼接 X = np.concatenate((image_batch, generated_images)) # 与X对应的标签,前64张图为真,标签是1,后64张图是假图,标签为0 y = [1] * BATCH_SIZE + [0] * BATCH_SIZE # 把真图与假图的拼接训练数据1送入判别器进行训练判别器的准确度 d_loss = discriminator_model.train_on_batch(X, y)

这里要是看不明白的话可以结合别人的讲解结合来看。

在这里训练好之后,判别器的精度会不断提高。

下面是重头戏了,也是GAN网络的核心:

def generator_containing_discriminator(g, d): model = Sequential() model.add(g) # 判别器参数不进行修改 d.trainable = False model.add(d) return model

他的网络结构如下所示:

GAN网络详解(从零入门)

这个模型有生成器与判别器组成:看代码,这个模型上半部分是生成网络,下半部分是判别网络,生成网络首先生成假图,然后送入判别网络中进行判断,这里有一个d.trainable=False,意思是,只调整生成器,判别的的参数不做更改。简直巧妙。

 

然后我们来看如何训练生成网络,这一块也是核心区域:

 # 训练一个batchsize里面的数据 for index in range(int(X_train.shape[0]/BATCH_SIZE)): # 产生随机噪声 noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 1000)) # 这里面都是真图片 image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE] # 这里产生假图片 generated_images = g.predict(noise, verbose=0) # 将真图片与假图片拼接在一起 X = np.concatenate((image_batch, generated_images)) # 前64张图片标签为1,即真图,后64张照片为假图 y = [1] * BATCH_SIZE + [0] * BATCH_SIZE # 对于判别器进行训练,不断提高判别器的识别精度 d_loss = d.train_on_batch(X, y) # 再次产生随机噪声 noise = np.random.uniform(-1, 1, (BATCH_SIZE, 1000)) # 设置判别器的参数不可调整 d.trainable = False # ×××××××××××××××××××××××××××××××××××××××××××××××××××××××××× # 在此我们送入噪声,并认为这些噪声是真实的标签 g_loss = generator_containing_discriminator.train_on_batch(noise, [1] * BATCH_SIZE) # ×××××××××××××××××××××××××××××××××××××××××××××××××××××××××× # 此时设置判别器可以被训练,参数可以被修改 d.trainable = True # 打印损失值 print("batch %d d_loss : %s, g_loss : %f" % (index, d_loss, g_loss))

重点在于这句代码

g_loss = generator_containing_discriminator.train_on_batch(noise, [1] * BATCH_SIZE)

首先这个网络模型(定义在上面),先传入生成器中,然后生成器生成图片之后,把图片传入判别器中,标签此刻传入的是1,真实的图片,但实际上是假图,此刻判别器就会判断为假图,然后模型就会不断调整生成器参数,此刻的判别器的参数被设置为为不可调整,d.trainable=False,所以为了不断降低loss值,模型就会一直调整生成器的参数,直到判别器认为这是真图。此刻判别器与生成器达到了一个平衡。也就是说生成器产生的假图,判别器已经分辨不出来了。所以继续迭代,提高判别器精度,如此往复循环,直到生成连人都辨别不了的图片。

最后我训练了大概65轮,实际上生成比较真实的狗的图片我估计可能上千轮了,当然不同的网络结构,所需要的迭代次数也不一样。我这个因为太费时间,就跑了大概,可以看出大概有个狗模样。这个是训练了65轮之后的效果:

GAN网络详解(从零入门)

以上就是全部的内容了。

https://github.com/jensleeGit/Kaggle_self_use/tree/master/Generative%20Dog%20Images

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/202970.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月19日 下午10:53
下一篇 2026年3月19日 下午10:54


相关推荐

  • C语言随机数函数

    C语言随机数函数rand 简介 1 使用该函数首先应在开头包含头文件 stdlib h include stdlib h C 建议使用 include cstdlib 2 在标准的 C 库中函数 rand 可以生成 0 RAND MAX 之间的一个随机数 其中 RAND MAX 是 stdlib h 中定义的一个整数 它与系统有关 rand 函数没有输入参数 直接通过表达式 rand 来引用 例如可以用下面的语句来打印两个随机数 printf Random cstdlib stdlib h stdlib h

    2026年3月18日
    2
  • 黄金搭档:rsync与inotify

    黄金搭档:rsync与inotify目录一 rsync 同步 一 概述二 常用命令二 inotify 实时监控 概述三 服务组合过程四 实验一 rsync 二 rsync inotify 一 rsync 同步 一 概述 1 rsync 是 linux 系统下的数据镜像备份工具 RemoteSync 是快速增量备份工具 可以远程同步 支持本地复制 2 可以不改变原有的数据属性信息 实现数据的备份迁移特性 3 因 delta transfer 算法 二进制比较算法 受欢迎 4 使用 c s 架构 端口号为 873 二 常用命令常用选项含义

    2026年3月17日
    1
  • 前端面试ajax考点汇总_javascript常见面试题

    前端面试ajax考点汇总_javascript常见面试题前端面试题总结(四)ajax篇1、什么是AJAX,为什么要使用Ajax(请谈一下你对Ajax的认识)什么是ajax:AJAX是“AsynchronousJavaScriptandXML”的缩写。他是指一种创建交互式网页应用的网页开发技术。Ajax包含下列技术:基于web标准(standards-basedpresentation)XH…

    2022年8月29日
    4
  • 时光网打不开的解决办法

    时光网打不开的解决办法2010年11月4日Update:时光网已回归.在C:\Windows\System32\drivers\etc文件夹中打开hosts文件,以文本格式打开。(Windows7和Vista可能需要复

    2022年7月2日
    28
  • Python给图片添加盲水印

    Python给图片添加盲水印盲水印就是图片有水印但人眼看不出来 需要通过程序才能提取水印 相当于隐形 盖章 可以用在数据泄露溯源 版权保护等场景

    2026年3月18日
    3
  • 【编译原理】LL(1)语法分析器

    【编译原理】LL(1)语法分析器1 项目要求文法要求 1 从文件读入 每条产生式占用一行 2 文法为 LL 1 文法从文件中读入文法 从键盘上输入待分析的符号串 采用 LL 1 分析算法判断该符号串是否为该文法的句子 2 实验思路 首先实现集合 FIRST X 构造算法和集合 FOLLOW A 构造算法 再根据 FIRST 和 FOLLOW 集合构造出预测分析表 并对指定的句子打印出分析栈的分析过程 判断是否为该文法的句子 3 实验原理 1

    2026年3月19日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号