BatchNorm1d

BatchNorm1d参考:https://zhuanlan.zhihu.com/p/100672008https://www.jianshu.com/p/2b94da24af3b#python3.8#-*-coding:utf-8-*-#—#@Software:PyCharm#@File:test2.py#@Author:—#@Institution:BeiJing,China#@E-mail:lgdyangninghua@163.com#@Site:

大家好,又见面了,我是你们的朋友全栈君。

参考:

https://zhuanlan.zhihu.com/p/100672008

https://www.jianshu.com/p/2b94da24af3b

https://github.com/ptrblck/pytorch_misc

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test2.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 5月 19, 2021
# ---
import torch
import torch.nn as nn
import numpy as np
np.random.seed(10)
torch.manual_seed(10)

data = np.array([[1, 2, 7],
                 [1, 3, 9],
                 [1, 4, 6]]).astype(np.float32)
bn_torch = nn.BatchNorm1d(num_features=3)
data_torch = torch.from_numpy(data)
bn_output_torch = bn_torch(data_torch)
print("bn_output_torch:", bn_output_torch)



def fowardbn(x, gam, beta, ):
    '''
    x:(N,D)维数据
    '''
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=0)
    var = x.var(dim=0,unbiased=False)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    print("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache

class MyBN:
    def __init__(self, momentum, eps, num_features):
        """
        初始化参数值
        :param momentum: 追踪样本整体均值和方差的动量
        :param eps: 防止数值计算错误
        :param num_features: 特征数量
        """
        # 对每个batch的mean和var进行追踪统计
        self._running_mean = 0
        self._running_var = 1
        # 更新self._running_xxx时的动量
        self._momentum = momentum
        # 防止分母计算为0
        self._eps = eps
        # 对应论文中需要更新的beta和gamma,采用pytorch文档中的初始化值
        self._beta = np.zeros(shape=(num_features, ))
        self._gamma = np.ones(shape=(num_features, ))
    def batch_norm(self, x):
        """
        BN向传播
        :param x: 数据
        :return: BN输出
        """
        x_mean = x.mean(axis=0)
        x_var = x.var(axis=0)
        # 对应running_mean的更新公式
        self._running_mean = (1-self._momentum)*x_mean + self._momentum*self._running_mean
        self._running_var = (1-self._momentum)*x_var + self._momentum*self._running_var
        # 对应论文中计算BN的公式
        x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)
        y = self._gamma*x_hat + self._beta
        print("x_mean:", x_mean, "x_var:", x_var, "self._gamma:", self._gamma, "self._beta:", self._beta)
        return y

my_bn = MyBN(momentum=0.1, eps=1e-05, num_features=3)
my_bn._beta = bn_torch.bias.detach().numpy()
my_bn._gamma = bn_torch.weight.detach().numpy()
bn_output = my_bn.batch_norm(data, )
print("MyBN bn_output:", bn_output)


out, cache = fowardbn(data_torch.detach(), bn_torch.weight.detach(), bn_torch.bias.detach())
print("fowardbn out2: ", out)

 

 

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 5月 19, 2021
# ---

import numpy as np
np.set_printoptions(suppress = True)
import torch
import torch.nn as nn
np.random.seed(10)
torch.manual_seed(10)

# import pprint
# np.random.seed(10)
# norm = np.random.normal(size=(5, 5))
# pprint.pprint(norm)

data = [
    [0.1, 0.3, 0.4],
    [0.5, 0.3, 0.2],
    [0.4, 0.6, 0.1],
    [0.5, 0.3, 0.2],
]
data_np = np.array(data, dtype=np.float32)*10; print("data_np.shape:", data_np.shape);
data_np = data_np.reshape((3,-1)); print("data_np.shape:", data_np.shape);
t_data = torch.from_numpy(data_np); t_data = torch.unsqueeze(t_data, dim=0)
print("t_data.shape:", t_data.shape); print(t_data)

class PointNet(nn.Module):
    def __init__(self):
        super(PointNet, self).__init__()
        #pytorch之nn.Conv1d详解 https://blog.csdn.net/sunny_xsc1994/article/details/82969867
        self.conv1 = torch.nn.Conv1d(3, 5, 1)
        self.bn1 = nn.BatchNorm1d(5)
        #Pytorch权值初始化及参数分组 https://blog.csdn.net/Bear_Kai/article/details/99302341
        #Pytorch 实现权重初始化 https://www.jb51.net/article/177617.htm
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                m.weight.data.normal_(0, 1)
                if m.bias is not None:
                    m.bias.data.zero_()
                self.weight = np.asarray(m.weight.data)
                #print("nn.Conv1d:", m.weight.data)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(5) #1
                m.bias.data.zero_()

    def forward(self, x):
        result1 = self.conv1(x)
        result2 = self.bn1(result1)
        return result1, result2, self.weight

pn = PointNet()
result1, result2, weight = pn(t_data); weight = torch.from_numpy(weight)
print("weight.shape:", weight.shape); print("weight:", weight)
print("result1.shape:", result1.shape); print(result1)
print("result2.shape:", result2.shape); print(result2)
#print("result1_end:", pn.bn1(result1))
#PointNet论文复现及代码详解 https://zhuanlan.zhihu.com/p/86331508
for n in range(t_data.shape[2]):
    sum = []
    for m in range(weight.shape[0]):
        #Pytorch总结之乘法 https://zhuanlan.zhihu.com/p/212461087
        #sum += (torch.mul(t_data[0,:,0], weight[m,:,0]))#对应位相乘
        sum.append(torch.dot(t_data[0, :, n], weight[m, :, 0]))#点乘
    print("sum:", sum)
#pytorch nn.BatchNorm1d 与手动python实现不一样--解决办法 https://www.jianshu.com/p/2b94da24af3b
#https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
def fowardbn(x, gam, beta, dim=0):
    '''
    x:(N,D)维数据
    '''
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 5 #1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=dim)
    var = x.var(dim=dim, unbiased=False)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    print("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache



#如果是B*C*(H*W)
#1, 3_Iup, 4
#3_Iup, 5_Out, 1 卷积核
#1, 5_Out(channel), 4
bn_re = result1.permute(0, 2, 1)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=1)
out = out.permute(0, 2, 1)
print("out1", out)

bn_re = result1.squeeze()
bn_re = bn_re.permute(1, 0)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=0)
out = out.permute(1, 0)
print("out2", out)

x = np.array([[-1.2089,  6.8342, -0.3317, -5.2298],
         [ 2.5075,  9.6109,  8.8057,  9.0995],
         [ 4.2763,  1.2605,  6.7774, 11.4138],
         [ 1.0103,  1.0549,  0.3408,  0.0656],
         [-2.2381,  1.9428, -3.6522, -7.8491]])
x = x.mean(axis=1)
y = np.array([-2.2381,  1.9428, -3.6522, -7.8491])
y = y.mean(axis=0)
print(x)
print(y)

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/144784.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • java的基础代码_java编程入门基础教程

    java的基础代码_java编程入门基础教程1.编写java源文件,认识java基本程序结构。创建一个文本文件,并重命名为”HelloWorld.java”用记事本打开,编写一段Java代码如下面所示例子所示。ClassHelloWorld.java{//main是程序的入口,所有程序都是从此处开始运行Publicstaticvoidmain(String[]arge){//在屏幕中打印输出“HelloWorld!”语句System.out.println(“HelloWorld”);}}2.下面对每条语句

    2022年10月17日
    2
  • win10图标变白纸_同是Office365,为什么你的软件图标还是旧版的?

    win10图标变白纸_同是Office365,为什么你的软件图标还是旧版的?为什么你的office365套件最新版的图标还是旧版?是新版图标还没向正式版用户推送吗?我的office365的账号有问题吗?难道是我打开的方式不对吗?旧版图标新版图标打开产品信息一看你的版本信息是这样的:版本号1808、半年频道怎么点更新都是显示已经到了最新的版本对吧?再看看新版图标的office365的产品信息是这样的:版本1904、每月频道这就定位到问题了,同是office365专业增强订阅…

    2022年10月19日
    2
  • 怎样推断一棵二叉树是全然二叉树

    怎样推断一棵二叉树是全然二叉树

    2021年12月1日
    53
  • JVM内存分配与管理详解

    JVM内存分配与管理详解概述了解C++的程序员都知道,在内存管理领域,都是由程序员维护与管理,程序员用于最高的管理权限,但对于java程序员来说,在内存管理领域,程序员不必去关心内存的分配以及回收,在jvm自动内存管理机制的帮助下,不需要想C++一样为每一个new操作去编写delete/free代码,这一切交给jvm,但正是这一切都交给了jvm,一旦出现内存泄漏与溢出,如果不了jvm,那么对于程序的编写与调试将会非常

    2022年6月1日
    55
  • 手机APP测试(测试点、测试流程、功能测试)

    手机APP测试(测试点、测试流程、功能测试)1、功能测试1.1启动APP安装完成后,是否可以正常打开,稳定运行APP的速度是可以让人接受,切换是否流畅网络异常时,应用是否会崩溃:在请求超时的情况下,如果程序逻辑处理的不好,就有可能发生

    2022年7月3日
    30
  • EL表达式详解_EL表达式问内置对象属性值

    EL表达式详解_EL表达式问内置对象属性值 EL表达式   1、EL简介1)语法结构    ${expression}2)[]与.运算符   EL提供.和[]两种运算符来存取数据。   当要存取的属性名称中包含一些特殊字符,如.或?等并非字母或数字的符号,就一定要使用[]。例如:     ${user.My-Name}应当改为${user

    2022年7月28日
    8

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号