FFM算法 Python实现

FFM算法 Python实现本算法是CTR中的系列算法之一,具体的原理就不说了。网上其他的博客一大堆。都是互相抄来抄去,写上去之后容易让人误会。因此我只传上代码实现部分。大家做个参考。这里我们的FFM算法是基于Tensorflow实现的。为什么用Tensorflow呢?观察二次项,由于field的引入,Vffm需要计算的参数有nfk个,远多于FM模型的nk个,而且由于每次计算都依赖于乘以的xj的field,所以…

大家好,又见面了,我是你们的朋友全栈君。

本算法是CTR中的系列算法之一,具体的原理就不说了。网上其他的博客一大堆。都是互相抄来抄去,写上去之后容易让人误会。因此我只传上代码实现部分。大家做个参考。

这里我们的FFM算法是基于Tensorflow实现的。

为什么用Tensorflow呢?观察二次项,由于field的引入,Vffm需要计算的参数有 nfk 个,远多于FM模型的 nk个,而且由于每次计算都依赖于乘以的xj的field,所以,无法用fm的计算技巧(ab = 1/2(a+b)^2-a^2-b^2),所以计算复杂度是 O(n^2)。

因此使用Tensorflow的目的是想通过GPU进行计算。同时这也给我们提供了一个思路:如果模型的计算复杂度较高,当不能使用CPU快速完成模型训练时,可以考虑使用GPU计算。比如Xgboost是已经封装好可以用在GPU上的算法库,而那些没有GPU版本的封装算法库时,例如我们此次采用的FFM算法,我们可以借助Tensorflow的GPU版本框架设计算法,并完成模型训练。

代码主要分三部分:

build_data.py

主要是完成对原始数据的转化。主要包括构造特征值对应field的字典。

FFM.py

主要包括线性部分及非线性部分的代码实现。

tools.py

主要包括训练集的构造。

这里我们主要分析 FFM.py,也就是模型的构建过程:

首先初始化一些参数,包括:

  • k:隐向量长度
  • f :field个数
  • p:特征值个数
  • 学习率大小
  • 批训练大小
  • 正则化
  • 模型保存位置等

代码如下图所示:

FFM算法 Python实现

然后,构造了一个model类,主要存放:

  • 初始化的一些参数
  • 模型结构
  • 模型训练op(参数更新)
  • 预测op
  • 模型保存以及载入的op

代码如下图所示:

FFM算法 Python实现

之后,对模型构造部分代码进行分析,可发现模型由两部分组成,第一部分是下图红框内容,其实就是线性表达式 w^Tx+b,其中:

  • b shape(None,1)
  • x shape (batch_size,p)
  • w1 shape(p,1)    注:p为特征值个数

定义变量及初始化后,就可以构造线性模型,代码如下图所示:

FFM算法 Python实现

然后,定义一个Vffm变量用来存放交叉项的权重,并初始化。因为我们已经了解到Vffm是一个三维向量,所以,v :shape(p,f,k) 。

之后是vi,fj、vj,fi的构造。因为v 有p行,代表共有p个特征值,所以vifj = v[i, feature2field[j]],说人话就是第i个特征值在第j个特征值对应的field上的隐向量。

vjfi 的构造方法类似,所以vivj就可以求出来.然后就是把交叉项累加,然后 reshape 成(batch_size,1)的形状,以便与线性模型进行矩阵加法计算。

代码如下图所示:

FFM算法 Python实现

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/131719.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • android:layout_gravity 和 android:gravity 的区别

    android:layout_gravity 和 android:gravity 的区别

    2021年8月26日
    52
  • 高并发解决方案相关面试题

    高并发解决方案相关面试题什么是DNS解析域名DNS域名解析就是讲域名转化为不需要显示端口(二级域名的端口一般为80)的IP地址,域名解析的一般先去本地环境的host文件读取配置,解析成对应的IP地址,根据IP地址访问对应的服务器。若host文件未配置,则会去网络运营商获取对应的IP地址和域名.什么是NginxNginx是一个高级的轻量级的web服…

    2022年5月22日
    29
  • 数据结构之数组和链表的区别

    数据结构之数组和链表的区别第一题便是数据结构中的数组和链表的区别数组(Array)一、数组特点:所谓数组,就是相同数据类型的元素按一定顺序排列的集合;数组的存储区间是连续的,占用内存比较大,故空间复杂的很大。但数组的二分查找时间复杂度小,都是O(1);数组的特点是:查询简单,增加和删除困难;1.1在内存中,数组是一块连续的区域1.2数组需要预留空间在使用前需要提前申请所占内存的大小,…

    2022年6月29日
    15
  • C/C++常见面试知识点总结附面试真题—-20220326更新

    C/C++常见面试知识点总结附面试真题—-20220326更新以下内容部分整理自网络,部分为自己面试的真题。第一部分:计算机基础1.C/C++内存有哪几种类型?C中,内存分为5个区:堆(malloc)、栈(如局部变量、函数参数)、程序代码区(存放二进制代码)、全局/静态存储区(全局变量、static变量)和常量存储区(常量)。此外,C++中有自由存储区(new)一说。2.堆和栈的区别?1).堆存放动态分配的对象——即那些在程序运行时分配的对象…

    2022年7月15日
    16
  • Java并发的CAS原理详解[通俗易懂]

    Java并发的CAS原理详解[通俗易懂]Java并发编程中的CAS原理是很重要的概念。CAS加volatile关键字是实现并发包的基石。没有CAS就不会有并发包,synchronized是一种独占锁、悲观锁,java.util.concurrent中借助了CAS指令实现了一种区别于synchronized的一种乐观锁。乐观锁和悲观锁的概念请参考Java中的21种锁。在Java中java.util.concurrent.atomic包下面的原子变量就是使用了乐观锁的一种实现方式CAS实现。在JDK5之前Java语言是靠synchroniz

    2022年10月10日
    0
  • tcp粘包是怎么产生的_tcp报文格式

    tcp粘包是怎么产生的_tcp报文格式tcp粘包是怎么产生的?1、什么是tcp粘包?发送方发送的多个数据包,到接收方缓冲区首尾相连,粘成一包,被接收。2、原因TCP协议默认使用Nagle算法可能会把多个数据包一次发送到接收方。应用程读取缓存中的数据包的速度小于接收数据包的速度,缓存中的多个数据包会被应用程序当成一个包一次读取。3、处理方法发送方使用TCP_NODELAY选项来…

    2022年8月11日
    4

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号