batch内负采样

batch内负采样一般在计算softmax交叉熵时,需要用tf.nn.log_uniform_candidate_sampler多itemid做随机负采样。但是在类似dssm这种双塔模型中,item侧特征除了itemid外,还有其他meta特征,此时负样本对itemid做负采样后,还需要取相应负样本的meta特征。可是在tf训练数据中并不方便建立itemid与各类meta特征的映射表。为了解决dssm类模型的负采样问题,可以取一个batch内其他用户的正样本做为本用户的负样本,以解决负采样meta特征问题。好了,废话少说,

大家好,又见面了,我是你们的朋友全栈君。

一般在计算softmax交叉熵时,需要用tf.nn.log_uniform_candidate_sampler对itemid做随机负采样。但是在类似dssm这种双塔模型中,item侧特征除了itemid外,还有其他meta特征,此时负样本对itemid做负采样后,还需要取相应负样本的meta特征。可是在tf训练数据中并不方便建立itemid与各类meta特征的映射表。
为了解决dssm类模型的负采样问题,可以取一个batch内其他用户的正样本做为本用户的负样本,以解决负采样meta特征问题。好了,废话少说,直接上代码

     for i in range(NEG):
        rand = int((random.random() + i) * batchSize / NEG)
        item_y = tf.concat([item_y,
                            tf.slice(item_y_temp, [rand, 0], [batchSize - rand, -1]),
                            tf.slice(item_y_temp, [0, 0], [rand, -1])], 0)
      prod_raw = tf.reduce_sum(tf.multiply(tf.tile(user_y, [NEG + 1, 1]), item_y), 1, True)
      prod = tf.transpose(tf.reshape(tf.transpose(prod_raw), [NEG + 1, batchSize])) 
      # 转化为softmax概率矩阵。
      prob = tf.nn.softmax(prod)
      # 只取第一列,即正样本列概率。
      hit_prob = tf.slice(prob, [0, 0], [-1, 1])
      loss = -tf.reduce_mean(tf.log(hit_prob))

代码注解:
其中item_y和item_y_temp 初始化为item侧最后一层embedding值,shape为[batchSize, emb_size]。
user_y为user侧最后一层embedding值,shape为[batchSize, emb_size]。
NEG为负采样个数,batchSize为batch大小。

  1. 在每次循环中,通过rand值打乱item_y_temp的行顺序,相当于取其他用户的正样本做为本用户的负样本
  2. 经历NEG次循环后,item_y的shape变为[(NEG+1)*batchSize, emb_size];注:item_y初始值有batchSize行,每次循环累加batchSize行
  3. 与user_emb点乘后,prod_raw的shape为[(NEG+1)*batch_size,1],
  4. 经过reshape和转置后,prod的shape为[batch_size,(NEG+1)];注:prod的第一列为正样本,其他列为负样本。

后面即可计算出采样后的softmax交叉熵了。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/149671.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • python中sqrt函数用法_Python : sqrt() 函数

    python中sqrt函数用法_Python : sqrt() 函数开平方函数sqrt()返回x的平方根(x>0)语法:importmathmath.sqrt(x)注意:此函数不可直接访问,需要导入math模块,然后需要使用math静态对象调用此函数。参数x—数值表达式返回结果是浮点数。importmath#Thiswillimportmathmoduleprint”math.sqrt(100):”,math.s…

    2022年6月2日
    66
  • 修改asmx样式

    修改asmx样式今天看到一张图,asmx的WebService。长这样:当时就感觉有意思,这个页面风格和我们平时的不一样,我们平时的WebService长这样:我们如果在WebMetohd上面加注释,即[WebMethod(Description=”注释”)],那么长这样:那么问题就来了,第一张图里面的样式是如何实现的呢?在浏览器上进入调试模式观察,可以发现它的html和我们的有点不…

    2022年4月29日
    41
  • 微信不能登录网页版(微信手机网页登录)

     因为出于工作和学习的目的,我的个人电脑操作系统使用的是Ubuntu18.04LTS,就目前而言,许多优秀的软件都有Linux版本,虽然Linux的用户相对群体较小,但是其软件生态也在逐渐成长,而且日常使用浏览器就可以解决许多的应用需求。现在微信和QQ在生活和工作领域均是举足轻重,emm…  BUT!!!腾讯到目前为止并没有推出Linux版的微信和TI…

    2022年4月12日
    61
  • Springboot和Spring的区别?看完你就明白了

    Springboot和Spring的区别?看完你就明白了从一道面试题说起面试的时候经常会被问到,spring和springboot的区别。或者SpringMVC和Springboot的区别。其实这样的问法就不是特别合适。因为spring、springboot、springmvc他们三个在spring体系中就不在同一个维度。看一下spring的全部项目spring家族有很多项目,springboot、springframework、springcloud等。我们常用的也就是,springboot、springcloud、springsecu

    2022年6月8日
    42
  • Alex 的 Hadoop 菜鸟教程: 第15课 Impala 安装使用教程

    Alex 的 Hadoop 菜鸟教程: 第15课 Impala 安装使用教程本教程介绍Impala的安装,使用和JDBC调用。为什么用Impala?因为Hive太慢了!Impala也可以执行SQL,但是比Hive的速度快很多。为什么Impala可以比Hive快呢?因为Hive采用的是把你的sql转化成hadoop的MapReduce任务的代码,然后编译,打包成jar包,并分发到各个server上执行,这是一个相当慢的过程。而Impala根本就不用Hadoop的MapReduce机制,直接调用HDFS的API获取文件,在自己的内存中进行计算。

    2022年5月2日
    48
  • IP地址划分[通俗易懂]

    IP地址划分[通俗易懂]IP地址划分1IP地址分类(1)A类IP地址一个A类IP地址由1字节的网络地址和3字节主机地址组成,网络地址的最高位必须是“0”,地址范围:1.0.0.1——126.255.255.254二进制表示为:00000001000000000000000000000001——01111110111111111111111111111110可用的A类网络有126个,每个网络能容纳…

    2022年6月11日
    52

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号