kNN(K-Nearest Neighbor)最邻近规则分类

kNN(K-Nearest Neighbor)最邻近规则分类

KNN最邻近规则,主要应用领域是对未知事物的识别,即推断未知事物属于哪一类,推断思想是,基于欧几里得定理,推断未知事物的特征和哪一类已知事物的的特征最接近;

K近期邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比較成熟的方法,也是最简单的机器学习算法之中的一个。该方法的思路是:假设一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上仅仅根据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 KNN方法尽管从原理上也依赖于极限定理,但在类别决策时,仅仅与极少量的相邻样本有关。因为KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其它方法更为适合。
  KNN算法不仅能够用于分类,还能够用于回归。通过找出一个样本的k个近期邻居,将这些邻居的属性的平均值赋给该样本,就能够得到该样本的属性。更实用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比(组合函数)。
  该算法在分类时有个基本的不足是,当样本不平衡时,如一个类的样本容量非常大,而其它类样本容量非常小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 该算法仅仅计算“近期的”邻居样本,某一类的样本数量非常大,那么或者这类样本并不接近目标样本,或者这类样本非常靠近目标样本。不管如何,数量并不能影响执行结果。能够採用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的还有一个不足之处是计算量较大,由于对每个待分类的文本都要计算它到全体已知样本的距离,才干求得它的K个近期邻点。眼下经常使用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比較适用于样本容量比較大的类域的自己主动分类,而那些样本容量较小的类域採用这样的算法比較easy产生误分。

K-NN能够说是一种最直接的用来分类未知数据的方法。基本通过以下这张图跟文字说明就能够明确K-NN是干什么的

<span>kNN(K-Nearest Neighbor)最邻近规则分类</span>

简单来说,K-NN能够看成:有那么一堆你已经知道分类的数据,然后当一个新数据进入的时候,就開始跟训练数据里的每一个点求距离,然后挑离这个训练数据近期的K个点看看这几个点属于什么类型,然后用少数服从多数的原则,给新数据归类。

 

算法步骤:

step.1—初始化距离为最大值

step.2—计算未知样本和每一个训练样本的距离dist

step.3—得到眼下K个最临近样本中的最大距离maxdist

step.4—假设dist小于maxdist,则将该训练样本作为K-近期邻样本

step.5—反复步骤2、3、4,直到未知样本和全部训练样本的距离都算完

step.6—统计K-近期邻样本中每一个类标号出现的次数

step.7—选择出现频率最大的类标号作为未知样本的类标号

 

 

KNN的matlab简单实现代码

function target=KNN(in,out,test,k)

% in:       training samples data,n*d matrix

% out: training samples’ class label,n*1

% test:     testing data

% target:   class label given by knn

% k:        the number of neighbors

ClassLabel=unique(out);

c=length(ClassLabel);

n=size(in,1);

% target=zeros(size(test,1),1);

dist=zeros(size(in,1),1);

for j=1:size(test,1)

    cnt=zeros(c,1);

    for i=1:n

        dist(i)=norm(in(i,:)-test(j,:));

    end

    [d,index]=sort(dist);

    for i=1:k

        ind=find(ClassLabel==out(index(i)));

        cnt(ind)=cnt(ind)+1;

    end

    [m,ind]=max(cnt);

    target(j)=ClassLabel(ind);

end

 

R语言的实现代码例如以下

library(class)
data(iris)
names(iris)
m1<-knn.cv(iris[,1:4],iris[,5],k=3,prob=TRUE)
attributes(.Last.value)
library(MASS)
m2<-lda(iris[,1:4],iris[,5])  与判别分析进行比較
b<-data.frame(Sepal.Length=6,Sepal.Width=4,Petal.Length=5,Petal.Width=6)
p1<-predict(m2,b,type=”class”)

C++ 实现 :

//    KNN.cpp     K-近期邻分类算法
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <string.h>
#include <iostream>
#include <math.h>
#include <fstream>
using namespace std;
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    宏定义
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
#define  ATTR_NUM  4                        //属性数目
#define  MAX_SIZE_OF_TRAINING_SET  1000      //训练数据集的最大大小
#define  MAX_SIZE_OF_TEST_SET      100       //測试数据集的最大大小
#define  MAX_VALUE  10000.0                  //属性最大值
#define  K  7
//结构体
struct dataVector {
 int ID;                      //ID号
 char classLabel[15];             //分类标号
 double attributes[ATTR_NUM]; //属性 
};
struct distanceStruct {
 int ID;                      //ID号
 double distance;             //距离
 char classLabel[15];             //分类标号
};

////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    全局变量
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
struct dataVector gTrainingSet[MAX_SIZE_OF_TRAINING_SET]; //训练数据集
struct dataVector gTestSet[MAX_SIZE_OF_TEST_SET];         //測试数据集
struct distanceStruct gNearestDistance[K];                //K个近期邻距离
int curTrainingSetSize=0;                                 //训练数据集的大小
int curTestSetSize=0;                                     //測试数据集的大小
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    求 vector1=(x1,x2,…,xn)和vector2=(y1,y2,…,yn)的欧几里德距离
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
double Distance(struct dataVector vector1,struct dataVector vector2)
{
 double dist,sum=0.0;
 for(int i=0;i<ATTR_NUM;i++)
 {
  sum+=(vector1.attributes[i]-vector2.attributes[i])*(vector1.attributes[i]-vector2.attributes[i]);
 }
 dist=sqrt(sum);
 return dist;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    得到gNearestDistance中的最大距离,返回下标
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
int GetMaxDistance()
{
 int maxNo=0;
 for(int i=1;i<K;i++)
 {
  if(gNearestDistance[i].distance>gNearestDistance[maxNo].distance) maxNo = i;
 }
    return maxNo;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    对未知样本Sample分类
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
char* Classify(struct dataVector Sample)
{
 double dist=0;
 int maxid=0,freq[K],i,tmpfreq=1;;
 char *curClassLable=gNearestDistance[0].classLabel;
 memset(freq,1,sizeof(freq));
 //step.1—初始化距离为最大值
 for(i=0;i<K;i++)
 {
  gNearestDistance[i].distance=MAX_VALUE;
 }
 //step.2—计算K-近期邻距离
 for(i=0;i<curTrainingSetSize;i++)
 {
  //step.2.1—计算未知样本和每一个训练样本的距离
  dist=Distance(gTrainingSet[i],Sample);
  //step.2.2—得到gNearestDistance中的最大距离
  maxid=GetMaxDistance();
  //step.2.3—假设距离小于gNearestDistance中的最大距离,则将该样本作为K-近期邻样本
  if(dist<gNearestDistance[maxid].distance)
  {
   gNearestDistance[maxid].ID=gTrainingSet[i].ID;
   gNearestDistance[maxid].distance=dist;
   strcpy(gNearestDistance[maxid].classLabel,gTrainingSet[i].classLabel);
  }
 }
 //step.3—统计每一个类出现的次数
 for(i=0;i<K;i++) 
 {
  for(int j=0;j<K;j++)
  {
   if((i!=j)&&(strcmp(gNearestDistance[i].classLabel,gNearestDistance[j].classLabel)==0))
   {
    freq[i]+=1;
   }
  }
 }
 //step.4—选择出现频率最大的类标号
 for(i=0;i<K;i++)
 {
  if(freq[i]>tmpfreq) 
  {
   tmpfreq=freq[i];
    curClassLable=gNearestDistance[i].classLabel;
  }
 }
 return curClassLable;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    主函数
//
////////////////////////////////////////////////////////////////////////////////////////////////////////

void main()
{  
 char c;
    char *classLabel=””;
 int i,j, rowNo=0,TruePositive=0,FalsePositive=0;
 ifstream filein(“iris.data”);
 FILE *fp;
 if(filein.fail()){cout<<“Can’t open data.txt”<<endl; return;}
 //step.1—读文件 
 while(!filein.eof())
 {
  rowNo++;//第一组数据rowNo=1
  if(curTrainingSetSize>=MAX_SIZE_OF_TRAINING_SET)
  {
   cout<<“The training set has “<<MAX_SIZE_OF_TRAINING_SET<<” examples!”<<endl<<endl;
   break ;
  }  
  //rowNo%3!=0的100组数据作为训练数据集
  if(rowNo%3!=0)
  {   
   gTrainingSet[curTrainingSetSize].ID=rowNo;
   for(int i = 0;i < ATTR_NUM;i++)
   {     
    filein>>gTrainingSet[curTrainingSetSize].attributes[i];
    filein>>c;
   }   
   filein>>gTrainingSet[curTrainingSetSize].classLabel;
   curTrainingSetSize++;
   
  }
  //剩下rowNo%3==0的50组做測试数据集
  else if(rowNo%3==0)
  {
   gTestSet[curTestSetSize].ID=rowNo;
   for(int i = 0;i < ATTR_NUM;i++)
   {    
    filein>>gTestSet[curTestSetSize].attributes[i];
    filein>>c;
   }  
   filein>>gTestSet[curTestSetSize].classLabel;
   curTestSetSize++;
  }
 }
 filein.close();
 //step.2—KNN算法进行分类,并将结果写到文件iris_OutPut.txt
 fp=fopen(“iris_OutPut.txt”,”w+t”);
 //用KNN算法进行分类
 fprintf(fp,”************************************程序说明***************************************\n”);
 fprintf(fp,”** 採用KNN算法对iris.data分类。为了操作方便,对各组数据加入�rowNo属性,第一组rowNo=1!\n”);
 fprintf(fp,”** 共同拥有150组数据,选择rowNo模3不等于0的100组作为训练数据集,剩下的50组做測试数据集\n”);
 fprintf(fp,”***********************************************************************************\n\n”);
 fprintf(fp,”************************************实验结果***************************************\n\n”);
 for(i=0;i<curTestSetSize;i++)
 {
        fprintf(fp,”************************************第%d组数据**************************************\n”,i+1);
  classLabel =Classify(gTestSet[i]);
     if(strcmp(classLabel,gTestSet[i].classLabel)==0)//相等时,分类正确
  {
   TruePositive++;
  }
  cout<<“rowNo: “;
  cout<<gTestSet[i].ID<<”    \t”;
  cout<<“KNN分类结果:      “;

  cout<<classLabel<<“(正确类标号: “;
  cout<<gTestSet[i].classLabel<<“)\n”;
  fprintf(fp,”rowNo:  %3d   \t  KNN分类结果:  %s ( 正确类标号:  %s )\n”,gTestSet[i].ID,classLabel,gTestSet[i].classLabel);
  if(strcmp(classLabel,gTestSet[i].classLabel)!=0)//不等时,分类错误
  {
  // cout<<”   ***分类错误***\n”;
   fprintf(fp,”                                                                      ***分类错误***\n”);
  }
  fprintf(fp,”%d-最临近数据:\n”,K);
  for(j=0;j<K;j++)
  {
  // cout<<gNearestDistance[j].ID<<“\t”<<gNearestDistance[j].distance<<“\t”<<gNearestDistance[j].classLabel[15]<<endl;
   fprintf(fp,”rowNo:  %3d   \t   Distance:  %f   \tClassLable:    %s\n”,gNearestDistance[j].ID,gNearestDistance[j].distance,gNearestDistance[j].classLabel);
  }
  fprintf(fp,”\n”);
 }
    FalsePositive=curTestSetSize-TruePositive;
 fprintf(fp,”***********************************结果分析**************************************\n”,i);
 fprintf(fp,”TP(True positive): %d\nFP(False positive): %d\naccuracy: %f\n”,TruePositive,FalsePositive,double(TruePositive)/(curTestSetSize-1));
 fclose(fp);
    return;
}

 

以上内容为參考网上有关资料;加以总结;

 

 

 

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/119215.html原文链接:https://javaforall.net

(0)
全栈程序员-站长的头像全栈程序员-站长


相关推荐

  • java CAS详解[通俗易懂]

    java CAS详解[通俗易懂]CAS解释:CAS(compareandswap),比较并交换。可以解决多线程并行情况下使用锁造成性能损耗的一种机制.CAS操作包含三个操作数—内存位置(V)、预期原值(A)和新值(B)。如果内存位置的值与预期原值相匹配,那么处理器会自动将该位置值更新为新值。否则,处理器不做任何操作。一个线程从主内存中得到num值,并对num进行操作,写入值的时候,线程会把第一次取到的num值和主内存中num值进行比较,如果相等,就会将改变后的num写入主内存,如果不相等,则一直循环对比,知道成功为止。CAS

    2022年7月9日
    24
  • 十分钟免费拥有永久网站

    十分钟免费拥有永久网站在人人都会上网的信息时代,拥有属于自己的网站,已经不是什么稀奇的事情了。GithubPages就可以满足我们的需求了。它是github公司提供的免费的静态网站托管服务,用起来方便而且功能强大,不仅没有空间限制,还可以绑定自己的域名。一、注册github账户注册流程和其它平台一样。注册地址:https://github.com/join?source=logingithu…

    2022年5月27日
    87
  • 交换排序之高速排序[通俗易懂]

    交换排序之高速排序

    2022年1月19日
    46
  • java调用HTTP接口(Get请求和Post请求)

    java调用HTTP接口(Get请求和Post请求)前提:一个Http接口:http://172.83.38.209:7001/NSRTRegistration/test/add.do?id=8888888&name=99999999id和name是传入的参数浏览器访问接口:java代码调用Http接口代码如下(代码中注释分为两部分:处理get请求和post请求):packagecom.inspur.OKHTTP…

    2022年5月24日
    814
  • pcanywhere的端口「建议收藏」

    pcanywhere的端口「建议收藏」PCANYWHERE使用端口TCP:5631,UDP:5632通常打开第一个端口就可以了

    2022年9月13日
    0
  • File.createTempFile异常「建议收藏」

    错误:File.createtempfilejava.io.winntfilesystem.createfileexclusively(nativemethod)原来是Eclipse默认的JRE不是JDK下的修改为JDK下的jre就可以了转载于:https://www.cnblogs.com/cszzy/archive/2012/12/28/2837790.html…

    2022年4月11日
    103

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注全栈程序员社区公众号