随机森林 算法原理详解与实现步骤

随机森林 算法原理详解与实现步骤最近在研究 TLD 算法 在做目标检测的时候 用到了随机森林的组合分类器 话说现在的检测问题可以归结为分类问题 丰富了解决问题的手段 废话少说 接着看 1 随机森林原理介绍随机森林 指的是利用多棵树对样本进行训练并预测的一种分类器 该分类器最早由 LeoBreiman 和 AdeleCutler 提出 并被注册成了商标 简单来说 随机森林就是由多棵 CART Classificati

最近在研究TLD算法,在做目标检测的时候,用到了随机森林的组合分类器。话说现在的检测问题可以归结为分类问题,丰富了解决问题的手段。废话少说,接着看。

1.随机森林原理介绍

随机森林,指的是利用多棵树对样本进行训练并预测的一种分类器。该分类器最早由Leo Breiman和Adele Cutler提出,并被注册成了商标。简单来说,随机森林就是由多棵CART(Classification And Regression Tree)构成的。对于每棵树,它们使用的训练集是从总的训练集中有放回采样出来的,这意味着,总的训练集中的有些样本可能多次出现在一棵树的训练集中,也可能从未出现在一棵树的训练集中。在训练每棵树的节点时,使用的特征是从所有特征中按照一定比例随机地无放回的抽取的,根据Leo Breiman的建议,假设总的特征数量为M,这个比例可以是sqrt(M),1/2sqrt(M),2sqrt(M)。

因此,随机森林的训练过程可以总结如下:

(1)给定训练集S,测试集T,特征维数F。确定参数:使用到的CART的数量t,每棵树的深度d,每个节点使用到的特征数量f,终止条件:节点上最少样本数s,节点上最少的信息增益m

对于第1-t棵树,i=1-t:

(2)从S中有放回的抽取大小和S一样的训练集S(i),作为根节点的样本,从根节点开始训练

(3)如果当前节点上达到终止条件,则设置当前节点为叶子节点,如果是分类问题,该叶子节点的预测输出为当前节点样本集合中数量最多的那一类c(j),概率p为c(j)占当前样本集的比例;如果是回归问题,预测输出为当前节点样本集各个样本值的平均值。然后继续训练其他节点。如果当前节点没有达到终止条件,则从F维特征中无放回的随机选取f维特征。利用这f维特征,寻找分类效果最好的一维特征k及其阈值th,当前节点上样本第k维特征小于th的样本被划分到左节点,其余的被划分到右节点。继续训练其他节点。有关分类效果的评判标准在后面会讲。

(4)重复(2)(3)直到所有节点都训练过了或者被标记为叶子节点。

(5)重复(2),(3),(4)直到所有CART都被训练过。

利用随机森林的预测过程如下:

对于第1-t棵树,i=1-t:

(1)从当前树的根节点开始,根据当前节点的阈值th,判断是进入左节点( =th),直到到达,某个叶子节点,并输出预测值。

(2)重复执行(1)直到所有t棵树都输出了预测值。如果是分类问题,则输出为所有树中预测概率总和最大的那一个类,即对每个c(j)的p进行累计;如果是回归问题,则输出为所有树的输出的平均值。

注:有关分类效果的评判标准,因为使用的是CART,因此使用的也是CART的评判标准,和C3.0,C4.5都不相同。

对于分类问题(将某个样本划分到某一类),也就是离散变量问题,CART使用Gini值作为评判标准。定义为Gini=1-∑(P(i)*P(i)),P(i)为当前节点上数据集中第i类样本的比例。例如:分为2类,当前节点上有100个样本,属于第一类的样本有70个,属于第二类的样本有30个,则Gini=1-0.7×07-0.3×03=0.42,可以看出,类别分布越平均,Gini值越大,类分布越不均匀,Gini值越小。在寻找最佳的分类特征和阈值时,评判标准为:argmax(Gini-GiniLeft-GiniRight),即寻找最佳的特征f和阈值th,使得当前节点的Gini值减去左子节点的Gini和右子节点的Gini值最大。

对于回归问题,相对更加简单,直接使用argmax(Var-VarLeft-VarRight)作为评判标准,即当前节点训练集的方差Var减去减去左子节点的方差VarLeft和右子节点的方差VarRight值最大。

 

2.OpenCV函数使用

OpenCV提供了随机森林的相关类和函数。具体使用方法如下:

(1)首先利用CvRTParams定义自己的参数,其格式如下

 

 CvRTParams::CvRTParams(int max_depth, int min_sample_count, float regression_accuracy, bool use_surrogates, int max_categories, const float* priors, bool calc_var_importance, int nactive_vars, int max_num_of_trees_in_the_forest, float forest_accuracy, int termcrit_type)

 

大部分参数描述都在http://docs.opencv.org/modules/ml/doc/random_trees.html上面有,说一下没有描述的几个参数的意义

bool use_surrogates:是否使用代理,指的是,如果当前的测试样本缺少某些特征,但是在当前节点上的分类or回归特征正是缺少的这个特征,那么这个样本就没法继续沿着树向下走了,达不到叶子节点的话,就没有预测输出,这种情况下,可以利用当前节点下面的所有子节点中的叶子节点预测输出的平均值,作为这个样本的预测输出。

const float*priors:先验知识,这个指的是,可以根据各个类别样本数量的先验分布,对其进行加权。比如:如果一共有3类,第一类样本占整个训练集的80%,其余两类各占10%,那么这个数据集里面的数据就很不平均,如果每类的样本都加权的话,就算把所有样本都预测成第一类,那么准确率也有80%,这显然是不合理的,因此我们需要提高后两类的权重,使得后两类的分类正确率也不会太低。

float regression_accuracy:回归树的终止条件,如果当前节点上所有样本的真实值和预测值之间的差小于这个数值时,停止生产这个节点,并将其作为叶子节点。

后来发现这些参数在决策树里面有解释,英文说明在这里http://docs.opencv.org/modules/ml/doc/decision_trees.html#cvdtreeparams

具体例子如下,网上找了个别人的例子,自己改成了可以读取MNIST数据并且做分类的形式,如下:

 

复制代码
#include 
    
    // 
     opencv general include file #include 
     
     // 
      opencv machine learning include file #include 
      
      using 
      namespace cv; 
      // 
       OpenCV API is in the C++ "cv" namespace 
      /* 
       
      */ 
      // 
       global definitions (for speed and ease of use)  
      // 
      手写体数字识别 
      #define NUMBER_OF_TRAINING_SAMPLES 60000 
      #define ATTRIBUTES_PER_SAMPLE 784 
      #define NUMBER_OF_TESTING_SAMPLES 10000 
      #define NUMBER_OF_CLASSES 10 
      // 
       N.B. classes are integer handwritten digits in range 0-9 
      /* 
       
      */ 
      // 
       loads the sample database from file (which is a CSV text file) 
       inline  
      void revertInt( 
      int& 
      x) { x=((x& 
      0x000000ff)<< 
      24)|((x& 
      0x0000ff00)<< 
      8)|((x& 
      0x00ff0000)>> 
      8)|((x& 
      0xff000000)>> 
      24 
      ); };  
      int read_data_from_csv( 
      const 
      char* samplePath, 
      const 
      char* 
       labelPath, Mat data, Mat classes,  
      int 
       n_samples ) { FILE* sampleFile=fopen(samplePath, 
      " 
      rb 
      " 
      ); FILE* labelFile=fopen(labelPath, 
      " 
      rb 
      " 
      );  
      int mbs= 
      0,number= 
      0,col= 
      0,row= 
      0 
      ; fread(&mbs, 
      4, 
      1 
      ,sampleFile); fread(&number, 
      4, 
      1 
      ,sampleFile); fread(&row, 
      4, 
      1 
      ,sampleFile); fread(&col, 
      4, 
      1 
      ,sampleFile); revertInt(mbs); revertInt(number); revertInt(row); revertInt(col); fread(&mbs, 
      4, 
      1 
      ,labelFile); fread(&number, 
      4, 
      1 
      ,labelFile); revertInt(mbs); revertInt(number); unsigned  
      char 
       temp;  
      for( 
      int line = 
      0; line < n_samples; line++ 
      ) {  
      // 
       for each attribute on the line in the file 
      for( 
      int attribute = 
      0; attribute < (ATTRIBUTES_PER_SAMPLE + 
      1); attribute++ 
      ) {  
      if (attribute < 
       ATTRIBUTES_PER_SAMPLE) {  
      // 
       first 64 elements (0-63) in each line are the attributes fread(&temp, 
      1, 
      1 
      ,sampleFile);  
      // 
      fscanf(f, "%f,", &tmp); data.at< 
      float>(line, attribute) = static_cast< 
      float> 
      (temp);  
      // 
       printf("%f,", data.at 
       
         (line, attribute)); 
        
       }  
      else 
      if (attribute == 
       ATTRIBUTES_PER_SAMPLE) {  
      // 
       attribute 65 is the class label {0 ... 9} fread(&temp, 
      1, 
      1 
      ,labelFile);  
      // 
      fscanf(f, "%f,", &tmp); classes.at< 
      float>(line, 
      0) = static_cast< 
      float> 
      (temp);  
      // 
       printf("%f\n", classes.at 
       
         (line, 0)); 
        
       } } } fclose(sampleFile); fclose(labelFile);  
      return 
      1; 
      // 
       all OK 
      }  
      /* 
       
      */ 
      int main( 
      int argc, 
      char 
       argv ) {  
      for ( 
      int i= 
      0; i< argc; i++ 
      ) std::cout< 
      
        std::endl; 
       
      
     
   // lets just check the version first printf ("OpenCV version %s (%d.%d.%d)\n", CV_VERSION, CV_MAJOR_VERSION, CV_MINOR_VERSION, CV_SUBMINOR_VERSION); //定义训练数据与标签矩阵 Mat training_data = Mat(NUMBER_OF_TRAINING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1); Mat training_classifications = Mat(NUMBER_OF_TRAINING_SAMPLES, 1, CV_32FC1); //定义测试数据矩阵与标签 Mat testing_data = Mat(NUMBER_OF_TESTING_SAMPLES, ATTRIBUTES_PER_SAMPLE, CV_32FC1); Mat testing_classifications = Mat(NUMBER_OF_TESTING_SAMPLES, 1, CV_32FC1); // define all the attributes as numerical // alternatives are CV_VAR_CATEGORICAL or CV_VAR_ORDERED(=CV_VAR_NUMERICAL) // that can be assigned on a per attribute basis  Mat var_type = Mat(ATTRIBUTES_PER_SAMPLE + 1, 1, CV_8U ); var_type.setTo(Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical // this is a classification problem (i.e. predict a discrete number of class // outputs) so reset the last (+1) output var_type element to CV_VAR_CATEGORICAL  var_type.at 
  
    (ATTRIBUTES_PER_SAMPLE, 
   0) = 
    CV_VAR_CATEGORICAL;  
   double result; 
   // 
    value returned from a prediction  
   // 
   加载训练数据集和测试数据集 
   if (read_data_from_csv(argv[ 
   1],argv[ 
   2], training_data, training_classifications, NUMBER_OF_TRAINING_SAMPLES) && 
    read_data_from_csv(argv[ 
   3],argv[ 
   4 
   ], testing_data, testing_classifications, NUMBER_OF_TESTING_SAMPLES)) {  
   /* 
   *步骤1:定义初始化Random Trees的参数* 
   */ 
   float priors[] = { 
    
   1, 
   1, 
   1, 
   1, 
   1, 
   1, 
   1, 
   1, 
   1, 
   1}; 
   // 
    weights of each classification for classes CvRTParams 
   params = CvRTParams( 
   20, 
   // 
    max depth 
   50, 
   // 
    min sample count 
   0, 
   // 
    regression accuracy: N/A here 
   false, 
   // 
    compute surrogate split, no missing data 
   15, 
   // 
    max number of categories (use sub-optimal algorithm for larger numbers) priors, 
   // 
    the array of priors 
   false, 
   // 
    calculate variable importance 
   50, 
   // 
    number of variables randomly selected at node and used to find the best split(s). 
   100, 
   // 
    max number of trees in the forest 
   0.01f, 
   // 
    forest accuracy CV_TERMCRIT_ITER | CV_TERMCRIT_EPS 
   // 
    termination cirteria 
    );  
   /* 
   *步骤2:训练 Random Decision Forest(RDF)分类器 
   */ 
    printf(  
   " 
   \nUsing training database: %s\n\n 
   ", argv[ 
   1 
   ]); CvRTrees* rtree = 
   new 
    CvRTrees;  
   bool train_result=rtree-> 
   train(training_data, CV_ROW_SAMPLE, training_classifications, Mat(), Mat(), var_type, Mat(),  
   params 
   );  
   // 
    float train_error=rtree->get_train_error();  
   // 
    printf("train error:%f\n",train_error);  
   // 
    perform classifier testing and report results 
    Mat test_sample;  
   int correct_class = 
   0 
   ;  
   int wrong_class = 
   0 
   ;  
   int false_positives [NUMBER_OF_CLASSES] = { 
    
   0, 
   0, 
   0, 
   0, 
   0, 
   0, 
   0, 
   0, 
   0, 
   0 
   }; printf(  
   " 
   \nUsing testing database: %s\n\n 
   ", argv[ 
   2 
   ]);  
   for ( 
   int tsample = 
   0; tsample < NUMBER_OF_TESTING_SAMPLES; tsample++ 
   ) {  
   // 
    extract a row from the testing matrix test_sample = 
    testing_data.row(tsample);  
   /* 
   *步骤3:预测 
   */ 
    result = rtree-> 
   predict(test_sample, Mat()); printf( 
   " 
   Testing Sample %i -> class result (digit %d)\n 
   ", tsample, ( 
   int 
   ) result);  
   // 
    if the prediction and the (true) testing classification are the same  
   // 
    (N.B. openCV uses a floating point decision tree implementation!) 
   if (fabs(result - testing_classifications.at< 
   float>(tsample, 
   0 
   )) >= 
    FLT_EPSILON) {  
   // 
    if they differ more than floating point error => wrong class wrong_class++ 
   ; false_positives[( 
   int) result]++ 
   ; }  
   else 
    {  
   // 
    otherwise correct correct_class++ 
   ; } } printf(  
   " 
   \nResults on the testing database: %s\n 
   " 
   " 
   \tCorrect classification: %d (%g%%)\n 
   " 
   " 
   \tWrong classifications: %d (%g%%)\n 
   " 
   , argv[ 
   2 
   ], correct_class, ( 
   double) correct_class* 
   100/ 
   NUMBER_OF_TESTING_SAMPLES, wrong_class, ( 
   double) wrong_class* 
   100/ 
   NUMBER_OF_TESTING_SAMPLES);  
   for ( 
   int i = 
   0; i < NUMBER_OF_CLASSES; i++ 
   ) { printf(  
   " 
   \tClass (digit %d) false postives %d (%g%%)\n 
   " 
   , i, false_positives[i], ( 
   double) false_positives[i]* 
   100/ 
   NUMBER_OF_TESTING_SAMPLES); }  
   // 
    all matrix memory free by destructors  
   // 
    all OK : main returns 0 
   return 
   0 
   ; }  
   // 
    not OK : main returns -1 
   return - 
   1 
   ; } 
  
复制代码

 

MNIST样本可以在这个网址http://yann.lecun.com/exdb/mnist/下载,改一下路径可以直接跑的。

3.如何自己设计随机森林程序

有时现有的库无法满足要求,就需要自己设计一个分类器算法,这部分讲一下如何设计自己的随机森林分类器,代码实现就不贴了,因为在工作中用到了,因此比较敏感。

首先,要有一个RandomForest类,里面保存整个树需要的一些参数,包括但不限于:训练样本数量、测试样本数量、特征维数、每个节点随机提取的特征维数、CART树的数量、树的最大深度、类别数量(如果是分类问题)、一些终止条件、指向所有树的指针,指向训练集和测试集的指针,指向训练集label的指针等。还要有一些函数,至少要有train和predict吧。train里面直接调用每棵树的train方法即可,predict同理,但要对每棵树的预测输出做处理,得到森林的预测输出。

 

其次,要有一个sample类,这个类可不是用来存储训练集和对应label的,这是因为,每棵树、每个节点都有自己的样本集和,如果你的存储每个样本集和的话,需要的内存实在是太过巨大了,假设样本数量为M,特征维数为N,则整个训练集大小为M×N,而每棵树的每层都有这么多样本,树的深度为D,共有S棵树的话,则需要存储M×N×D×S的存储空间。这实在是太大了。因此,每个节点训练时用到的训练样本和特征,我们都用序号数组来代替,sample类就是干这个的。sample的函数基本需要两个就行,第一个是从现有训练集有放回的随机抽取一个新的训练集,当然,只包含样本的序号。第二个函数是从现有的特征中无放回的随机抽取一定数量的特征,同理,也是特征序号即可。

然后,需要一个Tree类,代表每棵树,里面保存树的一些参数以及一个指向所有节点的指针。

最后,需要一个Node类,代表树的每个节点。

 

需要说明的是,保存树的方式可以是最普通的数组,也可是是vector。Node的保存方式同理,但是个人不建议用链表的方式,在程序设计以及函数处理上太麻烦,但是在省空间上并没有太多的体现。

目前先写这么多,最后这部分我还会再扩充一些。

#2017.2.28

在github上开源了一个简单的随机森林程序,包含训练、预测部分,支持分类和回归问题,里面有mnist训练的实例,附了不少注释,比较适合入门学习,地址:

https://github.com/handspeaker/RandomForests

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/229001.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月16日 下午5:42
下一篇 2026年3月16日 下午5:42


相关推荐

  • 软件缺陷报告[通俗易懂]

    软件缺陷报告[通俗易懂]1、定义概述:标识并描述发现的缺陷,具有清晰、完整和可重视问题所需的信息的文档理解:测试人员发现缺陷,记录,通过缺陷报告将缺陷报告给开发人员,并对缺陷进行跟踪管理。缺陷报告是测试人员与开发人员之间重要的沟通方式2、什么是缺陷软件缺陷就是通常说的Bug,它是指在软件中存在的影响软件正常运行的问题3、软件缺陷产生的原因1、需求不明确和变更软件需求不清晰或者开发人员对需求理解偏差,导致软件设…

    2026年1月16日
    4
  • Vue进阶(四十三):Vuex之Mutations详解

    Vue进阶(四十三):Vuex之Mutations详解通俗的理解 mutations 里面装着一些改变数据方法的集合 这是 Veux 设计很重要的一点 就是把处理数据逻辑方法全部放在 mutations 里面 使得数据和视图分离 2 怎么用 mutations mutation 结构 每一个 mutation 都有一个字符串类型的事件类型 type 和回调函数 handler 也可以理解为 type handler 这和订阅发布有点类似 先注册事件 当触发

    2026年3月19日
    1
  • 公网IP和内网IP区别

    公网IP和内网IP区别什么是内网IP:一些小型企业或者学校,通常都是申请一个固定的IP地址,然后通过IP共享(IPSharing),使用整个公司或学校的机器都能够访问互联网。而这些企业或学校的机器使用的IP地址就是内网IP,内网IP是在规划IPv4协议时,考虑到IP地址资源可能不足,就专门为内部网设计私有IP地址(或称之为保留地址),一般常用内网IP地址都是这种形式的:10.X.X.X、172.16.X.X-1…

    2022年4月30日
    54
  • bs与cs架构的优缺点_bs架构与cs架构的区别详细讲解

    bs与cs架构的优缺点_bs架构与cs架构的区别详细讲解简介C/S又称Client/Server或客户/服务器模式。服务器通常采用高性能的PC、工作站或小型机,并采用大型数据库系统,如Oracle、Sybase、Informix或SQLServer。客户端需要安装专用的客户端软件。B/S是Brower/Server的缩写,客户机上只要安装一个浏览器(Browser),如NetscapeNavigator或InternetExplorer,服务器安装Oracle、Sybase、Informix或SQLServer等数据库。浏览器通过Web

    2022年8月31日
    6
  • 使用xsync脚本分发「建议收藏」

    使用xsync脚本分发「建议收藏」为什么使用xsync脚本来分发文件因为操作简单,只需要执行脚本在后面加上需要分发的文件就行了然后底层不一致,scp使用的是安全拷贝,而xsync使用的是增量拷贝由于底层不一致,xsync比scp快上许多使用脚本来分发文件之前不同节点之间的免密登录安排上脚本实现#!/bin/bash#1输入参数个数,如果没有参数就会退出pcount=$#if((pcount==0));thenechonoargs;exit;fi#2需要分发的文件名称p1=$1fname=`

    2022年5月18日
    56
  • Python列表(list)详解[通俗易懂]

    Python列表(list)详解[通俗易懂]Python内置的四种常用数据结构:列表(list)、元组(tuple)、字典(dict)以及集合(set)。这四种数据结构一但都可用于保存多个数据项,这对于编程而言是非常重要的,因为程序不仅需要使

    2022年7月3日
    40

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号