生成对抗学习

生成对抗学习生成对抗学习自动编码器复习生成式对抗网络介绍训练判别器训练生成器数据集模型构建生成器判别器模型训练自动编码器复习核心目标 构建输入等于输出用途 降维 特征提取 初始化深度网络训练方式 梯度下降 反向传播生成式对抗网络介绍最小最大游戏 零和博弈 游戏双方分别是生成器和判别器 生成器学习伪造数据 判别器学习判断数据的真实性 为了胜利双方不断自我优化 各自提高生成能力和判别能力 最终以假乱真 训练判别器真实数据集中采样数据 并标记为 1 生成器随机采样数据 并标记为 0 锁定生成器不训练 反向

自动编码器复习

生成式对抗网络介绍

训练判别器

训练生成器

数据集

class Image_data(Dataset): def __init__(self,img_h=256,img_w=256,path,data_path,label_path,process): self.img_h=imgz_h self.img_w=img_w self.path=path self.data_path=data_path self.label_path=label_path self.process=process self.img_data=os.listdir(self.path+'/'+self.data_path) def __len__(self): return len(self.img_data) def __getitem__(self,item): image_name=self.img_data[item] label_name=image_name.split('.')[0] image_path=self.path+'/'+self.data_path+'/'+image_name label_path=self.path+'/'+self.label_path+'/'+labe_name+'.jpg' image=Image.open(image_path) label=Image.open(label_path) if self.process: transforms_image=transforms.Compose([ transforms.Resize([self.img_h,self.img_w]), transforms.Totensor(), transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) image=transforms_image(image) transforms_label=transforms.Compose([ transforms.Resize([self.img_h,self.img_w]), transfoms.ToTensor()]) label=transforms_label(label) return image,label 

模型构建

生成器

class conv_block(nn.Module): def __init__(self,ch_in,ch_out): super(conv_block, self).__init__() self.conv=nn.Sequential( nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): return self.conv(x) class conv_up(nn.Module): def __init__(self,ch_in,ch_out): super(conv_up, self).__init__() self.conv=nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): return self.conv(x) class U_net(nn.Module): def __init__(self,ch_in,ch_out): super(U_net, self).__init__() self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2) self.conv1=conv_block(ch_in=ch_in,ch_out=32) self.conv2=conv_block(ch_in=32,ch_out=64) self.conv3=conv_block(ch_in=64,ch_out=128) self.conv4=conv_block(ch_in=128,ch_out=256) self.conv5=conv_block(ch_in=256,ch_out=512) self.up5=conv_up(ch_in=512,ch_out=256) self.conv_up5=conv_block(ch_in=512,ch_out=256) self.up4=conv_up(ch_in=256,ch_out=128) self.conv_up4=conv_block(ch_in=256,ch_out=128) self.up3=conv_up(ch_in=128,ch_out=64) self.conv_up3=conv_block(ch_in=128,ch_out=64) self.up2=conv_up(ch_in=64,ch_out=32) self.conv_up2=conv_block(ch_in=64,ch_out=32) self.conv1_1=nn.Conv2d(in_channels=32,out_channels=ch_out,kernel_size=1,stride=1,padding=0) def forward(self,x): x1=self.conv1(x) x2=self.maxpool(x1) x2=self.conv2(x2) x3=self.maxpool(x2) x3=self.conv3(x3) x4=self.maxpool(x3) x4=self.conv4(x4) x5=self.maxpool(x4) x5=self.conv5(x5) d5=self.up5(x5) d5=torch.cat((x4,d5),dim=1) d5=self.conv_up5(d5) d4=self.up4(d5) d4=torch.cat((x3,d4),dim=1) d4=self.conv_up4(d4) d3=self.up3(d4) d3=torch.cat((x2,d3),dim=1) d3=self.conv_up3(d3) d2=self.up2(d3) d2=torch.cat((x1,d2),dim=1) d2=self.conv_up2(d2) d1=self.conv1_1(d2) d1=torch.sigmoid(d1) return d1 

判别器

class CNN(nn.Module): def __init__(self,ch_in,num_class=1): super(CNN, self).__init__() ndf=32 self.dis=nn.Sequential( conv_block(ch_in=ch_in,ch_out=ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=ndf,ch_out=2*ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=2*ndf,ch_out=4*ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=4*ndf,ch_out=8*ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=8*ndf,ch_out=16*ndf) ) self.fc=nn.Sequential( nn.Linear(16*ndf,num_class), nn.Sigmoid(), ) self.avg_pool=nn.AdaptiveAvgPool2d((1,1)) def forward(self,x): out=self.dis(x) out=self.avg_pool(out) out=out.view(out.size(0),-1) out=self.fc(out) return out 

模型训练

class Trainer(object): def __init__(self,ch_in=3,ch_out=3,epoch=50,batchsize=16,lr=0.005,dataset=None): self.ch_in=ch_in self.ch_out=ch_out self.epoch=epoch self.batchsize=batchsize self.lr=lr self.dataset_loader=Dataloader(dataset=dataset,batch_size=self.batch_size,shuffle=True) #生成器 self.gen=U_net(ch_in=self.ch_in,ch_out=self.ch_out) self.gen_optimizer=torch.optim.Adam(self.gen.parameters(),lr=self.lr) self.gen_loss=nn.L1Loss() #判别器 self.dis=CNN(ch_in=self.ch_in*2,num_class=1) self.dis_optimizer=torch.optim.Adam(self.dis.parameters(),lr=self.lr) self.dis_loss=nn.BCELoss() def set_requires_grad(self,nets,requires_grad): if not isinstance(nets,list): nets=[nets] for net in nets: if net is not None: for parm in net.parameters(): parm.requires_grad=requires_grad def train(self): Tensor=torch.FloatTensor for epoch in range(self.epoch): epoch_gen_loss=0 epoch_dis_loss=0 for i,(bx,by) in enumerate(self.dataset_loader): one_label=Variable(Tensor(bx.size(0),1).fill_(1.0),requires_grad=False) zero_label=Variable(Tensor(bx.size(0),1).fill_(0),requires_grad=False) if i%2==0: #训练生成器 self.set_requires_grad(self.dis,False) self.set_erquires_grad(self.gen,True) bx_gen=self.gen(bx) loss_rec=self.gen_loss(bx_gen,by) fake_ab=torch.cat([bx_gen,bx],dim=1) dis_fake=self.dis(fake_ab) loss_gen=self.gen_loss(dis_fake,one_label) loss_gen=loss_gen+100*loss_rec self.gen_optimizer.zero_grad() loss_gen.backward() self.gen_optimizer.step() epoch_gen_loss+=loss_gen.item() print('生成器损失',loss_gen.item()) else: #训练判别器 self.set_requires_grad(self.dis.False) self.set_requires_grad(self.gen,True) bx_gen=self.gen(bx) fake_ab=torch.cat([bx_gen,bx],dim=1) dis_fake=self.dis(fake_ab) real_ab=torch.cat([by,bx],dim=1) dis_real=self.dis(real_ab) loss_fake=self.gen_loss(dis_fake,zero_label) loss_real=self.gen_loss(dis_real,one_label) loss_dis=(loss_fake+loss_real)/2 self.dis_optimizer.zero_grad() loss_dis.backward() self.dis_optimizer.step() epoch_dis_loss+=loss_dis.item() print('判别器损失',loss_dis.item()) 
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/213038.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月18日 下午6:41
下一篇 2026年3月18日 下午6:42


相关推荐

  • 向量范数与矩阵范数矩阵模的平方-函数和几何以及映射的关系-数学

    向量范数与矩阵范数矩阵模的平方-函数和几何以及映射的关系-数学…

    2022年5月15日
    51
  • Ubuntu保存退出vim编辑器「建议收藏」

    Ubuntu保存退出vim编辑器「建议收藏」命令模式,从键盘上输入的任何字符都被作为编辑命令来解释,vi下很多操作如配置编辑器、文本查找和替换、选择文本等都是在命令模式下进行的。输入模式,从键盘上输入的所有字符都被插入到正在编辑的缓冲区中,被当作正文。1.编辑进入vi/vim后按字母“i”或“I”即可进入编辑状态(此时左下角会出现“插入”),另外还可以用a…

    2022年6月11日
    44
  • RowBounds[通俗易懂]

    RowBounds[通俗易懂]在mybatis中,使用RowBounds进行分页,非常方便,不需要在sql语句中写limit,即可完成分页功能。但是由于它是在sql查询出所有结果的基础上截取数据的,所以在数据量大的sql中并不适用,它更适合在返回数据结果较少的查询中使用最核心的是在mapper接口层,传参时传入RowBounds(intoffset,intlimit)对象,即可完成分页。不需要修改xml配置添加limitmapper接口层代码如下List<Book>selectBoo

    2022年4月19日
    53
  • 3D移动 translate3d

    3D移动 translate3d3D转换我们主要学习工作中最常用的3D位移和3D旋转主要知识点3D移动在2D移动的基础上多加了一个可以移动的方向,就是z轴方向。translform:translateX(100px):仅仅是在x轴上移动 translform:translateY(100px):仅仅是在Y轴上移动 translform:translateZ(100px):仅仅是在Z轴上移动(注意:translateZ一般用px作单位) transform:translate3d(x,y,z):其中x、y、z分别指要移动的

    2025年8月9日
    4
  • mysql通配符转义_转义MySQL通配符

    mysql通配符转义_转义MySQL通配符小编典典_而%不是通配符在MySQL一般,而且不应该被转义,将它们放入普通的字符串字面量的目的。mysql_real_escape_string是正确的,足以满足此目的。addcslashes不应该使用。_并且%仅在LIKE-matching上下文中是特殊的。当您想为LIKE语句中的文字使用准备字符串时,要100%匹配百分之一百,而不仅仅是以100开头的任何字符串,都需要担心两种转义。首先是喜欢转…

    2022年6月16日
    46
  • 元数据与数据治理|大数据治理(第九篇)

    元数据与数据治理|大数据治理(第九篇)nbsp nbsp 魅族大数据平台的一个技术分享活动 话题是 大数据治理之路 魅族大数据平台工作人员分享了一些他们的大数据治理经验 很有内容 首先 他们整理了一个治理流程 架构图然后 依照架构图 大致讲了架构图中的每个模块 以及将模块串联起来的一个管理流程 流程图如下 然后 依照架构图 大致讲了架构图中的每个模块 以及将模块串联起来的一个管理流程 流程图如下 nbsp 流程图上面 其中 主数据管

    2026年3月16日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号