1 什么是决策树
2 构建决策树
2.1 构建决策树的基本步骤
2.2 构建决策树例子
下面就以一个经典的打网球的例子来说明如何构建决策树。我们今天是否去打网球(play)主要由天气(outlook)、温度(temperature)、湿度(humidity)、是否有风(windy)来确定。样本中共14条数据。
下面将分别介绍使用ID3和C4.5算法构建决策树。
2.2.1 使用ID3算法构建决策树
ID3算法是使用信息增益来选择特征的。
2.2.1.1 信息增益的计算方法
2.2.1.2 计算是否打球的经验熵
2.2.1.3 计算outlook天气特征的信息增益
其中|D|是数据集中所有样本个数,j是当前特征的不同取值个数, |Dj|是第j个取值的样本个数, H(Dj)是该取值的样本的基于目标变量的信息熵。
2.2.1.4 计算temperature温度特征的信息增益
2.2.1.5 计算humidity湿度特征的信息增益
2.2.1.6 计算windy是否刮风的信息增益
2.2.1.7 确定root节点
2.2.1.7.1 计算outlook为sunny的数据集
2.2.1.7.2 计算outlook为rainy的数据
2.2.2 使用C4.5算法构建决策树
2.2.2.1 计算各特征的信息增益率
对比上面4个特征的信息增益率,outlook天气特征的信息增益率最大,所以outlook作为root节点。其他计算方法类似。
2.3 决策树实践
2.3.1 决策树代码实现
此代码是基于ID3算法的信息增益来实现的。
#!/usr/bin/python #encoding:utf-8 from math import log import operator import treePlotter import sys reload(sys) sys.setdefaultencoding("utf-8") def createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing','flippers'] #change to discrete values return dataSet, labels 计算信息熵 def calcShannonEnt(dataSet): numEntries = len(dataSet) # 样本数 labelCounts = {} # 创建一个数据字典:key是最后一列的数值(即标签,也就是目标分类的类别),value是属于该类别的样本个数 for featVec in dataSet: # 遍历整个数据集,每次取一行 currentLabel = featVec[-1] #取该行最后一列的值 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 # 初始化信息熵 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob * log(prob,2) #log base 2 计算信息熵 return shannonEnt 按给定的特征划分数据 def splitDataSet(dataSet, axis, value): #axis是dataSet数据集下要进行特征划分的列号例如outlook是0列,value是该列下某个特征值,0列中的sunny retDataSet = [] for featVec in dataSet: #遍历数据集,并抽取按axis的当前value特征进划分的数据集(不包括axis列的值) if featVec[axis] == value: # reducedFeatVec = featVec[:axis] #chop out axis used for splitting reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) # print axis,value,reducedFeatVec # print retDataSet return retDataSet 选取当前数据集下,用于划分数据集的最优特征 def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 #获取当前数据集的特征个数,最后一列是分类标签 baseEntropy = calcShannonEnt(dataSet) #计算当前数据集的信息熵 bestInfoGain = 0.0; bestFeature = -1 #初始化最优信息增益和最优的特征 for i in range(numFeatures): #遍历每个特征iterate over all the features featList = [example[i] for example in dataSet]#获取数据集中当前特征下的所有值 uniqueVals = set(featList) # 获取当前特征值,例如outlook下有sunny、overcast、rainy newEntropy = 0.0 for value in uniqueVals: #计算每种划分方式的信息熵 subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy #计算信息增益 if (infoGain > bestInfoGain): #比较每个特征的信息增益,只要最好的信息增益 bestInfoGain = infoGain #if better than current best, set to best bestFeature = i return bestFeature #returns an integer 该函数使用分类名称的列表,然后创建键值为classList中唯一值的数据字典。字典 对象的存储了classList中每个类标签出现的频率。最后利用operator操作键值排序字典, 并返回出现次数最多的分类名称 def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] 生成决策树主方法 def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] # 返回当前数据集下标签列所有值 if classList.count(classList[0]) == len(classList): return classList[0]#当类别完全相同时则停止继续划分,直接返回该类的标签 if len(dataSet[0]) == 1: 遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组 dataSet return majorityCnt(classList) #由于无法简单的返回唯一的类标签,这里就返回出现次数最多的类别作为返回值 bestFeat = chooseBestFeatureToSplit(dataSet) # 获取最好的分类特征索引 bestFeatLabel = labels[bestFeat] #获取该特征的名字 # 这里直接使用字典变量来存储树信息,这对于绘制树形图很重要。 myTree = {bestFeatLabel:{}} #当前数据集选取最好的特征存储在bestFeat中 del(labels[bestFeat]) #删除已经在选取的特征 featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) return myTree def classify(inputTree,featLabels,testVec): firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) key = testVec[featIndex] valueOfFeat = secondDict[key] if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat, featLabels, testVec) else: classLabel = valueOfFeat return classLabel def storeTree(inputTree,filename): import pickle fw = open(filename,'w') pickle.dump(inputTree,fw) fw.close() def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr) if __name__ == '__main__': fr = open('play.tennies.txt') lenses =[inst.strip().split(' ') for inst in fr.readlines()] lensesLabels = ['outlook','temperature','huminidy','windy'] lensesTree =createTree(lenses,lensesLabels) treePlotter.createPlot(lensesTree) #!/usr/bin/python #encoding:utf-8 #treePlotter.py import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc="0.8") #定义文本框与箭头的格式 leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") def getNumLeafs(myTree): #获取树叶节点的数目 numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#测试节点的数据类型是不是字典,如果是则就需要递归的调用getNumLeafs()函数 numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs def getTreeDepth(myTree): #获取树的深度 maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 绘制带箭头的注释 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) #计算父节点和子节点的中间位置,在父节点间填充文本的信息 def plotMidText(cntrPt, parentPt, txtString): #!/usr/bin/python #encoding:utf-8 # 对原始数据进行分为训练数据和测试数据 import numpy as np from sklearn import tree from sklearn.model_selection import train_test_split import pydotplus def outlook_type(s): it = {
'sunny':1, 'overcast':2, 'rainy':3} return it[s] def temperature(s): it = {
'hot':1, 'mild':2, 'cool':3} return it[s] def humidity(s): it = {
'high':1, 'normal':0} return it[s] def windy(s): it = {
'TRUE':1, 'FALSE':0} return it[s] def play_type(s): it = {
'yes': 1, 'no': 0} return it[s] play_feature_E = 'outlook', 'temperature', 'humidity', 'windy' play_class = 'yes', 'no' # 1、读入数据,并将原始数据中的数据转换为数字形式 data = np.loadtxt("play.tennies.txt", delimiter=" ", dtype=str, converters={
0:outlook_type, 1:temperature, 2:humidity, 3:windy,4:play_type}) x, y = np.split(data,(4,),axis=1) # 2、拆分训练数据与测试数据,为了进行交叉验证 # x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3,random_state=2) x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3) # 3、使用信息熵作为划分标准,对决策树进行训练 clf = tree.DecisionTreeClassifier(criterion='entropy') print(clf) clf.fit(x_train, y_train) # 4、把决策树结构写入文件 dot_data = tree.export_graphviz(clf, out_file=None, feature_names=play_feature_E, class_names=play_class, filled=True, rounded=True, special_characters=True) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf('play1.pdf') # 系数反映每个特征的影响力。越大表示该特征在分类中起到的作用越大 print(clf.feature_importances_) # 5、使用训练数据预测,预测结果完全正确 answer = clf.predict(x_train) y_train = y_train.reshape(-1) print(answer) print(y_train) print(np.mean(answer == y_train)) # 6、对测试数据进行预测,准确度较低,说明过拟合 answer = clf.predict(x_test) y_test = y_test.reshape(-1) print(answer) print(y_test) print(np.mean(answer == y_test)) 运行结果: DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_split=1e-07, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=None, splitter='best') [ 0. 0. 0. 0.] ['0' '1' '0' '1' '0' '0' '1' '1' '1'] ['0' '1' '0' '1' '0' '0' '1' '1' '1'] 1.0 ['0' '1' '1' '1' '1'] ['0' '1' '1' '1' '1'] 1.0
转载自https://blog.csdn.net/huahuazhu/article/details/](https://blog.csdn.net/huahuazhu/article/details/)
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。
发布者:全栈程序员-站长,转载请注明出处:https://javaforall.net/205066.html原文链接:https://javaforall.net
