python复现softmax损失函数详细版

python复现softmax损失函数详细版fromtorchimportnnimporttorchdefloss_func(output,target):one_hot=torch.zeros_like(output)foriinrange(target.size(0)):one_hot[i,target[i]]=1softmax_out=torch.exp(output)/(torch.unsque…

大家好,又见面了,我是你们的朋友全栈君。

主要内容

  • softmax和交叉熵公式
  • 单个样本求解损失
  • 多个样本求解损失

softmax和交叉熵公式

  • softmax

先来看公式,softmax的作用是将样本对应的输出向量,转换成对应的类别概率值。这里使用以e为底的指数函数,将向量值归一化为0-1的概率值;
Alt
使用numpy的代码实现也很简单,但是当数值过大时会发生溢出,此时会将向量中的其他值减去最大值,数值平移到0附近。会避免溢出现象。ps:这里暂时不考虑这种情况
在这里插入图片描述

  • softmax交叉熵
    交叉熵是用来衡量分布p和q之间的相似度,越相似交叉熵越小。其中 p ( x ) p(x) p(x)是真实标签的one_hot编码, q ( x ) q(x) q(x)是预测值。需要注意的是这里的 q ( x ) q(x) q(x)必须是经过softmax的概率值。
    Alt

单个样本求解损失

#conding=utf-8

from torch import nn
import torch
import numpy as np

def MySoftmax(vector):
    return np.exp(vector)/np.exp(vector).sum()

def LossFunc(target,output):
    output = MySoftmax(output)
    one_hot = np.zeros_like(output)
    one_hot[:,target] = 1
    # print(one_hot)
    loss = (-np.log(output)*one_hot).sum()
    return loss
target = np.array([1])
output = np.array([[8,-3.,10]])
softmax_out = MySoftmax(output)
np.set_printoptions(suppress=True)
print(softmax_out)

# torch自带的softmax实现
print(nn.Softmax()(torch.Tensor(output)))

print(LossFunc(target,output))
print(nn.CrossEntropyLoss(reduction="sum")(torch.Tensor(output),torch.Tensor(target).long()))

需要注意的是现有的框架中基本都会在损失函数内部进行softmax转换。我这里设置的loss值没有求平均,所以reduction=“sum”

多个样本求解损失

#conding=utf-8

from torch import nn
import torch
import numpy as np

# def MySoftmax(vector):
# return np.exp(vector)/np.exp(vector).sum()
#
# def LossFunc(target,output):
# output = MySoftmax(output)
# one_hot = np.zeros_like(output)
# one_hot[:,target] = 1
# # print(one_hot)
# loss = (-np.log(output)*one_hot).sum()
# return loss
# target = np.array([1])
# output = np.array([[8,-3.,10]])
# softmax_out = MySoftmax(output)
# np.set_printoptions(suppress=True)
# print(softmax_out)
#
# # torch自带的softmax实现
# print(nn.Softmax()(torch.Tensor(output)))
#
# print(LossFunc(target,output))
# print(nn.CrossEntropyLoss(reduction="sum")(torch.Tensor(output),torch.Tensor(target).long()))

def loss_func(output,target):
    one_hot = torch.zeros_like(output)
    for i in range(target.size(0)):
        one_hot[i,target[i]]=1

    softmax_out = torch.exp(output)/( torch.unsqueeze(torch.exp(output).sum(dim=1),dim=1))
    # 确保每一个样本维度的概率之和为1
    print(softmax_out.sum(dim=1))
    loss = (-torch.log(softmax_out) * one_hot).sum()
    return loss

target = torch.Tensor([1,1,1]).long()
output = torch.Tensor([[10.,-5,5],[5,2,-1],[4,-9,5]])
softmax = nn.Softmax(dim=1)


criterion = nn.CrossEntropyLoss(reduction="sum")
print(criterion(output,target))
print(loss_func(output,target))

我这里使用的是torch的计算,主要原因是想使用label smoothing技巧,torch版在项目中应用更方便。
只是将numpy换成torch的形式,基本的公式都没有改变的。需要注意的是在多个样本求解softmax值是在样本的维度求概率。

喜欢的童鞋点个赞哦!大家有什么要了解的请留言,老汤尽量满足

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/153145.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • excel截取指定字符之后的字符串_怎样把单元格的字符串拼接起来

    excel截取指定字符之后的字符串_怎样把单元格的字符串拼接起来函数:mid需求:提取A1单元格字符串的一部分(第四个字符算起,截取2个字符)放在A2单元格。使用:=mid(A1,4,2)

    2025年6月8日
    3
  • 搜索引擎自动提交连接php文件,死链检测工具(自动提交给百度,逆天了)

    搜索引擎自动提交连接php文件,死链检测工具(自动提交给百度,逆天了)每个网站都避免不了404死链的存在。造成死链的原因有很多,比如说文章页被删除、链接被修改、网页链接更换存储路径等,这些都会成为死链。这些死链的产生,降低了搜索引擎对网站的友好度、影响用户体验,甚至会导致网站排名下降等。刚操作MAY博客的时候,文章的内容不是很多,遇到死链,只是简单的手动一个个去站长平台提交。但随着文章及页面的不断增加,一个个手动去操作,是不是麻烦了些。咦,是否能借助死链检测工具,自…

    2022年7月23日
    14
  • TransactionScope事务处理

    TransactionScope事务处理在我们日常开发的时候,有时候程序需要使用到事务,就比如,我们日常最熟悉的一个流程,那么就是银行的取款,当用户从ATM机器选择取款1000元的时候,恰巧这个时候如果停电,如果没有事务那么将会出现不堪设想的后果,银行都会倒闭。最近在开发一个功能,需要根据单据的信息生成2张单据,要么全部保存,要么都保存失败,做到事务的一致性、原子性,一开始我想到的是SQL和存储过程级别的事务,但是好像按照当前的系统的业务逻辑,这个方法的底层还是拼接SQL语句,后面又想着使用C#的ADO.NET级别的事务,根据数据生成sql,但

    2022年7月19日
    18
  • goland 2021.5.1激活码【在线注册码/序列号/破解码】

    goland 2021.5.1激活码【在线注册码/序列号/破解码】,https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月18日
    44
  • 激活成功教程quartus ii13.0_quartus ii 13.0安装

    激活成功教程quartus ii13.0_quartus ii 13.0安装文章目录一、QuartusII的下载二、QuartusII的安装三、QuartusII的激活成功教程1.下载激活成功教程器文件2.激活成功教程器的使用一、QuartusII的下载百度网盘下载链接:https://pan.baidu.com/s/1a9d-bq9RZmWrRV542X4IEA提取码:ifte说明:本链接来自于正点原子官方资料下载二、QuartusII的安装解压后双击运行exe文件:点击next:勾选“Iaccepttheagreement”,然后点击Next:

    2022年10月10日
    5
  • android之ContentObserver内容观察者的使用

    ContentObserver——内容观察者,目的是观察(捕捉)特定Uri引起的数据库的变化,继而做一些相应的处理,它类似于   数据库技术中的触发器(Trigger),当ContentObserver所观察的Uri发生变化时,便会触发它。(1)注册:    public final void  registerContentObserver(Uri uri, boolean noti

    2022年3月11日
    49

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号