BatchNorm1d

BatchNorm1d参考:https://zhuanlan.zhihu.com/p/100672008https://www.jianshu.com/p/2b94da24af3b#python3.8#-*-coding:utf-8-*-#—#@Software:PyCharm#@File:test2.py#@Author:—#@Institution:BeiJing,China#@E-mail:lgdyangninghua@163.com#@Site:

大家好,又见面了,我是你们的朋友全栈君。

参考:

https://zhuanlan.zhihu.com/p/100672008

https://www.jianshu.com/p/2b94da24af3b

https://github.com/ptrblck/pytorch_misc

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test2.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 5月 19, 2021
# ---
import torch
import torch.nn as nn
import numpy as np
np.random.seed(10)
torch.manual_seed(10)

data = np.array([[1, 2, 7],
                 [1, 3, 9],
                 [1, 4, 6]]).astype(np.float32)
bn_torch = nn.BatchNorm1d(num_features=3)
data_torch = torch.from_numpy(data)
bn_output_torch = bn_torch(data_torch)
print("bn_output_torch:", bn_output_torch)



def fowardbn(x, gam, beta, ):
    '''
    x:(N,D)维数据
    '''
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=0)
    var = x.var(dim=0,unbiased=False)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    print("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache

class MyBN:
    def __init__(self, momentum, eps, num_features):
        """
        初始化参数值
        :param momentum: 追踪样本整体均值和方差的动量
        :param eps: 防止数值计算错误
        :param num_features: 特征数量
        """
        # 对每个batch的mean和var进行追踪统计
        self._running_mean = 0
        self._running_var = 1
        # 更新self._running_xxx时的动量
        self._momentum = momentum
        # 防止分母计算为0
        self._eps = eps
        # 对应论文中需要更新的beta和gamma,采用pytorch文档中的初始化值
        self._beta = np.zeros(shape=(num_features, ))
        self._gamma = np.ones(shape=(num_features, ))
    def batch_norm(self, x):
        """
        BN向传播
        :param x: 数据
        :return: BN输出
        """
        x_mean = x.mean(axis=0)
        x_var = x.var(axis=0)
        # 对应running_mean的更新公式
        self._running_mean = (1-self._momentum)*x_mean + self._momentum*self._running_mean
        self._running_var = (1-self._momentum)*x_var + self._momentum*self._running_var
        # 对应论文中计算BN的公式
        x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)
        y = self._gamma*x_hat + self._beta
        print("x_mean:", x_mean, "x_var:", x_var, "self._gamma:", self._gamma, "self._beta:", self._beta)
        return y

my_bn = MyBN(momentum=0.1, eps=1e-05, num_features=3)
my_bn._beta = bn_torch.bias.detach().numpy()
my_bn._gamma = bn_torch.weight.detach().numpy()
bn_output = my_bn.batch_norm(data, )
print("MyBN bn_output:", bn_output)


out, cache = fowardbn(data_torch.detach(), bn_torch.weight.detach(), bn_torch.bias.detach())
print("fowardbn out2: ", out)

 

 

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: test.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 5月 19, 2021
# ---

import numpy as np
np.set_printoptions(suppress = True)
import torch
import torch.nn as nn
np.random.seed(10)
torch.manual_seed(10)

# import pprint
# np.random.seed(10)
# norm = np.random.normal(size=(5, 5))
# pprint.pprint(norm)

data = [
    [0.1, 0.3, 0.4],
    [0.5, 0.3, 0.2],
    [0.4, 0.6, 0.1],
    [0.5, 0.3, 0.2],
]
data_np = np.array(data, dtype=np.float32)*10; print("data_np.shape:", data_np.shape);
data_np = data_np.reshape((3,-1)); print("data_np.shape:", data_np.shape);
t_data = torch.from_numpy(data_np); t_data = torch.unsqueeze(t_data, dim=0)
print("t_data.shape:", t_data.shape); print(t_data)

class PointNet(nn.Module):
    def __init__(self):
        super(PointNet, self).__init__()
        #pytorch之nn.Conv1d详解 https://blog.csdn.net/sunny_xsc1994/article/details/82969867
        self.conv1 = torch.nn.Conv1d(3, 5, 1)
        self.bn1 = nn.BatchNorm1d(5)
        #Pytorch权值初始化及参数分组 https://blog.csdn.net/Bear_Kai/article/details/99302341
        #Pytorch 实现权重初始化 https://www.jb51.net/article/177617.htm
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                m.weight.data.normal_(0, 1)
                if m.bias is not None:
                    m.bias.data.zero_()
                self.weight = np.asarray(m.weight.data)
                #print("nn.Conv1d:", m.weight.data)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(5) #1
                m.bias.data.zero_()

    def forward(self, x):
        result1 = self.conv1(x)
        result2 = self.bn1(result1)
        return result1, result2, self.weight

pn = PointNet()
result1, result2, weight = pn(t_data); weight = torch.from_numpy(weight)
print("weight.shape:", weight.shape); print("weight:", weight)
print("result1.shape:", result1.shape); print(result1)
print("result2.shape:", result2.shape); print(result2)
#print("result1_end:", pn.bn1(result1))
#PointNet论文复现及代码详解 https://zhuanlan.zhihu.com/p/86331508
for n in range(t_data.shape[2]):
    sum = []
    for m in range(weight.shape[0]):
        #Pytorch总结之乘法 https://zhuanlan.zhihu.com/p/212461087
        #sum += (torch.mul(t_data[0,:,0], weight[m,:,0]))#对应位相乘
        sum.append(torch.dot(t_data[0, :, n], weight[m, :, 0]))#点乘
    print("sum:", sum)
#pytorch nn.BatchNorm1d 与手动python实现不一样--解决办法 https://www.jianshu.com/p/2b94da24af3b
#https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
def fowardbn(x, gam, beta, dim=0):
    '''
    x:(N,D)维数据
    '''
    momentum = 0.1
    eps = 1e-05
    running_mean = 0
    running_var = 5 #1
    running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
    running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
    mean = x.mean(dim=dim)
    var = x.var(dim=dim, unbiased=False)
    # bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
    x_hat = (x - mean) / torch.sqrt(var + eps)
    out = gam * x_hat + beta
    print("x_mean:", mean, "x_var:", var, "self._gamma:", gam, "self._beta:", beta)
    cache = (x, gam, beta, x_hat, mean, var, eps)
    return out, cache



#如果是B*C*(H*W)
#1, 3_Iup, 4
#3_Iup, 5_Out, 1 卷积核
#1, 5_Out(channel), 4
bn_re = result1.permute(0, 2, 1)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=1)
out = out.permute(0, 2, 1)
print("out1", out)

bn_re = result1.squeeze()
bn_re = bn_re.permute(1, 0)
out, cache = fowardbn(bn_re, pn.bn1.weight, pn.bn1.bias, dim=0)
out = out.permute(1, 0)
print("out2", out)

x = np.array([[-1.2089,  6.8342, -0.3317, -5.2298],
         [ 2.5075,  9.6109,  8.8057,  9.0995],
         [ 4.2763,  1.2605,  6.7774, 11.4138],
         [ 1.0103,  1.0549,  0.3408,  0.0656],
         [-2.2381,  1.9428, -3.6522, -7.8491]])
x = x.mean(axis=1)
y = np.array([-2.2381,  1.9428, -3.6522, -7.8491])
y = y.mean(axis=0)
print(x)
print(y)

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/144784.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • java批量修改数据库数据_sql批量更新多条数据

    java批量修改数据库数据_sql批量更新多条数据批量更新mysql更新语句很简单,更新一条数据的某个字段,一般这样写:代码如下:UPDATEmytableSETmyfield=’value’WHEREother_field=’other_value’;如果更新同一字段为同一个值,mysql也很简单,修改下where即可:代码如下:UPDATEmytableSETmyfield=’value’WHEREother_…

    2025年6月10日
    1
  • 用python3实现粒子群优化算法(PSO)

    用python3实现粒子群优化算法(PSO)粒子群优化算法(ParticleSwarmOptimization,PSO)属于进化算法的一种,是通过模拟鸟群捕食行为设计的。从随机解出发,通过迭代寻找最优解,通过适应度来评价解的品质。设想这样一个场景:一群鸟在随机搜索食物。在这个区域里只有一块食物。所有的鸟都不知道食物在那里。但是他们知道当前的位置离食物还有多远。那么找到食物的最优策略是什么呢。最简单有效的就是搜寻目前离食物最近的鸟的周围区…

    2022年5月24日
    39
  • Oracle11g软硬件基本要求,Oracle 11g的安装

    Oracle11g软硬件基本要求,Oracle 11g的安装Oracle11g有基本安装和高级安装两种方式。两种方式对硬件要求也不相同,oracle11g软件非常大,对硬件要求很高。目前只是讲述在windows环境下的安装,Linux环境下安装以后会讲,下表给出了安装Oracle11g所需的硬件配置。系统要求说明CPU最低主频550MHZ以上内存1GB以上虚拟内存物理内存的2倍磁盘空间基本安装需4.55G,高级安装需4.92G一、Windows环境下安装…

    2022年7月25日
    25
  • 分布式锁的应用场景和三种实现方式的区别_负载均衡策略

    分布式锁的应用场景和三种实现方式的区别_负载均衡策略多线程对同一资源的竞争,需要用到锁,例如Java自带的Synchronized、ReentrantLock。但只能用于单机系统中,如果涉及到分布式环境(多机器)的资源竞争,则需要分布式锁。分布式锁的主要作用:保证数据的正确性:比如:秒杀的时候防止商品超卖,表单重复提交,接口幂等性。避免重复处理数据:比如:调度任务在多台机器重复执行,缓存过期所有请求都去加载数据库。分布式锁的主要特性:互斥:同一时刻只能有一个线程获得锁。可重入:当一个线程获取锁后,还可以再次获取这个锁,避免死锁发生。高可用:当

    2022年9月8日
    1
  • 视音频数据处理入门:RGB、YUV像素数据处理[通俗易懂]

    视音频数据处理入门:RGB、YUV像素数据处理[通俗易懂]有段时间没有写博客了,这两天写起博客来竟然感觉有些兴奋,仿佛找回了原来的感觉。前一阵子在梳理以前文章的时候,发现自己虽然总结了各种视音频应用程序,却还缺少一个适合无视音频背景人员学习的“最基础”的程序。因此抽时间将以前写过的代码整理成了一个小项目。

    2022年7月16日
    11
  • iPhone 抓包工具Charles使用[通俗易懂]

    iPhone 抓包工具Charles使用[通俗易懂]Charles是在Mac下常用的截取网络封包的工具,在做iOS开发时,我们为了调试与服务器端的网络通讯协议,常常需要截取网络封包来分析。Charles通过将自己设置成系统的网络访问代理服务器,使得所有的网络访问请求都通过它来完成,从而实现了网络封包的截取和分析。Charles主要的功能包括:支持SSL代理。可以截取分析SSL的请求。支持流量控制。可以模拟慢速网络

    2022年5月16日
    543

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号