简单的交叉熵损失函数,你真的懂了吗?

简单的交叉熵损失函数,你真的懂了吗?个人网站 红色石头的机器学习之路 CSDN 博客 红色石头的专栏知乎 红色石头微博 RedstoneWill 的微博 GitHub RedstoneWill 的 GitHub 微信公众号 AI 有道 ID redstonewill 说起交叉熵损失函数 CrossEntropy 脑海中立马浮现出它的公式 L ylog nbsp y 1

说起交叉熵损失函数「Cross Entropy Loss」,脑海中立马浮现出它的公式:

L=[ylog y^+(1y)log (1y^)] L = − [ y l o g   y ^ + ( 1 − y ) l o g   ( 1 − y ^ ) ]

我们已经对这个交叉熵函数非常熟悉,大多数情况下都是直接拿来使用就好。但是它是怎么来的?为什么它能表征真实样本标签和预测概率之间的差值?上面的交叉熵函数是否有其它变种?也许很多朋友还不是很清楚!没关系,接下来我将尽可能以最通俗的语言回答上面这几个问题。

1. 交叉熵损失函数的数学原理

我们知道,在二分类问题模型:例如逻辑回归「Logistic Regression」、神经网络「Neural Network」等,真实样本的标签为 [0,1],分别表示负类和正类。模型的最后通常会经过一个 Sigmoid 函数,输出一个概率值,这个概率值反映了预测为正类的可能性:概率越大,可能性越大。

Sigmoid 函数的表达式和图形如下所示:

g(s)=11+es g ( s ) = 1 1 + e − s


这里写图片描述

其中 s 是模型上一层的输出,Sigmoid 函数有这样的特点:s = 0 时,g(s) = 0.5;s >> 0 时, g ≈ 1,s << 0 时,g ≈ 0。显然,g(s) 将前一级的线性输出映射到 [0,1] 之间的数值概率上。这里的 g(s) 就是交叉熵公式中的模型预测输出 。

我们说了,预测输出即 Sigmoid 函数的输出表征了当前样本标签为 1 的概率:

y^=P(y=1|x) y ^ = P ( y = 1 | x )

很明显,当前样本标签为 0 的概率就可以表达成:

1y^=P(y=0|x) 1 − y ^ = P ( y = 0 | x )

重点来了,如果我们从极大似然性的角度出发,把上面两种情况整合到一起:

P(y|x)=y^y(1y^)1y P ( y | x ) = y ^ y ⋅ ( 1 − y ^ ) 1 − y

不懂极大似然估计也没关系。我们可以这么来看:

当真实样本标签 y = 0 时,上面式子第一项就为 1,概率等式转化为:

P(y=0|x)=1y^ P ( y = 0 | x ) = 1 − y ^

当真实样本标签 y = 1 时,上面式子第二项就为 1,概率等式转化为:

P(y=1|x)=y^ P ( y = 1 | x ) = y ^

两种情况下概率表达式跟之前的完全一致,只不过我们把两种情况整合在一起了。

重点看一下整合之后的概率表达式,我们希望的是概率 P(y|x) 越大越好。首先,我们对 P(y|x) 引入 log 函数,因为 log 运算并不会影响函数本身的单调性。则有:

log P(y|x)=log(y^y(1y^)1y)=ylog y^+(1y)log(1y^) l o g   P ( y | x ) = l o g ( y ^ y ⋅ ( 1 − y ^ ) 1 − y ) = y l o g   y ^ + ( 1 − y ) l o g ( 1 − y ^ )

我们希望 log P(y|x) 越大越好,反过来,只要 log P(y|x) 的负值 -log P(y|x) 越小就行了。那我们就可以引入损失函数,且令 Loss = -log P(y|x)即可。则得到损失函数为:

L=[ylog y^+(1y)log (1y^)] L = − [ y l o g   y ^ + ( 1 − y ) l o g   ( 1 − y ^ ) ]

非常简单,我们已经推导出了单个样本的损失函数,是如果是计算 N 个样本的总的损失函数,只要将 N 个 Loss 叠加起来就可以了:

L=i=1Ny(i)log y^(i)+(1y(i))log (1y^(i)) L = ∑ i = 1 N y ( i ) l o g   y ^ ( i ) + ( 1 − y ( i ) ) l o g   ( 1 − y ^ ( i ) )

这样,我们已经完整地实现了交叉熵损失函数的推导过程。

2. 交叉熵损失函数的直观理解

可能会有读者说,我已经知道了交叉熵损失函数的推导过程。但是能不能从更直观的角度去理解这个表达式呢?而不是仅仅记住这个公式。好问题!接下来,我们从图形的角度,分析交叉熵函数,加深大家的理解。

首先,还是写出单个样本的交叉熵损失函数:

L=[ylog y^+(1y)log (1y^)] L = − [ y l o g   y ^ + ( 1 − y ) l o g   ( 1 − y ^ ) ]

我们知道,当 y = 1 时:

L=log y^ L = − l o g   y ^

这时候,L 与预测输出的关系如下图所示:


这里写图片描述

看了 L 的图形,简单明了!横坐标是预测输出,纵坐标是交叉熵损失函数 L。显然,预测输出越接近真实样本标签 1,损失函数 L 越小;预测输出越接近 0,L 越大。因此,函数的变化趋势完全符合实际需要的情况。

当 y = 0 时:

L=log (1y^) L = − l o g   ( 1 − y ^ )

这时候,L 与预测输出的关系如下图所示:


这里写图片描述

同样,预测输出越接近真实样本标签 0,损失函数 L 越小;预测函数越接近 1,L 越大。函数的变化趋势也完全符合实际需要的情况。

从上面两种图,可以帮助我们对交叉熵损失函数有更直观的理解。无论真实样本标签 y 是 0 还是 1,L 都表征了预测输出与 y 的差距。

另外,重点提一点的是,从图形中我们可以发现:预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大,是一种类似指数增长的级别。这是由 log 函数本身的特性所决定的。这样的好处是模型会倾向于让预测输出更接近真实样本标签 y。

3. 交叉熵损失函数的其它形式

什么?交叉熵损失函数还有其它形式?没错!我刚才介绍的是一个典型的形式。接下来我将从另一个角度推导新的交叉熵损失函数。

这种形式下假设真实样本的标签为 +1 和 -1,分别表示正类和负类。有个已知的知识点是Sigmoid 函数具有如下性质:

1g(s)=g(s) 1 − g ( s ) = g ( − s )

这个性质我们先放在这,待会有用。

好了,我们之前说了 y = +1 时,下列等式成立:

P(y=+1|x)=g(s) P ( y = + 1 | x ) = g ( s )

如果 y = -1 时,并引入 Sigmoid 函数的性质,下列等式成立:

P(y=1|x)=1g(s)=g(s) P ( y = − 1 | x ) = 1 − g ( s ) = g ( − s )

重点来了,因为 y 取值为 +1 或 -1,可以把 y 值带入,将上面两个式子整合到一起:

P(y|x)=g(ys) P ( y | x ) = g ( y s )

这个比较好理解,分别令 y = +1 和 y = -1 就能得到上面两个式子。

接下来,同样引入 log 函数,得到:

log P(y|x)=log g(ys) l o g   P ( y | x ) = l o g   g ( y s )

要让概率最大,反过来,只要其负数最小即可。那么就可以定义相应的损失函数为:

L=logg(ys) L = − l o g g ( y s )

还记得 Sigmoid 函数的表达式吧?将 g(ys) 带入:

L=log11+eys=log (1+eys) L = − l o g 1 1 + e − y s = l o g   ( 1 + e − y s )

好咯,L 就是我要推导的交叉熵损失函数。如果是 N 个样本,其交叉熵损失函数为:

L=i=1Nlog (1+eys) L = ∑ i = 1 N l o g   ( 1 + e − y s )

接下来,我们从图形化直观角度来看。当 y = +1 时:

L=log (1+es) L = l o g   ( 1 + e − s )

这时候,L 与上一层得分函数 s 的关系如下图所示:


这里写图片描述

横坐标是 s,纵坐标是 L。显然,s 越接近真实样本标签 1,损失函数 L 越小;s 越接近 -1,L 越大。

另一方面,当 y = -1 时:

L=log(1+es) L = l o g ( 1 + e s )

这时候,L 与上一层得分函数 s 的关系如下图所示:


这里写图片描述

同样,s 越接近真实样本标签 -1,损失函数 L 越小;s 越接近 +1,L 越大。

4. 总结

本文主要介绍了交叉熵损失函数的数学原理和推导过程,也从不同角度介绍了交叉熵损失函数的两种形式。第一种形式在实际应用中更加常见,例如神经网络等复杂模型;第二种多用于简单的逻辑回归模型。


这里写图片描述

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/204674.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月19日 下午7:48
下一篇 2026年3月19日 下午7:48


相关推荐

  • 5分钟入门mp4文件格式是多少_Mp4格式

    5分钟入门mp4文件格式是多少_Mp4格式写在前面本文主要内容包括,什么是MP4、MP4文件的基本结构、Box的基本结构、常见且重要的box介绍、普通MP4与fMP4的区别、如何通过代码解析MP4文件等。写作背景:最近经常回答团队小伙伴关于直播&短视频的问题,比如“flv.js的实现原理”、“为什么设计同学给的mp4文件浏览器里播放不了、但本地可以正常播放”、“MP4兼容性很好,可不可以用来做直播”等。在解答的过程中,发现经常涉及MP4协议的介绍。之前这块有简单了解过并做了笔记,这里稍微整理一下,顺便作为团队参考文档,如

    2022年10月16日
    3
  • 惠普台式电脑安装系统按哪个键_hp不识别u盘装系统

    惠普台式电脑安装系统按哪个键_hp不识别u盘装系统当我们使用U盘给电脑装系统时,需要进入BIOS设置从USB启动,不过设置BIOS太麻烦了,而且大多数电脑现在都支持快捷键启动,如惠普笔记本,那么惠普usb装系统按哪个键呢?接下来小编就跟大家讲解一下,希望能够帮助到大家。惠普usb装系统步骤阅读1、将U盘插在电脑的USB接口,开机并不断按下启动U盘快捷键f12。2、在进入系统启动菜单中选择有USB字样的选项并回车。3、重启电脑,选择YunQiShi…

    2022年8月13日
    9
  • Linux安装vim

    Linux安装vim一 vim 介绍 vim 是多模式编辑器 是 vi 的升级版 不仅兼容 vi 的所有指令 还有新的特性 vi 是 ubantu 自带的 二 下载安装 ctrl alt t 打开控制台 输入 sudoapt getinstallvi 后输入密码 等待安装即可

    2026年3月20日
    2
  • Jdbc系列六:ResultSetMetaData类

    Jdbc系列六:ResultSetMetaData类一 使用 JDBC 驱动程序处理元数据 Java 通过 JDBC 获得连接以后 得到一个 Connection 对象 可以从这个对象获得有关数据库管理系统的各种信息 包括数据库中的各个表 表中的各个列 数据类型 触发器 存储过程等各方面的信息 根据这些信息 JDBC 可以访问一个实现事先并不了解的数据库 获取这些信息的方法都是在 DatabaseMeta 类的对象上实现的 而 DataBaseMe

    2026年2月22日
    3
  • 十大经典排序算法java(几种排序算法的比较)

    四种常用排序算法冒泡排序特点:效率低,实现简单思想(从小到大排):每一趟将待排序序列中最大元素移到最后,剩下的为新的待排序序列,重复上述步骤直到排完所有元素。这只是冒泡排序的一种,当然也可以从后往前排。publicvoidbubbleSort(intarray[]){intt=0;for(inti=0;i&amp;amp;lt;…

    2022年4月11日
    62
  • 我自己实际操作安装MyCat水平分割之分片枚举和取模算法(二)

    我自己实际操作安装MyCat水平分割之分片枚举和取模算法(二)

    2021年7月10日
    105

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号