BN层pytorch实现[通俗易懂]

BN层pytorch实现[通俗易懂]#CreatedbyXkyat2019/11/29importtimeimporttorchimporttorchvisionimporttorch.nnasnnimportsysimporttorchvision.transformsastransformsfromtorch.utils.data.dataloaderimportDataLoad…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

# Created by Xky at 2019/11/29
import time
import torch
import torchvision
import torch.nn as nn
import sys
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
import torch.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
class FlattenLayer(nn.Module):  # 自己定义层Flattenlayer
    def __init__(self):
        super(FlattenLayer, self).__init__()

    def forward(self, x):  # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)

def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 判断当前模式是训练模式还是预测模式
    if not is_training:
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要保持
            # X的形状以便后面可以做广播运算
            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        # 训练模式下用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 拉伸和偏移
    return Y, moving_mean, moving_var

class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成0和1
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 不参与求梯度和迭代的变量,全在内存上初始化成0
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var, Module实例的traning属性默认为true, 调用.eval()后设成false
        Y, self.moving_mean, self.moving_var = batch_norm(self.training,
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y

net = nn.Sequential(
            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            BatchNorm(6, num_dims=4),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2), # kernel_size, stride
            nn.Conv2d(6, 16, 5),
            BatchNorm(16, num_dims=4),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2),
            FlattenLayer(),
            nn.Linear(16*4*4, 120),
            BatchNorm(120, num_dims=2),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            BatchNorm(84, num_dims=2),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )
net = net.to(device)
# def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):
#     """Download the fashion mnist dataset and then load into memory."""
#     trans = []
#     if resize:
#         trans.append(torchvision.transforms.Resize(size=resize))
#     trans.append(torchvision.transforms.ToTensor())
#
#     transform = torchvision.transforms.Compose(trans)
#     mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)
#     mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)
#     if sys.platform.startswith('win'):
#         num_workers = 0  # 0表示不用额外的进程来加速读取数据
#     else:
#         num_workers = 4
#     train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
#     test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
#
#     return train_iter, test_iter
# batch_size = 256
# train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)


#get Data
batch_size = 256
#transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
                                              train=True, transform=transform)
test_set = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',
                                             train=False, transform=transform)
train_iter = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_iter = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=0)

lr, num_epochs = 0.001, 5
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

# evaluate_accuracy
def evaluate_accuracy(test_iterator, net):
    with torch.no_grad():
        device = list(net.parameters())[0].device
        test_acc_sum = 0.0
        ncount = 0
        for x_test, y_test in test_iterator:
            if isinstance(net, torch.nn.Module):
                net.eval()
                x_test = x_test.to(device)
                y_test = y_test.to(device)
                y_hat = net(x_test)
                test_acc_sum += (y_hat.argmax(dim=1) == y_test).sum().cpu().item()
                ncount+=len(y_test)
                net.train()
        test_acc = test_acc_sum/ncount
        return test_acc
def train(num_epoch):
    for epoch in range(num_epoch):
        l_sum, train_acc_sum, ncount, start = 0.0, 0.0, 0, time.time()
        for x_train, y_train in train_iter:
            x_train = x_train.to(device)
            y_train = y_train.to(device)
            y_hat = net(x_train)
            l = loss(y_hat, y_train)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y_train).sum().cpu().item()
            ncount += y_train.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch: %d, train_loss: %.4f, train_acc: %.4f, test_acc: %.4f , spend_time: %.4f' %
              (epoch+1, l_sum/ncount,train_acc_sum/ncount, test_acc,time.time()-start))


if __name__ == "__main__":
    train(5)
# train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/181968.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • linux 动态库 静态库_静态库里面包含动态库

    linux 动态库 静态库_静态库里面包含动态库动态库与静态库文件系统补完文件的三个时间acm动态库与静态库动态链接与静态链接静态库文件系统补完文件的三个时间acm我们通过stat指令查看文件信息:[lyl@VM-4-3-centos2022-3-14]$statlog.txtFile:‘log.txt’Size:0 Blocks:0IOBlock:4096regularemptyfileDevice:fd01h/64769d Inode:790871

    2022年9月28日
    0
  • Ubuntu下查看cuda版本的两种方法

    Ubuntu下查看cuda版本的两种方法参考资料:https://blog.csdn.net/qq_16525279/article/details/80662217  在安装Pytorch等深度学习框架的时候,我们往往需要检查是否和cuda版本匹配。在Ubuntu系统下查看cuda版本的方法,我发现有两个。方法一:  比较常用的一个方法是使用如下命令:cat/usr/local/cuda/version.txt使用该命令的效果如下图所示:方法一使用效果方法二:  方法一主要是依据cuda安装时保存的关于版本的txt文件。但

    2022年6月12日
    204
  • Windows API——CFile, read, write,typeBinary函数「建议收藏」

    Windows API——CFile, read, write,typeBinary函数「建议收藏」文件操作API和CFile类在VC中,操作文件的方法有两种,一是利用一些API函数来创建,打开,读写文件,另外一个是利用MFC的CFile类,CFile封装了对文件的一般操作。下面酒主要介绍如何利用这两种方法操作文件。1.创建或打开一个文件API函数CreateFile可打开和创建文件、管道、邮槽、通信服务、设备以及控制台,但是在此时只是介绍用这个函数怎么实现创建和打开一个文件。HANDL…

    2022年8月18日
    6
  • Vue(3)webstorm代码格式规范设置与vue模板配置

    Vue(3)webstorm代码格式规范设置与vue模板配置编译器代码格式规范设置通常我们写代码时,代码缩进都是4个空格,但是在前端中,据全球投票统计,建议使用2个空格来进行代码缩进。首先我们打开webstorm中的设置,如果使用的是mac的同学直接使用c

    2022年7月30日
    61
  • sbc 通信_ipc进程间通信

    sbc 通信_ipc进程间通信SBC在企业IP通信系统中的应用刘航2008/05/04  摘要:本文针对企业IP通信系统建设实施的两大问题:终端接入安全和IP多媒体业务NAT穿越,介绍了基于SBC(SessionBorderController,会话边界控制器)的解决方案,并提出了利用SBC辅助实现IP录音的一种新应用模式。  关键词:IP通信、SBC、NAT穿越、安全、IP录音一、引言

    2022年9月12日
    0
  • java九九乘法表代码加解释_java九九乘法表编程代码是什么?

    java九九乘法表代码加解释_java九九乘法表编程代码是什么?展开全部packagech02;publicclassTEST{publicstaticvoidmain(String[]args){for(inti=1;i<=9;i++){for(intj=1;j<=i;j++){System.out.print(j+”*”+i+”=”+(i*j)+””);}System.out.println(…

    2022年7月15日
    15

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号