pytorch固定BN层参数[通俗易懂]

pytorch固定BN层参数[通俗易懂]背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。原因:未固定主分支BN层中的running_mean和running_var。解决方法:将需要固定的BN层状态设置为eval。问题示例:环境:torch:1.7.0#-*-coding:utf-8-*-importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.M

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE使用 1年只要46元 售后保障 童叟无欺

背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:未固定主分支BN层中的running_meanrunning_var

解决方法:将需要固定的BN层状态设置为eval

问题示例

环境:torch:1.7.0

# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

运行结果:

-------parameters requires grad info--------
conv1.weight:   True
conv1.bias:     True
bn1.weight:     True
bn1.bias:       True
conv2.weight:   True
conv2.bias:     True
bn2.weight:     True
bn2.bias:       True
fc1.weight:     True
fc1.bias:       True
fc2.weight:     True
fc2.bias:       True
fc3.weight:     True
fc3.bias:       True
-------parameters requires grad info--------
conv1.weight:   False
conv1.bias:     False
bn1.weight:     False
bn1.bias:       False
conv2.weight:   False
conv2.bias:     False
bn2.weight:     False
bn2.bias:       False
fc1.weight:     False
fc1.bias:       False
fc2.weight:     False
fc2.bias:       False
fc3.weight:     False
fc3.bias:       False
epoch:0 tensor([[-0.0755,  0.1138,  0.0966,  0.0564, -0.0224]])
epoch:1 tensor([[-0.0763,  0.1113,  0.0970,  0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_meanrunning_var并没在可优化参数net.parameters

bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到结果正常了:

epoch:0 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])

参考:

Cannot freeze batch normalization parameters

BatchNorm2d

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/183783.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • JAVA多线程实现的三种方式

    JAVA多线程实现的三种方式

    2022年1月20日
    55
  • java中高级工程师面试汇总

    java中高级工程师面试汇总1接口服务数据被劫包如何防止数据恶意提交1.1:防篡改客户端提交请求之前,先对自己请求的参数全部进行拼接加密得到一个加密字符串sign 请求参数加上sign,然后再发送给服务器 服务器将参数获取后也进行相同的拼接加密得到自己的sign 比较与客户端发来的sign是否相同 不相同则是被第三方修改过的,拒绝执行关键:第三方不知道加密方式和请求参数拼接规则,而客户端与服务器是知道的,因此第三方不知道修改参数后如何生成与服务器生成相同的sign 只要请求修改了一点点加密得到的就是不同的签名

    2022年7月8日
    21
  • Linux下Apache与MySQL+PHP的综合应用案例

    Linux下Apache与MySQL+PHP的综合应用案例

    2021年7月31日
    59
  • c++迭代器iterator遍历map_iterator迭代器原理

    c++迭代器iterator遍历map_iterator迭代器原理什么是迭代器迭代器是一种可以遍历容器元素的数据类型。迭代器是一个变量,相当于容器和操纵容器的算法之间的中介。C++更趋向于使用迭代器而不是数组下标操作,因为标准库为每一种标准容器(如vector、map和list等)定义了一种迭代器类型,而只有少数容器(如vector)支持数组下标操作访问容器元素。可以通过迭代器指向你想访问容器的元素地址,通过*x打印出元素值。这和我们所熟知的指针极其类似。C语言有指针,指针用起来十分灵活高效。C++语言有迭代器,迭代器相对于指针而言功能更为丰富。vector,是数

    2025年7月1日
    4
  • 网络基础知识大全_网络基础知识入门到精通

    网络基础知识大全_网络基础知识入门到精通1)如何查看本机所开端口:用netstat-a—n命令查看!netstat结果显示有一些英文,简单说一下这些英文具体都代表什么:LISTEN:侦听来自远方的TCP端口的连接请求SYN-SENT:再

    2022年8月6日
    9
  • gis地理加权回归步骤_地理加权回归权重

    gis地理加权回归步骤_地理加权回归权重内容导读1)回归概念介绍;2)探索性回归工具(解释变量的选择)使用;3)广义线性回归工具(GLR)使用;*加更:广义线性回归工具的补充内容4)地理加权回归工具(GWR)使用+小结。说明:本节是这个学习笔记最后一部分。PART/04地理加权回归工具(GWR)使用上一节我们讲了GLR广义线性回归,它是一种全局模型,可以构造出最佳描述研究区域中整体数据关系的方程。如果这些关系在研究区域中是一致的,则GLR回归方程可以对这些关系进行很好的建模。不过,当这些关系在研

    2022年10月6日
    5

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号