pytorch损失函数之nn.CrossEntropyLoss()、nn.NLLLoss()「建议收藏」

pytorch损失函数之nn.CrossEntropyLoss()、nn.NLLLoss()「建议收藏」这个损失函数用于多分类问题虽然说的交叉熵,但是和我们理解的交叉熵不一样

大家好,又见面了,我是你们的朋友全栈君。

nn.CrossEntropyLoss()这个损失函数用于多分类问题虽然说的是交叉熵,但是和我理解的交叉熵不一样。nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合,可以直接使用它来替换网络中的这两个操作。下面我们来看一下计算过程。
首先输入是size是(minibatch,C)。这里的C是类别数。损失函数的计算如下:

l o s s ( x , c l a s s ) = − l o g ( e x p ( x [ c l a s s ] ) ∑ j e x p ( x [ j ] ) ) = − x [ c l a s s ] + l o g ( ∑ j e x p ( x [ j ] ) ) loss(x,class)=-log(\frac{exp(x[class])}{\sum_jexp(x[j])})=-x[class]+log(\sum_jexp(x[j])) loss(x,class)=log(jexp(x[j])exp(x[class]))=x[class]+log(jexp(x[j]))

损失函数中也有权重weight参数设置,若设置权重,则公式为: l o s s ( x , c l a s s ) = w e i g h t [ c l a s s ] ( − x [ c l a s s ] + l o g ( ∑ j e x p ( x [ j ] ) ) ) loss(x,class)=weight[class](-x[class]+log(\sum_jexp(x[j]))) loss(x,class)=weight[class](x[class]+log(jexp(x[j])))
其他参数不具体说,和nn.BCELoss()设置差不多,默认情况下,对minibatch的loss求均值。

注意这里的标签值class,并不参与直接计算,而是作为一个索引,索引对象为实际类别

举个栗子,我们一共有三种类别,批量大小为1(为了好计算),那么输入size为(1,3),具体值为torch.Tensor([[-0.7715, -0.6205,-0.2562]])。标签值为target = torch.tensor([0]),这里标签值为0,表示属于第0类。loss计算如下:
import torch
import torch.nn as nn
import math

entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[-0.7715, -0.6205,-0.2562]])
target = torch.tensor([0])
output = entroy(input, target)
print(output)
#根据公式计算
输出:

tensor(1.3447)

动手自己算:

− x [ 0 ] + l o g ( e x p ( x [ 0 ] ) + e x p ( x [ 1 ] ) + e x p ( x [ 2 ] ) ) -x[0]+log(exp(x[0])+exp(x[1])+exp(x[2])) x[0]+log(exp(x[0])+exp(x[1])+exp(x[2])) = 0.7715 + l o g ( e x p ( − 0.7715 ) + e x p ( − 0.6205 ) + e x p ( − 0.2562 ) = 1.3447266007601868 =0.7715+log(exp(-0.7715)+exp(-0.6205)+exp(-0.2562)=1.3447266007601868 =0.7715+log(exp(0.7715)+exp(0.6205)+exp(0.2562)=1.3447266007601868

我们在看看是否等价nn.logSoftmax()和nn.NLLLoss()的整合:
m = nn.LogSoftmax()
loss = nn.NLLLoss()
input=m(input)
output = loss(input, target)
print('output:',output)
输出:

input: tensor([[-1.3447, -1.1937, -0.8294]])
output: tensor(1.3447)

可以看出nn.LogSoftmax()的对输入的操作就是: l o g ( e x p ( x ) ∑ i e x p ( x [ i ] ) ) log(\frac{exp(x)}{\sum_iexp(x[i])}) log(iexp(x[i])exp(x))x是输入向量。
而nn.NLLLoss()的操作是: l o s s n = − w n x n , y n loss_n=-w_nx_{n,y_n} lossn=wnxn,yn这里没有设置权重,也就是权重默认为1,x_{n,y_n}表示目标类所对应输入x中值,则loss就为 l o s s = − 1 ∗ x [ 0 ] = 1.3447 loss=-1*x[0]=1.3447 loss=1x[0]=1.3447
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/158215.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 可以识别图片上的文字的小程序

    可以识别图片上的文字的小程序微信上的小程序相信大家都不陌生,近年来,微信小程序从“跳一跳”之后,越发火了。由于小程序的出现,微信上的功能也逐渐增加了,今天就给大家介绍一个小程序,比较实用,它可以快速识别图片上的文字,这个小程序呢就叫“迅捷文字识别”。这是一个比较智能的文字识别的小程序,它可以将识别出来的字汉英互译,还可以直接拍照翻译,接下来就给大家介绍一下这个小程序的操作方法。1.首先,我们现在微信上找到这个程序,点…

    2022年5月29日
    47
  • Java高级工程师常见面试题(一)-Java基础「建议收藏」

    Java高级工程师常见面试题(一)-Java基础「建议收藏」博主其他相关文章:《Java高级工程师常见面试题-总结》1.String类为什么是final的。多线程安全,将字符串对象保存在字符串常量池中共享效率高2.HashMap的源码,实现原理,底层结构。HashMap基于哈希表的Map接口的实现。允许使用null值和null键。此类不保证映射的顺序,特别是它不保证该顺序恒久不变。值得注意的是HashMap不是线程安全的…

    2022年5月27日
    35
  • 协同过滤推荐算法(一)原理与实现

    协同过滤推荐算法(一)原理与实现一、协同过滤算法原理协同过滤推荐算法是诞生最早,并且较为著名的推荐算法。主要的功能是预测和推荐。算法通过对用户历史行为数据的挖掘发现用户的偏好,基于不同的偏好对用户进行群组划分并推荐品味相似的商品。协同过滤推荐算法分为两类,分别是基于用户的协同过滤算法(user-basedcollaboratIvefiltering),和基于物品的协同过滤算法(item-basedcollaborati…

    2022年6月24日
    31
  • 三十六:Redis过期键删除策略[通俗易懂]

    redisDb结构的expires字典保存了数据库中所有键的过期时间,我们称这个字典为过期字典:❑过期字典的键是一个指针,这个指针指向键空间中的某个键对象(也即是某个数据库键)。❑过期字典的值是一个longlong类型的整数,这个整数保存了键所指向的数据库键的过期时间——一个毫秒精度的UNIX时间戳。❑定时删除:在设置键的过期时间的同时,创建一个定时器(timer),让定时器在键的过…

    2022年4月13日
    47
  • jvm面试题目及答案_jvm原理面试题

    jvm面试题目及答案_jvm原理面试题Jvm面试题及答案整理965道(2021最新版)这是我收集的《Jvm最常见的965道面试题》高级Java面试问题列表。这些问题主要来自JVM核心部分,你可能知道这些棘手的JVM问题的答案,或者觉得这些不足以挑战你的Java知识,但这些问题都是容易在各种JVM面试中被问到的,而且包括我的朋友和同事在内的许多程序员都觉得很难回答。Jvm最新2021年面试题及答案,汇总版01、JAVA弱引用02、什么是堆03、什么是程序计数器04、各种回收器,各自优缺点,重点CMS、G1…

    2022年8月27日
    7
  • 词向量总结「建议收藏」

    词向量总结「建议收藏」词向量词向量是自然语言理解的重要工具,它的核心思想是把词映射到一个向量空间,并且这个向量空间很大程度上保留了原本的语义。词向量既可以作为对语料进行数据挖掘的基础,也可以作为更复杂的模型的输入,是现在nlp的主流工具。下面就总结一下nlp中经典的词向量方法。主要有:onehot、glove、cbow、skip-gram

    2022年5月6日
    38

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号