softmax 损失函数与梯度推导「建议收藏」

softmax 损失函数与梯度推导「建议收藏」softmax与svm很类似,经常用来做对比,svm的lossfunction对wx的输出s使用了hingefunction,即max(0,-),而softmax则是通过softmaxfunction对输出s进行了概率解释,再通过crossentropy计算lossfunction。将score映射到概率的softmaxfunction:,其中,,j指代i-thclass。…

大家好,又见面了,我是你们的朋友全栈君。

softmax与svm很类似,经常用来做对比,svm的loss function对wx的输出s使用了hinge function,即max(0,-),而softmax则是通过softmax function对输出s进行了概率解释,再通过cross entropy计算loss function。

将score映射到概率的softmax function:p_i=\frac{e^{f_{i}}}{\sum_{k}e^{f_k}} \quad (1),其中,f_i=W_ix,j指代 i-th class。

对于某一个样本如 X_i 的lost function为L_i = -\sum_{j}y_jlog(p_j) \quad (2).

(注:

1、以下所有的公式为了便于表达,设定只有一个样品,即L_i全部写做 L

2、公式中没有进行偏移,实际算法为了避免指数计算容易越界,需要另做偏移处理)

需要求loss function对W的导数(梯度),实际上是进行链式求导。

从最内层的开始,\frac{\partial p_i}{\partial f_j}=\frac{\partial \frac{e^{f_{i}}}{\sum_{k}e^{f_k}}}{\partial f_j} \quad (3),其中,令g_i=e^{f_i},\quad h_i=\sum_{k}e^{f_k}      已知\frac{\mathrm{d} \frac{g(x)}{h(x)} }{\mathrm{d} x}=\frac{​{g}'(x)h(x)-{h}'(x)g(x)}{h^2(x)} \quad (4)

且有\frac{\partial g_i}{\partial f_j}= \begin{cases} & \text{ if } i=j \quad e^{f_i} \\ & \text{ if } i\neq j \quad 0 \end{cases} \quad (5)\frac{\partial h_i}{\partial f_j}=e^{f_j},for \quad all \quad j \quad (6)

那么(3)式则可以根据(4)(5)(6)写成(注意,下面用\sum作为h的简写

\begin{cases} & \text{ if } i=j, \quad \frac{e^{f_i}\sum-e^{f_j}e^{f_i}}{\sum ^2}=\frac{e^{f_i}}{\sum} \frac{\sum-e^{f_j}}{\sum}= p_i(1-p_j)\\ & \text{ if } i\neq j, \quad \frac{0-e^{f_j}e^{f_i}}{\sum^2} = - \frac{e^{f_j}}{\sum} \frac{e^{f_i}}{\sum}=-p_jp_i \end{cases}

根据链式法则:(\sum_ky_k=1,y是一个只有一个元素为1,其余为0的向量,真正的分类时y_i=1)

\frac{\partial L}{\partial f_i}=\frac{\partial L}{\partial p_k}\frac{\partial p_k}{\partial f_i}=-\sum_k y_k \frac{1}{p_k}\frac{\partial p_k}{\partial f_i}\\ = -y_i(1-p_i) -\sum_{k\neq i}y_k \frac{1}{p_k}(-p_kp_i)\\ = -y_i(1-p_i)+\sum_{k\neq i}y_kp_i\\ =-y_i+y_ip_i+\sum_{k\neq i}y_kp_i\\ =p_i(\sum_ky_k)-y_i=p_i-y_i

最后一步,因为f_i=W_ix,这儿i代表第i个类别。

所以:\frac{\partial L}{\partial W_i}=\frac{\partial L}{\partial f_i} \frac{\partial f_i}{\partial W_i}=(p_i-y_i)x(上面设定了x只有一个,但实际x有n个,是矩阵而非向量)。

上面的公式用代码表示如下:

 for ii in range(num_train):
    current_scores = scores[ii, :]

    # Fix for numerical stability by subtracting max from score vector.
    # important! make them range between infinity to zero
    shift_scores = current_scores - np.max(current_scores)

    # Calculate loss for this example.
    loss_ii = -shift_scores[y[ii]] + np.log(np.sum(np.exp(shift_scores)))
    loss += loss_ii

    for jj in range(num_classes):
      softmax_score = np.exp(shift_scores[jj]) / np.sum(np.exp(shift_scores))

      # Gradient calculation.不懂这儿为什么要乘以x[ii]
      if jj == y[ii]:
        dW[:, jj] += (-1 + softmax_score) * X[ii]
      else:
        dW[:, jj] += softmax_score * X[ii]
        
     # Average over the batch and add our regularization term.
  loss /= num_train
  loss += reg * np.sum(W*W)

  # Average over the batch and add derivative of regularization term.
  dW /= num_train
  dW += 2*reg*W

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/153123.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • webstrom最新激活码[在线序列号]

    webstrom最新激活码[在线序列号],https://javaforall.net/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

    2022年3月20日
    62
  • 我给大家整理了50个开源的Java项目

    我给大家整理了50个开源的Java项目大家好,我是孟哥。在学习交流群,其他小伙伴总是问我:孟哥,项目能不能搞得全一些。我想一次学个够。撸完50个项目,我住院了,但是好在项目总结完了。孟哥花了好几天,一次撸了50个项目给大家,非常的香,技术、知识非常的的全面。学起来贼带劲。源码开源,关注+评论(50个源码)+转发,私信我获取源码。系统的截图如下所示:源码开源,关注+评论(50个源码)+转发,私信我获取源码。…

    2022年7月7日
    17
  • 北京移动全网优惠_随着竞争的加剧

    北京移动全网优惠_随着竞争的加剧 【eNet硅谷动力消息】被叫全免计划终于推出了,这个计划可以说是大家翘首以盼,许多人大大节省了话费,对很多人来说是一个大大的福音,但也因此造成了中国通讯资费的改革提速,从而加剧了行业之间的竞争。  中移动北京公司市场部负责人介绍,5月23日公司正式推出了全球通标准资费“被叫全免计划”。自即日开始,北京地区的全球通客户切实实现被叫免费,接听时间没有限制,进一步呼应了社会的期盼。按照本次…

    2022年10月7日
    0
  • lvm 扩容和缩减「建议收藏」

    lvm 扩容和缩减「建议收藏」lvm扩容和缩减1、LVM简介LVM是逻辑卷管理(LogicalVolumeManager)的简称,它是Linux环境下对磁盘分区进行管理的一种机制,LVM是建立在硬盘和分区之上的逻辑层,来提高磁盘分区管理的灵活性。LVM的工作原理其实很简单,它就是通过将底层的物理磁盘抽象的封装起来,然后以逻辑卷的方式呈现给上层应用。在传统的磁盘管理机制中,我们的上层应用是直接访问文件系统,从而对底层的物理硬盘进行读取,而在LVM中,其通过对底层的硬盘进行封装,当我们对底层的物理硬盘进行操作时,其不再是针对于分

    2022年6月20日
    33
  • python正方形螺旋线的绘制

    python正方形螺旋线的绘制多试错,反正又不要成本。importturtlea=1foriinrange(50):turtle.left(90)turtle.fd(a+1)turtle.left(90)

    2022年7月5日
    25
  • Ubuntu中Anaconda安装opencv3[通俗易懂]

    Ubuntu中Anaconda安装opencv3[通俗易懂]关于如何安装,这篇blog中已经给出了很好的方法:https://blog.csdn.net/isuccess88/article/details/73546835,但由于自前段时间开始换源已经不能解决anaconda的下载速度,因此即使使用此方法也很难进行下去,下载速度太慢了。我特地下载了opencv3的opencv3-3.2.0-py35(链接:https://pan.baidu.com…

    2022年10月19日
    0

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号