batch内负采样

batch内负采样一般在计算softmax交叉熵时,需要用tf.nn.log_uniform_candidate_sampler多itemid做随机负采样。但是在类似dssm这种双塔模型中,item侧特征除了itemid外,还有其他meta特征,此时负样本对itemid做负采样后,还需要取相应负样本的meta特征。可是在tf训练数据中并不方便建立itemid与各类meta特征的映射表。为了解决dssm类模型的负采样问题,可以取一个batch内其他用户的正样本做为本用户的负样本,以解决负采样meta特征问题。好了,废话少说,

大家好,又见面了,我是你们的朋友全栈君。

一般在计算softmax交叉熵时,需要用tf.nn.log_uniform_candidate_sampler对itemid做随机负采样。但是在类似dssm这种双塔模型中,item侧特征除了itemid外,还有其他meta特征,此时负样本对itemid做负采样后,还需要取相应负样本的meta特征。可是在tf训练数据中并不方便建立itemid与各类meta特征的映射表。
为了解决dssm类模型的负采样问题,可以取一个batch内其他用户的正样本做为本用户的负样本,以解决负采样meta特征问题。好了,废话少说,直接上代码

     for i in range(NEG):
        rand = int((random.random() + i) * batchSize / NEG)
        item_y = tf.concat([item_y,
                            tf.slice(item_y_temp, [rand, 0], [batchSize - rand, -1]),
                            tf.slice(item_y_temp, [0, 0], [rand, -1])], 0)
      prod_raw = tf.reduce_sum(tf.multiply(tf.tile(user_y, [NEG + 1, 1]), item_y), 1, True)
      prod = tf.transpose(tf.reshape(tf.transpose(prod_raw), [NEG + 1, batchSize])) 
      # 转化为softmax概率矩阵。
      prob = tf.nn.softmax(prod)
      # 只取第一列,即正样本列概率。
      hit_prob = tf.slice(prob, [0, 0], [-1, 1])
      loss = -tf.reduce_mean(tf.log(hit_prob))

代码注解:
其中item_y和item_y_temp 初始化为item侧最后一层embedding值,shape为[batchSize, emb_size]。
user_y为user侧最后一层embedding值,shape为[batchSize, emb_size]。
NEG为负采样个数,batchSize为batch大小。

  1. 在每次循环中,通过rand值打乱item_y_temp的行顺序,相当于取其他用户的正样本做为本用户的负样本
  2. 经历NEG次循环后,item_y的shape变为[(NEG+1)*batchSize, emb_size];注:item_y初始值有batchSize行,每次循环累加batchSize行
  3. 与user_emb点乘后,prod_raw的shape为[(NEG+1)*batch_size,1],
  4. 经过reshape和转置后,prod的shape为[batch_size,(NEG+1)];注:prod的第一列为正样本,其他列为负样本。

后面即可计算出采样后的softmax交叉熵了。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/149671.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 解决哈希冲突的方法「建议收藏」

    解决哈希冲突的方法「建议收藏」在实际的应用中,选取合适的哈希函数可减少冲突,但冲突是不可避免的。所以我就想给大家说几种解决哈希冲突的方法啦~首先就是开放定址法,用这个方法处理冲突的核心思想就是在冲突发生的时候,形成一个地址序列,顺着这个序列挨个去检查探测,一直等到找到一个“空”的开放地址。把我们发生冲突的关键字值存放到这个“空”地址中去。这个地址的算法一般就是:Hi=(H(key)+di)%m  这里面的i=1,2,。

    2022年6月17日
    41
  • JS数组合并(5种)

    JS数组合并(5种)前言项目过程中,经常会遇到JS数组合并的情况,时常为这个纠结。这里整理一下。简单而实用的for最容易想到的莫过于for了。会变更原数组,当然也可以写成生成新数组的形式。letarr=[1,2]letarr2=[3,4]for(letiinarr2){arr.push(arr2[i])}console.log(arr)//[1,2,3,4]arr.concat(arr2)会生成新的数组。letarr=[1,2]let

    2022年6月30日
    45
  • layoutSubviews 和 drawRect

    layoutSubviews 和 drawRect转自http://justsee.iteye.com/blog/1886463UIView的setNeedsDisplay和setNeedsLayout方法。首先两个方法都是异步执行的。setNeedsDisplay会调用自动调用drawRect方法,这样可以拿到UIGraphicsGetCurrentContext,就可以画画了。而setNeedsLayout会默认调用lay

    2022年7月15日
    14
  • vue脚手架基本使用「建议收藏」

    vue脚手架基本使用「建议收藏」vue脚手架基本使用

    2022年4月22日
    62
  • 菜鸟教程-maven[通俗易懂]

    菜鸟教程-maven[通俗易懂]Maven基于项目对象模型(缩写:POM)概念 Maven是一个项目管理工具,可以对Java项目进行构建、依赖管理。 Maven是一个基于Java的工具,所以要做的第一件事情就是安装JDK。 Maven提倡使用一个共同的标准目录结构,Maven使用约定优于配置的原则,大家尽可能的遵守这样的目录结构。如下所示: 目录 目的 ${basedir} 存放pom.xml和所有的子目录 ${basedir}/src/main/java 项目的ja

    2025年10月6日
    3
  • getcomponent_getsocketopt

    getcomponent_getsocketoptGetMessage函数功能GetMessage是计算机编程中的一个函数,从调用线程的消息队列里取得一个消息并把其放于指定的结构。GetMessage函数可取得与指定窗口联系的消息和由PostThreadMesssge寄送的线程消息,接收一定范围的消息值,不接收属于其他线程或应用程序的消息。GetMessage获取消息成功后,线程把从消息队列中删除该消息,函数会一直等待直到有消息到来才有返回值。函数声明WINUSERAPIBOOLWINAPIGetMessage(_Out_LPMS

    2025年11月8日
    2

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号