batch内负采样

batch内负采样一般在计算softmax交叉熵时,需要用tf.nn.log_uniform_candidate_sampler多itemid做随机负采样。但是在类似dssm这种双塔模型中,item侧特征除了itemid外,还有其他meta特征,此时负样本对itemid做负采样后,还需要取相应负样本的meta特征。可是在tf训练数据中并不方便建立itemid与各类meta特征的映射表。为了解决dssm类模型的负采样问题,可以取一个batch内其他用户的正样本做为本用户的负样本,以解决负采样meta特征问题。好了,废话少说,

大家好,又见面了,我是你们的朋友全栈君。

一般在计算softmax交叉熵时,需要用tf.nn.log_uniform_candidate_sampler对itemid做随机负采样。但是在类似dssm这种双塔模型中,item侧特征除了itemid外,还有其他meta特征,此时负样本对itemid做负采样后,还需要取相应负样本的meta特征。可是在tf训练数据中并不方便建立itemid与各类meta特征的映射表。
为了解决dssm类模型的负采样问题,可以取一个batch内其他用户的正样本做为本用户的负样本,以解决负采样meta特征问题。好了,废话少说,直接上代码

     for i in range(NEG):
        rand = int((random.random() + i) * batchSize / NEG)
        item_y = tf.concat([item_y,
                            tf.slice(item_y_temp, [rand, 0], [batchSize - rand, -1]),
                            tf.slice(item_y_temp, [0, 0], [rand, -1])], 0)
      prod_raw = tf.reduce_sum(tf.multiply(tf.tile(user_y, [NEG + 1, 1]), item_y), 1, True)
      prod = tf.transpose(tf.reshape(tf.transpose(prod_raw), [NEG + 1, batchSize])) 
      # 转化为softmax概率矩阵。
      prob = tf.nn.softmax(prod)
      # 只取第一列,即正样本列概率。
      hit_prob = tf.slice(prob, [0, 0], [-1, 1])
      loss = -tf.reduce_mean(tf.log(hit_prob))

代码注解:
其中item_y和item_y_temp 初始化为item侧最后一层embedding值,shape为[batchSize, emb_size]。
user_y为user侧最后一层embedding值,shape为[batchSize, emb_size]。
NEG为负采样个数,batchSize为batch大小。

  1. 在每次循环中,通过rand值打乱item_y_temp的行顺序,相当于取其他用户的正样本做为本用户的负样本
  2. 经历NEG次循环后,item_y的shape变为[(NEG+1)*batchSize, emb_size];注:item_y初始值有batchSize行,每次循环累加batchSize行
  3. 与user_emb点乘后,prod_raw的shape为[(NEG+1)*batch_size,1],
  4. 经过reshape和转置后,prod的shape为[batch_size,(NEG+1)];注:prod的第一列为正样本,其他列为负样本。

后面即可计算出采样后的softmax交叉熵了。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/149671.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • 图像传感器的 DVP 信号

    图像传感器的 DVP 信号一、DVP简述DVP是数字视频端口(digitalvideoport)的简称,传统的sensor输出接口,采用并行输出方式,DVP总线PCLK极限约在96M左右,所有DVP最大速率最好控制在72M以下,DVP是并口,需要PCLK、VSYNC、HSYNC、D[0:11]——可以是8/10/12bit数据,具体情况要看ISP或baseband是否支持。DVP接口在信号完整性方面受限制,速率也受限制。如图1所示,并口传输数据需要帧同步信号(Vsync

    2022年5月27日
    33
  • [MAC]用beamoff给VMware的Mac OS X 10.10.x加速

    [MAC]用beamoff给VMware的Mac OS X 10.10.x加速MACOSX10 10 xYosemite 在 VMWare 中实在是太慢了 卡出翔 好在高人多 请装 beamoff 详见 https github com JasF beamoff gitCSDN 下载 http download csdn net download bytige 本站下载 http files cnblogs com files yipu b

    2025年11月13日
    3
  • 2014 (多校)1011 ZCC Loves Codefires

    2014 (多校)1011 ZCC Loves Codefires

    2021年12月2日
    48
  • 关于ie下阻止ActiveX控件

    关于ie下阻止ActiveX控件
    最近,公司的项目上有个部分要用到ActiveX控件。可是在访问的时候,就会弹出”Internetexplorer已经阻止站点用不安全方式使用ActiveX控件”一句。查了好多资料,除了更改ie的安全设置,没有其他方法。
    更改ie安全设置,需要更改的几个地方:
    首先,Internet选项–>安全
    1.选中Internet –“自定义级别”– “ActiveX控件和插件 “–“对未标记为可安全执行脚本的ActiveX控件初始化并执行脚本”(启用

    2022年5月14日
    47
  • java 重写和重载的区别[通俗易懂]

    java 重写和重载的区别[通俗易懂]classAnimal{ privateStringname; privateStringsex; privateintage; publicAnimal(){ //TODOAuto-generatedconstructorstub } publicvoidmove() { System.out.println("animalmove…

    2025年10月14日
    7
  • mysql左连接查询慢[通俗易懂]

    mysql左连接查询慢[通俗易懂]之前一直用的Oracle,今天用mysql查询一个很普通的左连接的时候,发现速度很慢。selectx.fid,x.isbirt,x.fscoresum,x.fsystemscore,x.feffectivescorefromtableaxleftjointablebhonx.fitemid=h.fidwhereh.fprojectid=’’这个sql耗时:2s多。我有点吓到了,后来我百度后发现然后我换了表的位置selectx.fid,x.isbirt,x.fsc

    2022年5月22日
    60

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号