stn  pytorch[通俗易懂]

stn  pytorch[通俗易懂]#-*-coding:utf-8-*-"""SpatialTransformerNetworksTutorial=====================================**Author**:`GhassenHAMROUNI<https://github.com/GHamrouni>`_..figure::/_static/img/…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

# -*- coding: utf-8 -*-
"""
Spatial Transformer Networks Tutorial
=====================================
**Author**: `Ghassen HAMROUNI <https://github.com/GHamrouni>`_

.. figure:: /_static/img/stn/FSeq.png

In this tutorial, you will learn how to augment your network using
a visual attention mechanism called spatial transformer
networks. You can read more about the spatial transformer
networks in the `DeepMind paper <https://arxiv.org/abs/1506.02025>`__

Spatial transformer networks are a generalization of differentiable
attention to any spatial transformation. Spatial transformer networks
(STN for short) allow a neural network to learn how to perform spatial
transformations on the input image in order to enhance the geometric
invariance of the model.
For example, it can crop a region of interest, scale and correct
the orientation of an image. It can be a useful mechanism because CNNs
are not invariant to rotation and scale and more general affine
transformations.

One of the best things about STN is the ability to simply plug it into
any existing CNN with very little modification.
"""
# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np


from tensorboardX import SummaryWriter
#from logger import Logger
#logger = Logger('./logs')

plt.ion()   # interactive mode

######################################################################
# Loading the data
# ----------------
#
# In this post we experiment with the classic MNIST dataset. Using a
# standard convolutional network augmented with a spatial transformer
# network.

use_cuda = torch.cuda.is_available()

# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=128, shuffle=True, num_workers=4)


# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=64, shuffle=True, num_workers=4)

######################################################################
# Depicting spatial transformer networks
# --------------------------------------
#
# Spatial transformer networks boils down to three main components :
#
# -  The localization network is a regular CNN which regresses the
#    transformation parameters. The transformation is never learned
#    explicitly from this dataset, instead the network learns automatically
#    the spatial transformations that enhances the global accuracy.
# -  The grid generator generates a grid of coordinates in the input
#    image corresponding to each pixel from the output image.
# -  The sampler uses the parameters of the transformation and applies
#    it to the input image.
#
# .. figure:: /_static/img/stn/stn-arch.png
#
# .. Note::
#    We need the latest version of PyTorch that contains
#    affine_grid and grid_sample modules.
#


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.fill_(0)
        self.fc_loc[2].bias.data = torch.FloatTensor([1, 0, 0, 0, 1, 0])

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)  #equal to reshape()  64x3x3x10-> 64*90
        theta = self.fc_loc(xs)   #64x6
        theta = theta.view(-1, 2, 3)  #reshape 64x2x3 Transform matrix

        grid = F.affine_grid(theta, x.size())
        #theta (Variable): input batch of affine matrices (N x 2 x 3)  64 x 2 x 3
        #size (torch.Size): the target output image size (N x C x H x W)  64x1x28x28
        #output (Variable): output Tensor of size (N x H x W x 2)         64x28x28x2

        x = F.grid_sample(x, grid)
        '''
        Args:
        input (Variable): input batch of images (N x C x IH x IW)
        grid (Variable): flow-field of size (N x OH x OW x 2)
        padding_mode (str): padding mode for outside grid values
            'zeros' | 'border'. Default: 'zeros'
        output: N x OH x OW x C
        '''
        
        return x  #64 x 28 x28 x 1

    def forward(self, x):  #x: 64 x28 x 28 x 1
        # transform the input
        x = self.stn(x)   # 64 x 28 x 28 x 1

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 64 x 12 x 12 x 10
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) #64 x 4 x 4 x 20
        x = x.view(-1, 320)  #64x320
        x = F.relu(self.fc1(x)) #64x50
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)  #64x10
        return F.log_softmax(x, dim=1)


model = Net()
if use_cuda:
    model.cuda()

######################################################################
# Training the model
# ------------------
#
# Now, let's use the SGD algorithm to train the model. The network is
# learning the classification task in a supervised way. In the same time
# the model is learning STN automatically in an end-to-end fashion.
'''现在,我们使用SGD算法来训练模型。网络以监督的方式学习分类任务。与此同时,该模型以端到端的方式自动学习STN。'''

optimizer = optim.SGD(model.parameters(), lr=0.01)


def train(epoch):
    #调用钱箱传播
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if use_cuda:
            data, target = data.cuda(), target.cuda()

        data, target = Variable(data), Variable(target) #定义为Variable类型,能够调用autograd
        optimizer.zero_grad()#初始化时,要清空梯度
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()#相当于更新权重值
        '''
    在实现梯度反向传递时主要需要三步:
    初始化梯度值:net.zero_grad()
    反向求解梯度:loss.backward()
    更新参数:optimizer.step()
'''


#if batch_idx % args.log_interval == 0:????????
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))
#
# A simple test procedure to measure STN the performances on MNIST.
#


def test():
    model.eval()   #让模型变为测试模式,主要是保证dropout和BN和训练过程一致。BN是指batch normalization
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)

        # sum up batch loss
        test_loss += F.nll_loss(output, target, size_average=False).item()
        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=True)[1]     #获得得分最高的类别
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
          .format(test_loss, correct, len(test_loader.dataset),
                  100. * correct / len(test_loader.dataset)))
'''
    将一个data pass分成几个mini-batch
    每一个mini-batch,F.nll_loss(output, target).data[0]的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True的参数(后面会用到)
    进行一次data pass之后,就可以将每一个mini-batch average loss求和
    将loss之和再除以mini-batch的数量,就得到最后的data point average loss

所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。

更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:

for each mini-batch:
    ...
    test_loss += F.nll_loss(output, target, size_average=False).data[0]
    ...
...
test_loss /= len(test_loader.dataset)'''
tot_time=0;

######################################################################
# Visualizing the STN results
# ---------------------------
#
# Now, we will inspect the results of our learned visual attention
# mechanism.
#
# We define a small helper function in order to visualize the
# transformations while training.


def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.


def visualize_stn():
    # Get a batch of training data
    data, _ = next(iter(test_loader))
    data = Variable(data)  #修改, volatile=True

    if use_cuda:
        data = data.cuda()

    input_tensor = data.cpu().data
    transformed_input_tensor = model.stn(data).cpu().data

    in_grid = convert_image_np(
        torchvision.utils.make_grid(input_tensor))

    out_grid = convert_image_np(
        torchvision.utils.make_grid(transformed_input_tensor))

    # Plot the results side-by-side
    f, axarr = plt.subplots(1, 2)
    axarr[0].imshow(in_grid)
    axarr[0].set_title('Dataset Images')

    axarr[1].imshow(out_grid)
    axarr[1].set_title('Transformed Images')


#for epoch in range(1, 20 + 1):  for epoch in range(1, args.epochs + 1):
for epoch in range(1, 4 + 1):
    train(epoch)
    test()

# Visualize the STN transformation on some input batch
visualize_stn()

plt.ioff()
plt.show()
print("Total time= {:.3f}s".format(tot_time))
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/180403.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • python .txt文件读取及数据处理总结

    python .txt文件读取及数据处理总结1、处理包含数据的文件最近利用Python读取txt文件时遇到了一个小问题,就是在计算两个np.narray()类型的数组时,出现了以下错误:TypeError:ufunc’subtract’didnotcontainaloopwithsignaturematchingtypesdtype(‘

    2022年5月7日
    105
  • 数据库锁表如何解决_mysql数据库怎么解锁

    数据库锁表如何解决_mysql数据库怎么解锁这个问题之前遇到过一次,但是由于不知道导致锁表的原因,也没细想,就知道表被锁了,然后让别人把表给解锁了。但是前天的一次操作,让我亲眼见证了导致锁表的过程,以及如何给lock的表解锁。1.导致锁表的原因(同志们也可以参考是不是也是同样的操作啊。。。):1.1首先是大前提我们正常的框架在service层都会有事物控制,比如我一个service层的方法要执行更新两张表,这两个表只有同…

    2022年8月23日
    4
  • ubuntu安装中文输入法搜狗_中文输入法怎么调出来

    ubuntu安装中文输入法搜狗_中文输入法怎么调出来请注意命令中不应该的空格可能导致命令不合法!一、检查fctix框架首先,要安装中文输入法,必须要保证系统上有fctix。fctix是一个以GPL方式发布的输入法框架,安装fctix后可以为操作系统的桌面环境提供一个灵活的输入方案,解决在GNU/Linux环境下安装中文输入法的问题。win+a打开所有应用程序,找到Language…

    2022年9月26日
    0
  • pycharm远程运行_pycharm打开远程项目

    pycharm远程运行_pycharm打开远程项目视频见:https://www.bilibili.com/video/av54728208一、计算机系统组成cpu:处理指令(比如点击鼠标)和运算数据内存:存储cpu要处理的数据,从硬盘中读取,读取速度快,处理好后再存入硬盘二、Python介绍Python是时下最流行、最火爆的编程语言之一,具体原因如下:简单、易学,适应人群广泛免费、开源应用领域广泛备注:以下……

    2022年8月28日
    0
  • 教你如何把M3U8转换成MP4丨NueXini M3U8 Downloader

    教你如何把M3U8转换成MP4丨NueXini M3U8 Downloader准备工具:https://www.lanzous.com/i18d7sh步骤:1.下载软件并且解压到非中文目录,然后打开主程序2.把M3U8的网络地址,或者下载到本地的M3U8文件拖入编辑框(红色箭头)3.点击解析(这里使用网络地址:http://nuexini.shop/ceshi.m3u8)4.点击开始(软件自动下载并且完成合并)5.完成!!!…

    2022年6月24日
    41
  • deep learning with pytorch中文版_pytorch distributed

    deep learning with pytorch中文版_pytorch distributed憨批的语义分割重制版9——Pytorch搭建自己的DeeplabV3+语义分割平台注意事项学习前言什么是DeeplabV3+模型代码下载DeeplabV3+实现思路一、预测部分1、主干网络介绍2、加强特征提取结构3、利用特征获得预测结果二、训练部分1、训练文件详解2、LOSS解析训练自己的DeeplabV3+模型一、数据集的准备二、数据集的处理三、开始网络训练四、训练结果预测注意事项这是重新构建了的DeeplabV3+语义分割网络,主要是文件框架上的构建,还有代码的实现,和之前的语义分割网络相比,更加

    2022年8月21日
    6

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号