Spatial Transformer Network_transgression

Spatial Transformer Network_transgression导读上一篇通俗易懂的SpatialTransformerNetworks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorch和numpy来实现了他们,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块插入到CNN中STN关键点解读STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖θ\thetaθ参数来实现的。刚开

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

导读

上一篇通俗易懂的Spatial Transformer Networks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorchnumpy来实现了,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块嵌入到CNN中

STN关键点解读

STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖 θ \theta θ参数来实现的。刚开始的时候我还以为训练STN还需要准备 θ \theta θ标签数据,实际上并不需要。

当输入图片通过STN模块之后获得变换后的图片,然后我们再将变换后的图片输入到CNN网络中,通过损失函数计算loss,然后计算梯度更新 θ \theta θ参数,最终STN模块会学习到如何矫正图片。

代码实现

  • 导包
import torch,torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
import numpy as np
from torchsummary import summary
import argparse
  • 定义网络结构
class STN_Net(nn.Module):
    def __init__(self,use_stn=True):
        super(STN_Net, self).__init__()
        self.conv1 = nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = nn.Conv2d(10,20,kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320,50)
        self.fc2 = nn.Linear(50,10)
        #用来判断是否使用STN
        self._use_stn = use_stn

        #localisation net
        #从输入图像中提取特征
        #输入图片的shape为(-1,1,28,28)
        self.localization = nn.Sequential(
            #卷积输出shape为(-1,8,22,22)
            nn.Conv2d(1,8,kernel_size=7),
            #最大池化输出shape为(-1,1,11,11)
            nn.MaxPool2d(2,stride=2),
            nn.ReLU(True),
            #卷积输出shape为(-1,10,7,7)
            nn.Conv2d(8,10,kernel_size=5),
            #最大池化层输出shape为(-1,10,3,3)
            nn.MaxPool2d(2,stride=2),
            nn.ReLU(True)
        )
        #利用全连接层回归\theta参数
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3,32),
            nn.ReLU(True),
            nn.Linear(32,2*3)
        )

        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0]
        ,dtype=torch.float))

    def stn(self,x):
        #提取输入图像中的特征
        xs = self.localization(x)
        xs = xs.view(-1,10*3*3)
        #回归theta参数
        theta = self.fc_loc(xs)
        theta = theta.view(-1,2,3)

        #利用theta参数计算变换后图片的位置
        grid = F.affine_grid(theta,x.size())
        #根据输入图片计算变换后图片位置填充的像素值
        x = F.grid_sample(x,grid)

        return x

    def forward(self,x):
        #使用STN模块
        if self._use_stn:
            x = self.stn(x)
        #利用STN矫正过的图片来进行图片的分类
        #经过conv1卷积输出的shape为(-1,10,24,24)
        #经过max pool的输出shape为(-1,10,12,12)
        x = F.relu(F.max_pool2d(self.conv1(x),2))
        #经过conv2卷积输出的shape为(-1,20,8,8)
        #经过max pool的输出shape为(-1,20,4,4)
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
        x = x.view(-1,320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x,dim=1)
  • 加载数据集
def get_dataloader(batch_size):
    # 加载数据集
    # 如果GPU可用就用GPU,否则用CPU
    device = torch.device("cuda" if torch.cuda.is_available()
    					   else "cpu")
    # 加载训练集
    train_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(root="D:/dataset", train=True, download=True,
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                       ])), batch_size=batch_size, shuffle=True)

    # 加载测试集
    test_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(root="D:/dataset", train=False,
                       transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                       ])), batch_size=batch_size, shuffle=True)

    return train_dataloader,test_dataloader
  • 训练模型
def train(net,epoch_nums,lr,train_dataloader,per_batch,device):
    #使用训练模式
    net.train()
    #选择梯度下降优化算法
    optimizer = optim.SGD(net.parameters(),lr=lr)
    #训练模型
    for epoch in range(epoch_nums):
        for batch_idx,(data,label) in enumerate(train_dataloader):
            data,label = data.to(device),label.to(device)

            optimizer.zero_grad()
            pred = net(data)
            loss = F.nll_loss(pred,label)
            loss.backward()
            optimizer.step()

            if batch_idx % per_batch == 0:
                print("Train Epoch:{ 
   } [{ 
   }/{ 
   } ({ 
   :.0f}%)]\tLoss:
                { 
   :.6f}".format(epoch,batch_idx * len(data),
                len(train_dataloader.dataset),
                100. * batch_idx /len(train_dataloader),loss.item()))
  • 评估模型
def evaluate(net,test_dataloader,device):
    with torch.no_grad():
        #使用评估模式
        net.eval()
        eval_loss = 0
        eval_acc = 0
        for data,label in test_dataloader:
            data,label = data.to(device),label.to(device)
            pred = net(data)

            eval_loss += F.nll_loss(pred,label,
            size_average=False).item()
            pred_label = pred.max(1,keepdim=True)[1]
            eval_acc += pred_label.eq(label.view_as(pred_label)
            ).sum().item()

        eval_loss /= len(test_dataloader.dataset)
        print("evaluate set: Average loss: { 
   :.4f},Accuracy:{ 
   }/{ 
   } 
        ({ 
   :.2f}%)\n".format(
            eval_loss,eval_acc,len(test_dataloader.dataset),
            100*eval_acc / len(test_dataloader.dataset)))
  • 将pytorch的tensor转换为numpy的array
def tensor_to_array(img_tensor):
    img_array = img_tensor.numpy().transpose((1,2,0))
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    img_array = std * img_array + mean
    img = np.clip(img_array,0,1)
    return img
  • 可视化STN变换图片
def visualize_stn(net,dataloader,device):
    with torch.no_grad():
        data = next(iter(dataloader))[0].to(device)

        input_tensor = data.cpu()
        t_input_tensor = net.stn(data).cpu()

        in_grid = tensor_to_array(torchvision.utils.make_grid(
        input_tensor))
        out_grid = tensor_to_array(torchvision.utils.make_grid(
        t_input_tensor))

        f,axarr = plt.subplots(1,2)
        axarr[0].imshow(in_grid)
        axarr[0].set_title("input images")

        axarr[1].imshow(out_grid)
        axarr[1].set_title("stn transformed images")

        plt.show()

在这里插入图片描述
通过对比输入图片和经过STN变换后的图片能够很明显发现,经过STN之后能将旋转的图片进行明显的纠正。

  • 参数设置
def parse_args():
    parse = argparse.ArgumentParser("config stn args")
    parse.add_argument("--lr",default=0.01,
    type=float,help="learning rate")
    parse.add_argument("--epoch_nums",default=20,
    type=int,help="iterated epochs")
    parse.add_argument("--use_stn",default=True,
    type=bool,help="whether to use STN module")
    parse.add_argument("--batch_size",default=64,
    type=int,help="batch size")
    parse.add_argument("--use_eval",default=True,
    type=bool,help="whether to evaluate")
    parse.add_argument("--use_visual",default=True,
    type=bool,help="visual STN transform image")
    parse.add_argument("--use_gpu",default=True,
    type=bool,help="whether to use GPU")
    parse.add_argument("--show_net_construct",default=False,
    type=bool,help="print net construct info")
    return parse.parse_args()
  • 主函数
if __name__ == "__main__":
    args = parse_args()
    if args.use_gpu and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    #加载数据集
    train_loader,test_loader = get_dataloader(args.batch_size)
    #创建网络
    net = STN_Net(args.use_stn).to(device)
    #打印网络的结构信息
    if args.show_net_construct:
        summary(net,(1,28,28))
    #训练模型
    train(net,args.epoch_nums,args.lr,train_loader
    ,args.batch_size,device)
    if args.use_eval:
        #评估模型
        evaluate(net,test_loader,device)
    if args.use_visual:
        #可视化展示效果
        visualize_stn(net,test_loader,device)

参考:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/183487.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 机器学习小组知识点10:多项式分布(Mutibinomial distribution)

    机器学习小组知识点10:多项式分布(Mutibinomial distribution)介绍把二项分布公式再推广,就得到了多项分布。二项分布的典型例子是扔硬币,硬币正面朝上概率为pp,重复扔nn次硬币,kk次为正面的概率即为一个二项分布概率。(严格定义见二项分布中伯努利实验定义)把二项扩展为多项就得到了多项分布。比如扔骰子,不同于扔硬币,骰子有6个面对应6个不同的点数,这样单次每个点数朝上的概率都是16\frac{1}{6}(对应p1p_1至p6p_6,它们的值不一定都是16\f

    2022年8月31日
    0
  • 交换机上uplink端口的作用是什么_uplink怎么用的

    交换机上uplink端口的作用是什么_uplink怎么用的PoE交换机是如今安防行业使用很广泛的一种设备,因为它是是一种为远程交换机(如IP电话或摄像机)提供电力和数据传输的交换机,具有非常重要的作用。而在使用PoE交换机时,就有朋友咨询到,有的PoE交换机上标着PoE,另外也看到有的标着PoE+。那么,PoE交换机与PoE+有什么区别呢?接下来就由飞畅科技的小编来为大家详细介绍下吧!1、什么是PoE交换机PoE交换机由IEEE802.3af标准定义,…

    2022年10月4日
    0
  • 《人工神经网络原理》读书笔记(六)-Boltzmann机[通俗易懂]

    《人工神经网络原理》读书笔记(六)-Boltzmann机[通俗易懂]全部笔记的汇总贴:《人工神经网络原理》-读书笔记汇总一、随机型神经网络的提出BP和Hopfield网络陷入局部最小点的原因网络误差或能量函数构成了含有多个极小点的非线性超曲面;网络误差或能量函数只能按照梯度下降方向单调变化,而不能有任何上升趋势。随机型神经网络的基本思想不但能够让网络误差或能量函数按照梯度下降方向变化,也能够让它们按照某种方式向梯度上升方向变化,这样才有可能使网络跳出局部极小点而向全局极小点收敛。随机型神经网络的特点神经元的输出状态有概率决定;网络连接权值的调整

    2022年7月15日
    13
  • 用了下FIREBIRD,发现真的不错哦

    用了下FIREBIRD,发现真的不错哦

    2021年7月30日
    71
  • 创建构建工程的方法(待完成)

    创建构建工程的方法(待完成)

    2021年5月8日
    126
  • 手把手教你搭建Android开发环境

    手把手教你搭建Android开发环境搭建开发环境,是学习一门技术的开始。参照网上的教程,整理了一下。进行Android开发应用开发时,首先需要有JDK和AndroidSDK的支持,还需要开发工具。在AndriodStudio2.2开始,安装AndroidStudio时,会自动安装JDK和AndroidSDK。下载网址:https://developer.android.google.cn/studio/一、进入网址,点击下载安卓工作室(原英文版,翻译后的界面了)二、点击下载后,出现协议界面,勾选同意,下载。三、双击刚

    2022年7月23日
    7

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号