变分自编码器VAE

变分自编码器VAE1 VAE amp GAN 变分自编码器 Variationala encoder VAE 是一类重要的生成模型 generativemo 除了 VAEs 还有一类重要的生成模型 GANsVAE 跟 GAN 比较 目标基本是一致的 希望构建一个从隐变量 Z 生成目标数据 X 的模型 但是实现上有所不同 生成模型的难题就是判断生成分布与真实分布的相似度 因为我们只知道两者的采样结果 不知道它们的分布表达式 KL 散度是根据两个概率分布的表达式来算它们的相似度的 我们只有样本

1. VAE & GAN

变分自编码器(Variational auto-encoder,VAE)是一类重要的生成模型(generative model)

除了VAEs,还有一类重要的生成模型GANs

VAE 跟 GAN 比较,目标基本是一致的——希望构建一个从隐变量 Z 生成目标数据 X 的模型,但是实现上有所不同。

生成模型的难题就是判断生成分布与真实分布的相似度,因为我们只知道两者的采样结果,不知道它们的分布表达式。 KL 散度是根据两个概率分布的表达式来算它们的相似度的,我们只有样本本身,没有分布表达式,当然也就没有方法算 KL 散度。

GAN 的思路很直接粗犷:既然没有合适的度量,那我干脆把这个度量也用神经网络训练出来吧

与GANs不同的是,VAEs是知道图像的密度函数(PDF)的(或者说,是我们设定的)

2. VAE

2.1 简单引入

变分自编码器VAE

观测数据是X,而X由隐变量Z产生,由Z->X是生成模型\theta,就是解码器;

而由x->z是识别模型\phi,类似于自编码器的编码器。

2.2 传统理解

有一批数据样本 {
X1,…,Xn},其整体用 X 来描述,如果能得到其分布,那我直接根据 p(X) 来采样,就可以得到所有可能的 X 了,但这是不现实的,因此引入:

变分自编码器VAE

p(X|Z) 是一个由 Z 来生成 X的模型,而我们假设 Z 服从标准正态分布,也就是 p(Z)=N(0,I)。如果这个能实现,那么我们就可以先从标准正态分布中采样一个 Z,然后根据 Z 来算一个 X

变分自编码器VAE

但观察上图,经过采样出来的Zk,进而生成的Xk不再对应着原来的 Xk,直接最小化 D(X̂ k,Xk)^2是很不科学的,而事实上代码也不是这样实现的

2.3 真正理解

在整个 VAE 模型中,并没有去使用 p(Z)(先验分布)是正态分布的假设,用的是假设 p(Z|X)(后验分布)是正态分布

给定一个真实样本 Xk,假设存在一个专属于 Xk 的分布 p(Z|Xk),服从正态分布;然后生成器X=g(Z),希望能够把从分布 p(Z|Xk) 采样出来的一个 Zk 还原为 Xk

因此,

        有多少个 X 就有多少个正态分布了。参数:均值 μ 和方差 σ^2(多元的话,都是向量)

于是构建两个神经网络 μk=f1(Xk),logσ^2=f2(Xk) 来算它们了。因为 σ^2 总是非负的,需要加激活函数处理,而拟合 logσ^2 不需要加激活函数,因为它可正可负。

变分自编码器VAE

但是,如果根据上图训练,模型希望重构 X,也就是最小化 D(X̂k,Xk)^2,但是这个重构过程受到噪声的影响,因为Zk 是通过重新采样过的。不过好在这个噪声强度(也就是方差)通过一个神经网络算出来的,所以最终模型为了重构得更好,肯定会想尽办法让方差为0。

方差为 0 的话,也就没有随机性了,所以采样其实都只是得到确定的结果(也就是均值)

模型会慢慢退化成普通的 AutoEncoder,噪声不再起作用

2.4 进一步理解—>分布标准化

VAE 还让所有的 p(Z|X) 都向标准正态分布看齐,这样就防止了噪声为零

假设  所有的 p(Z|X) 都很接近标准正态分布 N(0,I),那么根据定义:

变分自编码器VAE

因此,p(Z) 满足标准正态分布。然后我们就可以放心地从 N(0,I) 中采样来生成图像了。

变分自编码器VAE 2.5 损失

 怎么让所有的 p(Z|X) 都向 N(0,I) 看齐呢?最直接的方法是在重构误差的基础上中加入额外的 loss

因此,将一般(各分量独立的)正态分布与标准正态分布的 KL 散度KL(N(μ,σ^2)‖N(0,I))作为这个额外的 loss,计算结果为:

变分自编码器VAE

 2.6 模型实现

我们要从 p(Z|Xk) 中采样一个 Zk 出来,尽管我们知道了 p(Z|Xk) 是正态分布,但是均值方差都是靠模型算出来的,我们要靠这个过程反过来优化均值方差的模型,但是“采样”这个操作是不可导的,而采样的结果是可导的,于是我们利用了一个事实:

 变分自编码器VAE这样一来,“采样”这个操作就不用参与梯度下降了

3. VAE本质

VAE就是在自编码器模型上做进一步变分处理,使得编码器的输出结果能对应到目标分布的均值和方差;因此,它的 Encoder 有两个,一个用来计算均值,一个用来计算方差

本质上就是在常规的自编码器的基础上,对 encoder 的结果(在VAE中对应着计算均值的网络)加上了“高斯噪声”,使得结果 decoder 能够对噪声有鲁棒性;而那个额外的 KL loss(目的是让均值为 0,方差为 1),事实上就是相当于对 encoder 的一个正则项,希望 encoder 出来的东西零均值。

另外一个 encoder(计算方差的网络)是用来动态调节噪声的强度的。当 decoder 还没有训练好时(重构误差远大于 KL loss),就会适当降低噪声(KL loss 增加),使得拟合起来容易一些(重构误差开始下降)。反之,如果 decoder 训练得还不错时(重构误差小于 KL loss),这时候噪声就会增加(KL loss 减少),使得拟合更加困难了(重构误差又开始增加),这时候decoder 就要想办法提高它的生成能力了

变分自编码器VAE

重构的过程是希望没噪声的,而 KL loss 则希望有高斯噪声的,两者是对立的。所以,VAE 跟 GAN 一样,内部其实是包含了一个对抗的过程,只不过它们两者是混合起来,共同进化的

4. auto-encoder 和 VAE 对比

Auto-Encoder能够把一个高维的向量(28*28图像)压缩到只有30维,并且解码回的图像具备清楚的辨认度(如下图)。

变分自编码器VAE

但是这并没有达到我们真正想要构造的生成模型的标准,因为,对于一个生成模型而言,解码器部分应该是单独能够提取出来的,并且对于在规定维度下任意采样的一个编码,都应该能通过解码器产生一张清晰且真实的图片。

auto-encoder无法达到这一标准的原因:

变分自编码器VAE 

 如上图所示,假设有两张训练图片,经过训练自编码器模型已经能无损地还原这两张图片。接下来,我们在code空间上,两张图片的编码点中间处取一点,然后将这一点交给解码器,我们希望新的生成图片是一张清晰的图片(类似3/4全月的样子)。但是,实际的结果是,生成图片是模糊且无法辨认的乱码图。一个比较合理的解释是,因为编码和解码的过程使用了深度神经网络,这是一个非线性的变换过程,所以在code空间上点与点之间的迁移是非常没有规律的。

为了解决这个问题,我们可以引入噪声(VAE),使得图片的编码区域得到扩大,从而掩盖掉失真的空白编码点。

变分自编码器VAE

如上图所示,现在在给两张图片编码的时候加上一点噪音,使得每张图片的编码点出现在绿色箭头所示范围内,于是在训练模型的时候,绿色箭头范围内的点都有可能被采样到,这样解码器在训练时会把绿色范围内的点都尽可能还原成和原图相似的图片。然后我们可以关注之前那个失真点,现在它处于全月图和半月图编码的交界上,于是解码器希望它既要尽量相似于全月图,又要尽量相似于半月图,于是它的还原结果就是两种图的折中(3/4全月图)。

 由此我们发现,给编码器增添一些噪音,可以有效覆盖失真区域。不过这还并不充分,因为在上图的距离训练区域很远的黄色点处,它依然不会被覆盖到,仍是个失真点。为了解决这个问题,我们可以试图把噪音无限拉长,使得对于每一个样本,它的编码会覆盖整个编码空间,不过我们得保证,在原编码附近编码的概率最高,离原编码点越远,编码概率越低。在这种情况下,图像的编码就由原先离散的编码点变成了一条连续的编码分布曲线,如下图所示。

变分自编码器VAE 

5. pytorch代码

import torch import torchvision from torch import nn from torch import optim import torch.nn.functional as F from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image from torchvision.datasets import MNIST import os import datetime if not os.path.exists('./vae_img'): os.mkdir('./vae_img') def to_img(x): x = x.clamp(0, 1) x = x.view(x.size(0), 1, 28, 28) return x num_epochs = 100 batch_size = 128 learning_rate = 1e-3 img_transform = transforms.Compose([ transforms.ToTensor() # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = MNIST('./data', transform=img_transform, download=True) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() if torch.cuda.is_available(): eps = torch.cuda.FloatTensor(std.size()).normal_() else: eps = torch.FloatTensor(std.size()).normal_() eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h3 = F.relu(self.fc3(z)) # return F.sigmoid(self.fc4(h3)) return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x) z = self.reparametrize(mu, logvar) return self.decode(z), mu, logvar strattime = datetime.datetime.now() model = VAE() if torch.cuda.is_available(): # model.cuda() print('cuda is OK!') model = model.to('cuda') else: print('cuda is NO!') reconstruction_function = nn.MSELoss(size_average=False) # reconstruction_function = nn.MSELoss(reduction=sum) def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images x: origin images mu: latent mean logvar: latent log variance """ BCE = reconstruction_function(recon_x, x) # mse loss # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # KL divergence return BCE + KLD optimizer = optim.Adam(model.parameters(), lr=1e-3) for epoch in range(num_epochs): model.train() train_loss = 0 for batch_idx, data in enumerate(dataloader): img, _ = data img = img.view(img.size(0), -1) img = Variable(img) img = (img.cuda() if torch.cuda.is_available() else img) optimizer.zero_grad() recon_batch, mu, logvar = model(img) loss = loss_function(recon_batch, img, mu, logvar) loss.backward() # train_loss += loss.data[0] train_loss += loss.item() optimizer.step() if batch_idx % 100 == 0: endtime = datetime.datetime.now() print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} time:{:.2f}s'.format( epoch, batch_idx * len(img), len(dataloader.dataset), 100. * batch_idx / len(dataloader), loss.item() / len(img), (endtime-strattime).seconds)) print('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(dataloader.dataset))) if epoch % 10 == 0: # 生成图像 z = torch.randn(batch_size, 20).to(device) out = model.decode(z).view(-1, 1, 28, 28) save_image(out, './vae_img/sampled-{}.png'.format(epoch)) # 重构图像 save = to_img(recon_batch.cpu().data) save_image(save, './vae_img/image_{}.png'.format(epoch)) torch.save(model.state_dict(), './vae.pth') 

Reference:

https://zhuanlan.zhihu.com/p/

http://www.gwylab.com/note-vae.html

https://blog.csdn.net/weixin_/article/details/

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/212485.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月18日 下午7:57
下一篇 2026年3月18日 下午7:57


相关推荐

  • vscode一键配置C/C++多个C及CPP文件编译与tasks.json和launch.json原理

    vscode一键配置C/C++多个C及CPP文件编译与tasks.json和launch.json原理vscode配置环境及配置原理搜了很多的教程,发现要么教程太老,给的配置信息里面有些参数都不能使用了,要么就是直接扔下自己的配置信息就没了,不知道咋来的,也不能拿过来直接用,让我这种小白无从下手,于是就摸索整理一下,帮助一下像我这样小白刚入手的小伙伴们。原理我觉得最重要的就是我们要明白各个配置文件是干嘛的,它是怎么被vscode使用的,明白这一点,那么自己就可以比较清晰参数该怎么改,应该改哪些参数,而不是拿着别人的配置文件,无从下手。配置文件基本的原理(只是原理,不是咋配置的):vscode使用的最

    2025年8月11日
    4
  • 二叉树中序遍历的三种方法

    二叉树中序遍历的三种方法二叉树是一种重要的数据结构 对于二叉树的遍历也很重要 这里通过三种方法简单介绍一下二叉树的中序遍历 中序遍历就是先遍历二叉树的左子树 然后遍历根节点 最后遍历右子树 例如下面的二叉树 中序遍历的结果如下 5 10 6 15 2 对于中序遍历 直观上的结果就是将二叉树所有节点投影到下面的一条直线上 得到的顺序就是二叉树的中序遍历结果 1 递归法递归方法是最容易想到的方法 递归调用遍历方法先遍历左子

    2026年3月18日
    2
  • 【SpringBoot】20、SpringBoot中打war包需要注意「建议收藏」

    【SpringBoot】20、SpringBoot中打war包需要注意「建议收藏」最近在做一个项目,遇到了项目打成war包的一个问题,项目创建时选择的时jar包方式,后因项目部署要求,需要打成war包部署,遇到很多坑,在此做一下记录一、修改打包方式原:<version>0.0.1-SNAPSHOT</version><packaging>jar</packaging>改后:<version>0.0.1-SNAPSHOT</version><packaging>war</p

    2022年5月10日
    38
  • jvm类的加载机制_java类加载流程及原理

    jvm类的加载机制_java类加载流程及原理1.类加载器的组织结构转载请注明出处:http://blog.csdn.net/seu_calvin/article/details/52301541类加载器ClassLoader是具有层次结构的,也就是父子关系。其中,Bootstrap是所有类加载器的父亲。(1)Bootstrapclassloader:启动类加载器当运行Java虚拟机时,这个类加载器被创建,…

    2022年8月11日
    9
  • flashfxp注册码

    flashfxp注册码FlashFXP4.0注册码key(通用):——–FlashFXPRegistrationDataSTART——–FLASHFXPVENSVURFnQEAAAGGZJcQuuC6/Znb915ltgBNBmXkEQhOgVxpo/z4OJEIfnjjL/LLDCQbiZE9+N8EbDIQP/sQQf5D+faH6owMEG7/wINp3590f9jk462O98CWS

    2022年7月26日
    14
  • RPM 安装位置

    RPM 安装位置rpm-qplxxxxxx.rpm1.如何安装rpm软件包rmp软件包的安装可以使用程序rpm来完成。执行下面的命令rpm-iyour-package.rpm其中your-package.rpm是你要安装的rpm包的文件名,一般置于当前目录下。安装过程中可能出现下面的警告或者提示:…conflictwith…可能是要安装的包里有一些文件可

    2022年4月30日
    199

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号