stn  pytorch[通俗易懂]

stn  pytorch[通俗易懂]#-*-coding:utf-8-*-"""SpatialTransformerNetworksTutorial=====================================**Author**:`GhassenHAMROUNI<https://github.com/GHamrouni>`_..figure::/_static/img/…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

# -*- coding: utf-8 -*-
"""
Spatial Transformer Networks Tutorial
=====================================
**Author**: `Ghassen HAMROUNI <https://github.com/GHamrouni>`_

.. figure:: /_static/img/stn/FSeq.png

In this tutorial, you will learn how to augment your network using
a visual attention mechanism called spatial transformer
networks. You can read more about the spatial transformer
networks in the `DeepMind paper <https://arxiv.org/abs/1506.02025>`__

Spatial transformer networks are a generalization of differentiable
attention to any spatial transformation. Spatial transformer networks
(STN for short) allow a neural network to learn how to perform spatial
transformations on the input image in order to enhance the geometric
invariance of the model.
For example, it can crop a region of interest, scale and correct
the orientation of an image. It can be a useful mechanism because CNNs
are not invariant to rotation and scale and more general affine
transformations.

One of the best things about STN is the ability to simply plug it into
any existing CNN with very little modification.
"""
# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np


from tensorboardX import SummaryWriter
#from logger import Logger
#logger = Logger('./logs')

plt.ion()   # interactive mode

######################################################################
# Loading the data
# ----------------
#
# In this post we experiment with the classic MNIST dataset. Using a
# standard convolutional network augmented with a spatial transformer
# network.

use_cuda = torch.cuda.is_available()

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=128, shuffle=True, num_workers=4)


# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=64, shuffle=True, num_workers=4)

######################################################################
# Depicting spatial transformer networks
# --------------------------------------
#
# Spatial transformer networks boils down to three main components :
#
# -  The localization network is a regular CNN which regresses the
#    transformation parameters. The transformation is never learned
#    explicitly from this dataset, instead the network learns automatically
#    the spatial transformations that enhances the global accuracy.
# -  The grid generator generates a grid of coordinates in the input
#    image corresponding to each pixel from the output image.
# -  The sampler uses the parameters of the transformation and applies
#    it to the input image.
#
# .. figure:: /_static/img/stn/stn-arch.png
#
# .. Note::
#    We need the latest version of PyTorch that contains
#    affine_grid and grid_sample modules.
#


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.fill_(0)
        self.fc_loc[2].bias.data = torch.FloatTensor([1, 0, 0, 0, 1, 0])

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)  #equal to reshape()  64x3x3x10-> 64*90
        theta = self.fc_loc(xs)   #64x6
        theta = theta.view(-1, 2, 3)  #reshape 64x2x3 Transform matrix

        grid = F.affine_grid(theta, x.size())
        #theta (Variable): input batch of affine matrices (N x 2 x 3)  64 x 2 x 3
        #size (torch.Size): the target output image size (N x C x H x W)  64x1x28x28
        #output (Variable): output Tensor of size (N x H x W x 2)         64x28x28x2

        x = F.grid_sample(x, grid)
        '''
        Args:
        input (Variable): input batch of images (N x C x IH x IW)
        grid (Variable): flow-field of size (N x OH x OW x 2)
        padding_mode (str): padding mode for outside grid values
            'zeros' | 'border'. Default: 'zeros'
        output: N x OH x OW x C
        '''
        
        return x  #64 x 28 x28 x 1

    def forward(self, x):  #x: 64 x28 x 28 x 1
        # transform the input
        x = self.stn(x)   # 64 x 28 x 28 x 1

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 64 x 12 x 12 x 10
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) #64 x 4 x 4 x 20
        x = x.view(-1, 320)  #64x320
        x = F.relu(self.fc1(x)) #64x50
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)  #64x10
        return F.log_softmax(x, dim=1)


model = Net()
if use_cuda:
    model.cuda()

######################################################################
# Training the model
# ------------------
#
# Now, let's use the SGD algorithm to train the model. The network is
# learning the classification task in a supervised way. In the same time
# the model is learning STN automatically in an end-to-end fashion.
'''现在,我们使用SGD算法来训练模型。网络以监督的方式学习分类任务。与此同时,该模型以端到端的方式自动学习STN。'''

optimizer = optim.SGD(model.parameters(), lr=0.01)


def train(epoch):
    #调用钱箱传播
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if use_cuda:
            data, target = data.cuda(), target.cuda()

        data, target = Variable(data), Variable(target) #定义为Variable类型,能够调用autograd
        optimizer.zero_grad()#初始化时,要清空梯度
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()#相当于更新权重值
        '''
    在实现梯度反向传递时主要需要三步:
    初始化梯度值:net.zero_grad()
    反向求解梯度:loss.backward()
    更新参数:optimizer.step()
'''


#if batch_idx % args.log_interval == 0:????????
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))
#
# A simple test procedure to measure STN the performances on MNIST.
#


def test():
    model.eval()   #让模型变为测试模式,主要是保证dropout和BN和训练过程一致。BN是指batch normalization
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)

        # sum up batch loss
        test_loss += F.nll_loss(output, target, size_average=False).item()
        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=True)[1]     #获得得分最高的类别
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
          .format(test_loss, correct, len(test_loader.dataset),
                  100. * correct / len(test_loader.dataset)))
'''
    将一个data pass分成几个mini-batch
    每一个mini-batch,F.nll_loss(output, target).data[0]的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True的参数(后面会用到)
    进行一次data pass之后,就可以将每一个mini-batch average loss求和
    将loss之和再除以mini-batch的数量,就得到最后的data point average loss

所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。

更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:

for each mini-batch:
    ...
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    ...
...
test_loss /= len(test_loader.dataset)'''
tot_time=0;

######################################################################
# Visualizing the STN results
# ---------------------------
#
# Now, we will inspect the results of our learned visual attention
# mechanism.
#
# We define a small helper function in order to visualize the
# transformations while training.


def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.


def visualize_stn():
    # Get a batch of training data
    data, _ = next(iter(test_loader))
    data = Variable(data)  #修改, volatile=True

    if use_cuda:
        data = data.cuda()

    input_tensor = data.cpu().data
    transformed_input_tensor = model.stn(data).cpu().data

    in_grid = convert_image_np(
        torchvision.utils.make_grid(input_tensor))

    out_grid = convert_image_np(
        torchvision.utils.make_grid(transformed_input_tensor))

    # Plot the results side-by-side
    f, axarr = plt.subplots(1, 2)
    axarr[0].imshow(in_grid)
    axarr[0].set_title('Dataset Images')

    axarr[1].imshow(out_grid)
    axarr[1].set_title('Transformed Images')


#for epoch in range(1, 20 + 1):  for epoch in range(1, args.epochs + 1):
for epoch in range(1, 4 + 1):
    train(epoch)
    test()

# Visualize the STN transformation on some input batch
visualize_stn()

plt.ioff()
plt.show()
print("Total time= {:.3f}s".format(tot_time))
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/180403.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • asp文件运行方式_asp文件的扩展名

    asp文件运行方式_asp文件的扩展名免费的jsp空间太难申请了,好不容易申请到asp空间,却发现下载不了apk文件,

    2025年6月22日
    1
  • hisi3516dv300学习笔记——编译hisi3516dv300的SDK

    hisi3516dv300学习笔记——编译hisi3516dv300的SDK先下载linux内核源码包,下载地址:https://mirrors.edge.kernel.org/pub/linux/kernel/v4.x/(1)编译整个osdrv目录:注意:默认不发布内核源码包,只发布补丁文件。内核源码包需自行从开源社区上下载。从linux开源社区下载v4.9.37版本的内核:1)进入网站:www.kernel.org2)选择HTTP协议资源的https://www.kernel.org/pub/选项,进入子页面3)选择linux/菜单项,进入子页面4)选择ker

    2022年9月23日
    0
  • 星愿浏览器有什么优点_星愿浏览器插件

    星愿浏览器有什么优点_星愿浏览器插件目的:想基于浏览器进程抓包,但是想获得噪声相对小的数据,则找相对ChromeGoogle等主流browser更简单的浏览器;想使用Google的某个扩展程序,所以找基于Chrome内核的浏览器所以,我要找基于Chrome内核的简单浏览器最后找到了这几个符合条件的浏览器:星愿、百分cent、Vival、Brave星愿优点:星愿的主页面具有相当的自主性,可以自由拖动添加图标和更换背景、搜索框等。其主页有个搜索漫画的功能,好像在看漫画这一块做了一些页面优化。缺点:只能在它提供的星愿商店里下扩.

    2025年6月11日
    0
  • kettle下载安装使用教程

    kettle下载安装使用教程Kettle简介Kettle是一款国外开源的ETL工具,纯java编写,可以在Window、Linux、Unix上运行, 数据抽取高效稳定。Kettle中文名称叫水壶,该项目的主程序员MATT希望把各种数据放到一个壶里,然后以一种指定的格式流出。Kettle这个ETL工具集,它允许你管理来自不同数据库的数据,通过提供一个图形化的用户环境来描述你想做什么,而不是你想怎么做。Kettle中有两…

    2022年5月24日
    28
  • fastclick干什么用的_fast对旅游的作用

    fastclick干什么用的_fast对旅游的作用fastclick是具有消除移动端浏览器上的点击事件的300ms的延迟的作用。注意几点1、PC端无效2、Android上的Chrome32+浏览器,如果在 viewpor

    2022年8月6日
    4
  • idea打开工程无法运行java程序_如何运行一个java程序

    idea打开工程无法运行java程序_如何运行一个java程序有时候想运行别人的项目,但是别人的项目并非IDEA项目(甚至只有源码),当我们打开项目时候,并不能运行,我们却不知道怎么办。经过多次查找和尝试,最终终于能够运行起来了。记一下解决的方法。这是源码打开的项目首先,查看sdk是否设置了。可以在File=》ProjectStructure=》Project里面查看并设置。但是,设置后还是不可以运行。如下图:设置SDK…

    2022年9月29日
    1

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号