OHEM的pytorch代码实现细节

OHEM的pytorch代码实现细节详细解读一下OHEM的实现代码:defohem_loss(batch_size,cls_pred,cls_target,loc_pred,loc_target,smooth_l1_sigma=1.0):”””Arguments:batch_size(int):numberofsampledroisforbboxhe…

大家好,又见面了,我是你们的朋友全栈君。

详细解读一下OHEM的实现代码:

def ohem_loss(
    batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):
    """
    Arguments:
        batch_size (int): number of sampled rois for bbox head training
        loc_pred (FloatTensor): [R, 4], location of positive rois
        loc_target (FloatTensor): [R, 4], location of positive rois
        pos_mask (FloatTensor): [R], binary mask for sampled positive rois
        cls_pred (FloatTensor): [R, C]
        cls_target (LongTensor): [R]

    Returns:
        cls_loss, loc_loss (FloatTensor)
    """
    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
    ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)
    #这里先暂存下正常的分类loss和回归loss
    loss = ohem_cls_loss + ohem_loc_loss
    #然后对分类和回归loss求和

  
    sorted_ohem_loss, idx = torch.sort(loss, descending=True)
    #再对loss进行降序排列
    keep_num = min(sorted_ohem_loss.size()[0], batch_size)
    #得到需要保留的loss数量
    if keep_num < sorted_ohem_loss.size()[0]:
    #这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
        keep_idx_cuda = idx[:keep_num]
        #保留到需要keep的数目
        ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
        ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]
        #分类和回归保留相同的数目
    cls_loss = ohem_cls_loss.sum() / keep_num
    loc_loss = ohem_loc_loss.sum() / keep_num
    #然后分别对分类和回归loss求均值
    return cls_loss, loc_loss

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/139445.html原文链接:https://javaforall.net

(0)
上一篇 2022年5月29日 下午8:00
下一篇 2022年5月29日 下午8:00


相关推荐

  • Mac(OSX)下媲美XShell的神器Termius「建议收藏」

    Mac(OSX)下媲美XShell的神器Termius「建议收藏」文章目录简介特点软件环境配置配置项配置密钥配置说明配置主机配置项简介XShell的大名不用多说,称它为Windows平台最好用的远程终端不为过吧。唯一不足的地方就是它只有Windows版本。所以今天跟大家介绍一款全平台的远程终端——Termius。Termius不仅涵盖了Windows、Linux、OSX,还变态得支持Android和iOS(以后在地铁、公交上都可以随时拿出手机来排查线上问题啦…………

    2022年7月20日
    27
  • sql语句字符串用单引号还是双引号_sql什么时候用单引号

    sql语句字符串用单引号还是双引号_sql什么时候用单引号总结一下SQL语句中引号(‘)、quotedstr()、(”)、format()在SQL语句中的用法以及SQL语句中日期格式的表示(#)、(”)在Delphi中进行字符变量连接相加时单引号用(”’),又引号用(””)表示首先定义变量varAnInt:integer=123;//为了方便在此都给它们赋初值。虽然可能在引赋初值在某些情况下不对AnIntStr:string=’45…

    2022年8月31日
    7
  • Kimi新一轮10亿美元融资正在进行 估值涨至180亿美元

    Kimi新一轮10亿美元融资正在进行 估值涨至180亿美元

    2026年3月14日
    3
  • vue之解决跨域问题[通俗易懂]

    vue之解决跨域问题[通俗易懂]同源策略:http协议、主机名、端口号都要相同。因为浏览器同源策略的影响,向后端服务器请求数据的时候,不能进行访问。可以采用代理服务器的方式,代理服务器:浏览器向一个相同同源策略的g代理服务器上请求资源,因为服务器之间没有同源策略,代理服务器就去找后端服务器请求资源,在返回给浏览器解决方法一:在根目录下新建vue.config.js文件,这里是js文件哈。module.exports={ lintOnSave:false,//取消格式化 devServer:{ proxy:

    2025年12月13日
    3
  • 截取示波器网络图片[通俗易懂]

    截取示波器网络图片[通俗易懂]■问题由来手边有一台相对比较古老的Tektronix的示波器TDS3054D示波器,四通道的。它可以通过联网获得波形的图片。对于记录观察到的波形相对比较方便。▲示波器及其联网获得屏幕图片在截取示波器波形窗口的过程中,由于上面出现红色的字体(HOME:TDS3054BAA(192.168.0.101))的影响,经常使得截取图片出现不完整,因此希望通过软件(PYTHON程序)自动完成精细截取的过程。▲截取示波器波形窗口TDS3054B的显示模式包括两种:普通显示模式:示波器的

    2022年10月12日
    5
  • 交换机telnet配置

    交换机telnet配置

    2021年7月30日
    72

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号