pytorch损失函数之nn.CrossEntropyLoss()、nn.NLLLoss()「建议收藏」

pytorch损失函数之nn.CrossEntropyLoss()、nn.NLLLoss()「建议收藏」这个损失函数用于多分类问题虽然说的交叉熵,但是和我们理解的交叉熵不一样

大家好,又见面了,我是你们的朋友全栈君。

nn.CrossEntropyLoss()这个损失函数用于多分类问题虽然说的是交叉熵,但是和我理解的交叉熵不一样。nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合,可以直接使用它来替换网络中的这两个操作。下面我们来看一下计算过程。
首先输入是size是(minibatch,C)。这里的C是类别数。损失函数的计算如下:

l o s s ( x , c l a s s ) = − l o g ( e x p ( x [ c l a s s ] ) ∑ j e x p ( x [ j ] ) ) = − x [ c l a s s ] + l o g ( ∑ j e x p ( x [ j ] ) ) loss(x,class)=-log(\frac{exp(x[class])}{\sum_jexp(x[j])})=-x[class]+log(\sum_jexp(x[j])) loss(x,class)=log(jexp(x[j])exp(x[class]))=x[class]+log(jexp(x[j]))

损失函数中也有权重weight参数设置,若设置权重,则公式为: l o s s ( x , c l a s s ) = w e i g h t [ c l a s s ] ( − x [ c l a s s ] + l o g ( ∑ j e x p ( x [ j ] ) ) ) loss(x,class)=weight[class](-x[class]+log(\sum_jexp(x[j]))) loss(x,class)=weight[class](x[class]+log(jexp(x[j])))
其他参数不具体说,和nn.BCELoss()设置差不多,默认情况下,对minibatch的loss求均值。

注意这里的标签值class,并不参与直接计算,而是作为一个索引,索引对象为实际类别

举个栗子,我们一共有三种类别,批量大小为1(为了好计算),那么输入size为(1,3),具体值为torch.Tensor([[-0.7715, -0.6205,-0.2562]])。标签值为target = torch.tensor([0]),这里标签值为0,表示属于第0类。loss计算如下:
import torch
import torch.nn as nn
import math

entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[-0.7715, -0.6205,-0.2562]])
target = torch.tensor([0])
output = entroy(input, target)
print(output)
#根据公式计算
输出:

tensor(1.3447)

动手自己算:

− x [ 0 ] + l o g ( e x p ( x [ 0 ] ) + e x p ( x [ 1 ] ) + e x p ( x [ 2 ] ) ) -x[0]+log(exp(x[0])+exp(x[1])+exp(x[2])) x[0]+log(exp(x[0])+exp(x[1])+exp(x[2])) = 0.7715 + l o g ( e x p ( − 0.7715 ) + e x p ( − 0.6205 ) + e x p ( − 0.2562 ) = 1.3447266007601868 =0.7715+log(exp(-0.7715)+exp(-0.6205)+exp(-0.2562)=1.3447266007601868 =0.7715+log(exp(0.7715)+exp(0.6205)+exp(0.2562)=1.3447266007601868

我们在看看是否等价nn.logSoftmax()和nn.NLLLoss()的整合:
m = nn.LogSoftmax()
loss = nn.NLLLoss()
input=m(input)
output = loss(input, target)
print('output:',output)
输出:

input: tensor([[-1.3447, -1.1937, -0.8294]])
output: tensor(1.3447)

可以看出nn.LogSoftmax()的对输入的操作就是: l o g ( e x p ( x ) ∑ i e x p ( x [ i ] ) ) log(\frac{exp(x)}{\sum_iexp(x[i])}) log(iexp(x[i])exp(x))x是输入向量。
而nn.NLLLoss()的操作是: l o s s n = − w n x n , y n loss_n=-w_nx_{n,y_n} lossn=wnxn,yn这里没有设置权重,也就是权重默认为1,x_{n,y_n}表示目标类所对应输入x中值,则loss就为 l o s s = − 1 ∗ x [ 0 ] = 1.3447 loss=-1*x[0]=1.3447 loss=1x[0]=1.3447
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/158215.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • git 命令怎么删除远程分支文件_git删除远程仓库分支

    git 命令怎么删除远程分支文件_git删除远程仓库分支本地删除请看:git命令怎么删除本地分支查看所有分支查看项目的远程分支:gitbranch-r删除远程分支比如我们要删除远程分支origin/SLT_table_reportgitpushorigin-d分支名我们执行:gitpushorigin-dSLT_table_report删除成功注意这里不能写成origin/SLT_table_report,不然会报错:具体请参考【git删除远程分支报错error:unabletodelete‘

    2022年10月16日
    3
  • 忆贵州三年的教书编程岁月:不弛于空想,不骛于虚声「建议收藏」

    忆贵州三年的教书编程岁月:不弛于空想,不骛于虚声「建议收藏」回首,2016年7月他离开北京回到了家乡贵州,成为了贵州财经大学的一名青年教师。转眼,2019年7月他迎来了人生的第三张通知书,即将辗转第三个城市,开始新的征途。教书三年,讲台前的每一次分享都值得回味,学生的每一句“老师好”,每一个问候和祝福,都留下了深刻的印象。

    2022年10月7日
    2
  • 封装前端UI组件库–dialog

    封装前端UI组件库–dialog1 前言 dialog 弹窗组件库的实现 目前包括自定义内容 提示语 内容 底部按钮 弹窗的宽高等等 其中提示语 内容 弹窗的宽高的实现与 button 极其类似 请看上文 实现原理都是调用的时候传入参数 在自定义组件里面接收参数 根据参数再做具体的操作等等 2 自定义插槽的实现由于需要自定义传入的参数 提示语 内容 底部按钮都需要插槽传入 过多 接收的时候又都是用来接 所以为了区分需要用到自定义插槽 给插槽取个名字 1 使用

    2025年10月19日
    2
  • docker企业实战视频教程

    docker企业实战视频教程

    2022年2月9日
    42
  • SSM框架中Dao层,Mapper层,controller层,service层,model层,entity层都有什么作用「建议收藏」

    SSM框架中Dao层,Mapper层,controller层,service层,model层,entity层都有什么作用「建议收藏」SSM是sping+springMVC+mybatis集成的框架。MVC即modelviewcontroller。model层=entity层。存放我们的实体类,与数据库中的属性值基本保持一致。service层。存放业务逻辑处理,也是一些关于数据库处理的操作,但不是直接和数据库打交道,他有接口还有接口的实现方法,在接口的实现方法中需要导入mapper层,mapper层是直接跟数据库…

    2022年7月12日
    26
  • 【自动化测试工具】QTP/UFT入门

    【自动化测试工具】QTP/UFT入门准备工作:QTP11.5安装教程:http://www.iquicktest.com/qtp-uft-11-5-download.html注: QuickTestPro(QTP)11.5后更名为UnifiedFunctionalTesting(UFT)1、安装后打开QTP,勾选Webadd-in,进入QTP后File-New-Test。2、选择File菜单下New

    2022年5月28日
    54

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号