Pytorch_hook机制的理解及利用register_forward_hook(hook)中间层输出[通俗易懂]

Pytorch_hook机制的理解及利用register_forward_hook(hook)中间层输出[通俗易懂]参考文献:【1】梯度计算问题含公式:参考链接1.【2】pytorch改动和.data和.detch()问题:https://blog.csdn.net/dss_dssssd/article/details/83818181【3】hook技术介绍:https://www.cnblogs.com/hellcat/p/8512090.html【4】hook应用->中间层的输出:https://blog.csdn.net/qq_40303258/article/details/10688431

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全家桶1年46,售后保障稳定

参考文献:

【1】梯度计算问题含公式:参考链接1.

【2】pytorch改动和.data和.detch()问题:https://blog.csdn.net/dss_dssssd/article/details/83818181

【3】hook技术介绍:https://www.cnblogs.com/hellcat/p/8512090.html

【4】hook应用->中间层的输出:https://blog.csdn.net/qq_40303258/article/details/106884317

【5】hook函数介绍:参考链接2

需要了解的基本点:

(1)backward()是Pytorch中用来求梯度的方法。

(2)Variable是对tensor的封装,包含了三部分:

  •  .data:tensor本身
  • .grad:对应tensor的梯度
  • .grad_fn:该Variable是通过什么方式获得的

(3)pytorch 0.4版本后将tensor和Variable合并在了一起。

x = Variable(torch.randn(2, 1), requires_grad=True) # 利用Variable封装tensor
##等效 x = torch.rand(2,1,requires_grad=True)
x = torch.rand(2,1) # 不等效

Jetbrains全家桶1年46,售后保障稳定

(4)hook种类分为两种

Tensor级别  register_hook(hook) ->为Tensor注册一个backward hook,用来获取变量的梯度;hook必须遵循如下的格式:hook(grad) -> Tensor or None

nn.Module对象 register_forward_hook(hook)register_backward_hook(hook)两种方法,分别对应前向传播和反向传播的hook函数。

(5)hook作用:获取某些变量的中间结果的。Pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数在使用后应及时删除,以避免每次都运行钩子增加运行负载。

举例说明 Tensor级别  :

例子1(借鉴参考文献1和3)

import torch 
from torch.autograd import Variable 


def print_grad(grad):
    print('grad is \n',grad)
 
x = Variable(torch.randn(2, 1), requires_grad=True)
## x = torch.rand(2,1,requires_grad=True) #  等效
print('x value is \n',x)
y = x+3
print('y value is \n',y)
z = torch.mean(torch.pow(y, 1/2))
lr = 1e-3

y.register_hook(print_grad) 
z.backward() # 梯度求解
x.data -= lr*x.grad.data
print('new x is\n',x)
output:
x value is 
 tensor([[ 2.5474],
        [-1.1597]], requires_grad=True)
y value is 
 tensor([[5.5474],
        [1.8403]], grad_fn=<AddBackward0>)
grad is 
 tensor([[0.1061],
        [0.1843]])
new x is
 tensor([[ 2.5473],
        [-1.1599]], requires_grad=True)

分析:

对于z来说,求梯度最终求解的是对x的梯度(导数,偏导),因此y是一个中间变量。因此可以用register_hook()来获取其作为中间值的导数,否则z对于y的偏导是获取不到的。x的偏导和y的偏导实际上是相同值,推导如下图。

Pytorch_hook机制的理解及利用register_forward_hook(hook)中间层输出[通俗易懂]

不用register_hook()的例子。

#y.register_hook(print_grad) 

z.backward() # 梯度求解
print('y\'s grad is ',y.grad)
print('x\'s grad is \n',x.grad)
x.data -= lr*x.grad.data
print('new x is\n',x)

output:
y's grad is  None
x's grad is 
 tensor([[0.1544],
        [0.1099]])
new x is
 tensor([[-0.3801],
        [ 2.1755]], requires_grad=True)

可以看出,z对于x的grad是存在的,但是z对于中间变量y的grad是不存在的。也就验证了Pytorch会自动舍弃图计算的中间结果这句话。

举例说明 Module级别 

【1】register_forward_hook(hook)

在网络执行forward()之后,执行hook函数,需要具有如下的形式:

hook(module, input, output) -> None or modified output

hook可以修改input和output,但是不会影响forward的结果。最常用的场景是需要提取模型的某一层(不是最后一层)的输出特征,但又不希望修改其原有的模型定义文件,这时就可以利用forward_hook函数。

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))     #1 
        out = F.max_pool2d(out, 2)      #2
        out = F.relu(self.conv2(out))   #3
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

features = []
def hook(module, input, output): 
    # module: model.conv2 
    # input :in forward function  [#2]
    # output:is  [#3 self.conv2(out)]
    features.append(output.clone().detach())
    # output is saved  in a list 


net = LeNet() ## 模型实例化 
x = torch.randn(2, 3, 32, 32) ## input 
handle = net.conv2.register_forward_hook(hook) ## 获取整个Lenet模型 conv2的中间结果
y = net(x)  ## 获取的是 关于 input x 的 conv2 结果 

print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook删除 

以上文字和代码示例,均来自参考文献5中的示例,由于示例对于register_forward_hook(hook)没有过多注解,因此我加了一些注解。

个人理解:register_forward_hook(hook) 作用就是(假设想要conv2层),那么就是根据 model(该层),该层input,该层output,可以将 output获取。

register_forward_hook(hook)  最大的作用也就是当训练好某个model,想要展示某一层对最终目标的影响效果。

例子:【借鉴参考文献4】

class LayerActivations:
    features = None
    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)
        # 获取model.features中某一层的output
    
    def hook_fn(self, module, input, output):
        self.features = output.cpu()
 
    def remove(self): ## 删除hook
        self.hook.remove()


''' 类似于以下格式
class CNNnet1(torch.nn.Module): ## wangluo jiegou  
    def __init__(self):
        super(CNNnet1,self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(),  
            torch.nn.ReLU(),
            torch.nn.Conv1d(),
            torch.nn.ReLU(),
            torch.nn.Conv1d(),
            torch.nn.BatchNorm1d(),
            torch.nn.MaxPool1d()
            torch.nn.ReLU()
        ) 
'''     
#### model= CNN()
#### train(model,train_loader,learning_rate,batch_size,epochs)
#### 
model.eval() 
test_dataset = DataSet(test_features, test_labels) 
test_loader = DataLoader(test_dataset,batch_size=1,shuffle=True)
        
img = next(iter(test_loader))[0] # gain a input 

for i in range(len(model.features)): # model.features is a nn.Sequential()
    conv_out = LayerActivations(model.features,i) # 实例化,获取每一层
    ouput = model(img)
    act = conv_out.features # gain the ith output
    conv_out.remove # delete the hook

    plt.imshow(act[0].detach().numpy(),cmap='hot') # output is showed using 热力图 
    plt.colorbar(shrink=0.4) # 句柄大小
    plt.show() 

大概画完了就是这个样子[每一层都有一个图,不做过多展示]:

Pytorch_hook机制的理解及利用register_forward_hook(hook)中间层输出[通俗易懂]

Pytorch_hook机制的理解及利用register_forward_hook(hook)中间层输出[通俗易懂]

其中 plt.imshow()是热力图画法,详情点击链接。可以把参考文献4中是将所有的中间层画到了一张画布上,因为卷积层尺寸不同,我就没放在一起。

[2]register_backward_hook(hook)

因为暂时没有用到,不做详细讲解,具体可参考参考文献5。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/213597.html原文链接:https://javaforall.net

(0)
上一篇 2025年7月18日 上午7:01
下一篇 2025年7月18日 上午7:43


相关推荐

  • 扫雷小游戏-纯网页版下载_扫雷游戏下载手机版

    扫雷小游戏-纯网页版下载_扫雷游戏下载手机版这两天在恶补前端的相关知识,看到JQuery的动画部分时,突然心血来潮想做一个扫雷的网页版,于是花了差不多一天的时间完成了一个初始版本,权当对这几天学习成果的一个回顾,若某处功能有更好实现方式欢迎留言

    2022年8月2日
    9
  • 【Java】JVM垃圾回收机制与类加载机制

    【Java】JVM垃圾回收机制与类加载机制不同于C++需要编程人员手动释放内存,Java有虚拟机,因此Java不需要程序员主动去释放内存,而是通过虚拟机自身的垃圾回收器(GarbageCollector-GC)来进行对象的回收。Java语言由于有虚拟机的存在,实现了平台无关性,在任意平台都是通过将代码转换为字节码文件,从而在平台下的虚拟机中运行代码的。JVM内存区域分布虚拟机栈:存放每个方法执行时的栈帧,一个方法调用到…

    2022年5月18日
    41
  • zencart模板修改的地方

    zencart模板修改的地方zencart模闆修改的地方,在修改前一定要備份數據庫和程序文件.1:新聞頁上的zencartnews,如何修改zencart,如圖:修改文件news.php,位置:includes\languages\english2,修改産品頁上面滾動圖片的高度和寬度:修改文件stylesheet_scrollpic.css,位置:includes\templates\模闆名\css在修改不讓自動…

    2022年7月27日
    9
  • OpenClaw光速国产化,大厂出的“龙虾”到底哪个最好用?

    OpenClaw光速国产化,大厂出的“龙虾”到底哪个最好用?

    2026年3月13日
    2
  • 流利说文本level6_流利说level4原文

    流利说文本level6_流利说level4原文Level6Unit11/4ListeningLesson1Harry’sInjury1-2DialogueLesson3Lovers’QuarrelReadingLesson4TheBoyWhoCriedWolfLesson5SurvivalintheOutback2/4ListeningLesson1T…

    2022年10月8日
    5
  • Spatial Transformer Network_transgression

    Spatial Transformer Network_transgression导读上一篇通俗易懂的SpatialTransformerNetworks(STN)(一)中,我们详细介绍了STN中会使用到的几个模块,并且用pytorch和numpy来实现了他们,这篇文章我们将会利用pytorch来实现一个MNIST的手写数字识别并且将STN模块插入到CNN中STN关键点解读STN有一个最大的特点就是STN模块能够很容易的嵌入到CNN中,只需要进行非常小的修改即可。上一篇文章我们也说了STN拥有平移、旋转、剪切、缩放等不变性,而这一特点主要是依赖θ\thetaθ参数来实现的。刚开

    2022年8月31日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号