自动编码器复习
生成式对抗网络介绍
训练判别器
训练生成器
数据集
class Image_data(Dataset): def __init__(self,img_h=256,img_w=256,path,data_path,label_path,process): self.img_h=imgz_h self.img_w=img_w self.path=path self.data_path=data_path self.label_path=label_path self.process=process self.img_data=os.listdir(self.path+'/'+self.data_path) def __len__(self): return len(self.img_data) def __getitem__(self,item): image_name=self.img_data[item] label_name=image_name.split('.')[0] image_path=self.path+'/'+self.data_path+'/'+image_name label_path=self.path+'/'+self.label_path+'/'+labe_name+'.jpg' image=Image.open(image_path) label=Image.open(label_path) if self.process: transforms_image=transforms.Compose([ transforms.Resize([self.img_h,self.img_w]), transforms.Totensor(), transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) image=transforms_image(image) transforms_label=transforms.Compose([ transforms.Resize([self.img_h,self.img_w]), transfoms.ToTensor()]) label=transforms_label(label) return image,label
模型构建
生成器
class conv_block(nn.Module): def __init__(self,ch_in,ch_out): super(conv_block, self).__init__() self.conv=nn.Sequential( nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True), nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): return self.conv(x) class conv_up(nn.Module): def __init__(self,ch_in,ch_out): super(conv_up, self).__init__() self.conv=nn.Sequential( nn.Upsample(scale_factor=2), nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(ch_out), nn.ReLU(inplace=True) ) def forward(self,x): return self.conv(x) class U_net(nn.Module): def __init__(self,ch_in,ch_out): super(U_net, self).__init__() self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2) self.conv1=conv_block(ch_in=ch_in,ch_out=32) self.conv2=conv_block(ch_in=32,ch_out=64) self.conv3=conv_block(ch_in=64,ch_out=128) self.conv4=conv_block(ch_in=128,ch_out=256) self.conv5=conv_block(ch_in=256,ch_out=512) self.up5=conv_up(ch_in=512,ch_out=256) self.conv_up5=conv_block(ch_in=512,ch_out=256) self.up4=conv_up(ch_in=256,ch_out=128) self.conv_up4=conv_block(ch_in=256,ch_out=128) self.up3=conv_up(ch_in=128,ch_out=64) self.conv_up3=conv_block(ch_in=128,ch_out=64) self.up2=conv_up(ch_in=64,ch_out=32) self.conv_up2=conv_block(ch_in=64,ch_out=32) self.conv1_1=nn.Conv2d(in_channels=32,out_channels=ch_out,kernel_size=1,stride=1,padding=0) def forward(self,x): x1=self.conv1(x) x2=self.maxpool(x1) x2=self.conv2(x2) x3=self.maxpool(x2) x3=self.conv3(x3) x4=self.maxpool(x3) x4=self.conv4(x4) x5=self.maxpool(x4) x5=self.conv5(x5) d5=self.up5(x5) d5=torch.cat((x4,d5),dim=1) d5=self.conv_up5(d5) d4=self.up4(d5) d4=torch.cat((x3,d4),dim=1) d4=self.conv_up4(d4) d3=self.up3(d4) d3=torch.cat((x2,d3),dim=1) d3=self.conv_up3(d3) d2=self.up2(d3) d2=torch.cat((x1,d2),dim=1) d2=self.conv_up2(d2) d1=self.conv1_1(d2) d1=torch.sigmoid(d1) return d1
判别器
class CNN(nn.Module): def __init__(self,ch_in,num_class=1): super(CNN, self).__init__() ndf=32 self.dis=nn.Sequential( conv_block(ch_in=ch_in,ch_out=ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=ndf,ch_out=2*ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=2*ndf,ch_out=4*ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=4*ndf,ch_out=8*ndf), nn.MaxPool2d(kernel_size=2,stride=2), conv_block(ch_in=8*ndf,ch_out=16*ndf) ) self.fc=nn.Sequential( nn.Linear(16*ndf,num_class), nn.Sigmoid(), ) self.avg_pool=nn.AdaptiveAvgPool2d((1,1)) def forward(self,x): out=self.dis(x) out=self.avg_pool(out) out=out.view(out.size(0),-1) out=self.fc(out) return out
模型训练
class Trainer(object): def __init__(self,ch_in=3,ch_out=3,epoch=50,batchsize=16,lr=0.005,dataset=None): self.ch_in=ch_in self.ch_out=ch_out self.epoch=epoch self.batchsize=batchsize self.lr=lr self.dataset_loader=Dataloader(dataset=dataset,batch_size=self.batch_size,shuffle=True) #生成器 self.gen=U_net(ch_in=self.ch_in,ch_out=self.ch_out) self.gen_optimizer=torch.optim.Adam(self.gen.parameters(),lr=self.lr) self.gen_loss=nn.L1Loss() #判别器 self.dis=CNN(ch_in=self.ch_in*2,num_class=1) self.dis_optimizer=torch.optim.Adam(self.dis.parameters(),lr=self.lr) self.dis_loss=nn.BCELoss() def set_requires_grad(self,nets,requires_grad): if not isinstance(nets,list): nets=[nets] for net in nets: if net is not None: for parm in net.parameters(): parm.requires_grad=requires_grad def train(self): Tensor=torch.FloatTensor for epoch in range(self.epoch): epoch_gen_loss=0 epoch_dis_loss=0 for i,(bx,by) in enumerate(self.dataset_loader): one_label=Variable(Tensor(bx.size(0),1).fill_(1.0),requires_grad=False) zero_label=Variable(Tensor(bx.size(0),1).fill_(0),requires_grad=False) if i%2==0: #训练生成器 self.set_requires_grad(self.dis,False) self.set_erquires_grad(self.gen,True) bx_gen=self.gen(bx) loss_rec=self.gen_loss(bx_gen,by) fake_ab=torch.cat([bx_gen,bx],dim=1) dis_fake=self.dis(fake_ab) loss_gen=self.gen_loss(dis_fake,one_label) loss_gen=loss_gen+100*loss_rec self.gen_optimizer.zero_grad() loss_gen.backward() self.gen_optimizer.step() epoch_gen_loss+=loss_gen.item() print('生成器损失',loss_gen.item()) else: #训练判别器 self.set_requires_grad(self.dis.False) self.set_requires_grad(self.gen,True) bx_gen=self.gen(bx) fake_ab=torch.cat([bx_gen,bx],dim=1) dis_fake=self.dis(fake_ab) real_ab=torch.cat([by,bx],dim=1) dis_real=self.dis(real_ab) loss_fake=self.gen_loss(dis_fake,zero_label) loss_real=self.gen_loss(dis_real,one_label) loss_dis=(loss_fake+loss_real)/2 self.dis_optimizer.zero_grad() loss_dis.backward() self.dis_optimizer.step() epoch_dis_loss+=loss_dis.item() print('判别器损失',loss_dis.item())
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/213038.html原文链接:https://javaforall.net
