侯体宗的博客
  • 首页
  • Hyperf版
  • beego仿版
  • 人生(杂谈)
  • 技术
  • 关于我
  • 更多分类
    • 文件下载
    • 文字修仙
    • 中国象棋ai
    • 群聊
    • 九宫格抽奖
    • 拼图
    • 消消乐
    • 相册

决策树剪枝算法的python实现方法详解

Python  /  管理员 发布于 7年前   395

本文实例讲述了决策树剪枝算法的python实现方法。分享给大家供大家参考,具体如下:

决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。

ID3算法:ID3算法是决策树的一种,是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法,即Iterative Dichotomiser 3,迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法。在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。
信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低,反之一个系统越是混乱,它的信息熵就越高。所以信息熵可以被认为是系统有序化程度的一个度量。

基尼指数:在CART里面划分决策树的条件是采用Gini Index,定义如下:gini(T)=1−sumnj=1p2j。其中,( p_j )是类j在T中的相对频率,当类在T中是倾斜的时,gini(T)会最小。将T划分为T1(实例数为N1)和T2(实例数为N2)两个子集后,划分数据的Gini定义如下:ginisplit(T)=fracN1Ngini(T1)+fracN2Ngini(T2),然后选择其中最小的(gini_{split}(T) )作为结点划分决策树
具体实现
首先用函数calcShanno计算数据集的香农熵,给所有可能的分类创建字典

def calcShannonEnt(dataSet):   numEntries = len(dataSet)   labelCounts = {}   # 给所有可能分类创建字典   for featVec in dataSet:     currentLabel = featVec[-1]     if currentLabel not in labelCounts.keys():       labelCounts[currentLabel] = 0    labelCounts[currentLabel] += 1  shannonEnt = 0.0  # 以2为底数计算香农熵  for key in labelCounts:    prob = float(labelCounts[key]) / numEntries    shannonEnt -= prob * log(prob, 2)  return shannonEnt

# 对离散变量划分数据集,取出该特征取值为value的所有样本def splitDataSet(dataSet, axis, value):  retDataSet = []  for featVec in dataSet:    if featVec[axis] == value:      reducedFeatVec = featVec[:axis]      reducedFeatVec.extend(featVec[axis + 1:])      retDataSet.append(reducedFeatVec)  return retDataSet

对连续变量划分数据集,direction规定划分的方向, 决定是划分出小于value的数据样本还是大于value的数据样本集

  numFeatures = len(dataSet[0]) - 1  baseEntropy = calcShannonEnt(dataSet)  bestInfoGain = 0.0  bestFeature = -1  bestSplitDict = {}  for i in range(numFeatures):    featList = [example[i] for example in dataSet]    # 对连续型特征进行处理    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':      # 产生n-1个候选划分点      sortfeatList = sorted(featList)      splitList = []      for j in range(len(sortfeatList) - 1):        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)      bestSplitEntropy = 10000      slen = len(splitList)      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点      for j in range(slen):        value = splitList[j]        newEntropy = 0.0        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)        prob0 = len(subDataSet0) / float(len(dataSet))        newEntropy += prob0 * calcShannonEnt(subDataSet0)        prob1 = len(subDataSet1) / float(len(dataSet))        newEntropy += prob1 * calcShannonEnt(subDataSet1)        if newEntropy < bestSplitEntropy:          bestSplitEntropy = newEntropy          bestSplit = j      # 用字典记录当前特征的最佳划分点      bestSplitDict[labels[i]] = splitList[bestSplit]      infoGain = baseEntropy - bestSplitEntropy    # 对离散型特征进行处理    else:      uniqueVals = set(featList)      newEntropy = 0.0      # 计算该特征下每种划分的信息熵      for value in uniqueVals:        subDataSet = splitDataSet(dataSet, i, value)        prob = len(subDataSet) / float(len(dataSet))        newEntropy += prob * calcShannonEnt(subDataSet)      infoGain = baseEntropy - newEntropy    if infoGain > bestInfoGain:      bestInfoGain = infoGain      bestFeature = i  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理  # 即是否小于等于bestSplitValue  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':    bestSplitValue = bestSplitDict[labels[bestFeature]]    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)    for i in range(shape(dataSet)[0]):      if dataSet[i][bestFeature] <= bestSplitValue:        dataSet[i][bestFeature] = 1      else:        dataSet[i][bestFeature] = 0  return bestFeature
def chooseBestFeatureToSplit(dataSet, labels):  numFeatures = len(dataSet[0]) - 1  baseEntropy = calcShannonEnt(dataSet)  bestInfoGain = 0.0  bestFeature = -1  bestSplitDict = {}  for i in range(numFeatures):    featList = [example[i] for example in dataSet]    # 对连续型特征进行处理    if type(featList[0]).__name__ == 'float' or type(featList[0]).__name__ == 'int':      # 产生n-1个候选划分点      sortfeatList = sorted(featList)      splitList = []      for j in range(len(sortfeatList) - 1):        splitList.append((sortfeatList[j] + sortfeatList[j + 1]) / 2.0)      bestSplitEntropy = 10000      slen = len(splitList)      # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点      for j in range(slen):        value = splitList[j]        newEntropy = 0.0        subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)        subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)        prob0 = len(subDataSet0) / float(len(dataSet))        newEntropy += prob0 * calcShannonEnt(subDataSet0)        prob1 = len(subDataSet1) / float(len(dataSet))        newEntropy += prob1 * calcShannonEnt(subDataSet1)        if newEntropy < bestSplitEntropy:          bestSplitEntropy = newEntropy          bestSplit = j      # 用字典记录当前特征的最佳划分点      bestSplitDict[labels[i]] = splitList[bestSplit]      infoGain = baseEntropy - bestSplitEntropy    # 对离散型特征进行处理    else:      uniqueVals = set(featList)      newEntropy = 0.0      # 计算该特征下每种划分的信息熵      for value in uniqueVals:        subDataSet = splitDataSet(dataSet, i, value)        prob = len(subDataSet) / float(len(dataSet))        newEntropy += prob * calcShannonEnt(subDataSet)      infoGain = baseEntropy - newEntropy    if infoGain > bestInfoGain:      bestInfoGain = infoGain      bestFeature = i  # 若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理  # 即是否小于等于bestSplitValue  if type(dataSet[0][bestFeature]).__name__ == 'float' or type(dataSet[0][bestFeature]).__name__ == 'int':    bestSplitValue = bestSplitDict[labels[bestFeature]]    labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)    for i in range(shape(dataSet)[0]):      if dataSet[i][bestFeature] <= bestSplitValue:        dataSet[i][bestFeature] = 1      else:        dataSet[i][bestFeature] = 0  return bestFeature``def classify(inputTree, featLabels, testVec):  firstStr = inputTree.keys()[0]  if u'<=' in firstStr:    featvalue = float(firstStr.split(u"<=")[1])    featkey = firstStr.split(u"<=")[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(featkey)    if testVec[featIndex] <= featvalue:      judge = 1    else:      judge = 0    for key in secondDict.keys():      if judge == int(key):        if type(secondDict[key]).__name__ == 'dict':          classLabel = classify(secondDict[key], featLabels, testVec)        else:          classLabel = secondDict[key]  else:    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    for key in secondDict.keys():      if testVec[featIndex] == key:        if type(secondDict[key]).__name__ == 'dict':          classLabel = classify(secondDict[key], featLabels, testVec)        else:          classLabel = secondDict[key]  return classLabel
def majorityCnt(classList):  classCount={}  for vote in classList:    if vote not in classCount.keys():      classCount[vote]=0    classCount[vote]+=1  return max(classCount)def testing_feat(feat, train_data, test_data, labels):  class_list = [example[-1] for example in train_data]  bestFeatIndex = labels.index(feat)  train_data = [example[bestFeatIndex] for example in train_data]  test_data = [(example[bestFeatIndex], example[-1]) for example in test_data]  all_feat = set(train_data)  error = 0.0  for value in all_feat:    class_feat = [class_list[i] for i in range(len(class_list)) if train_data[i] == value]    major = majorityCnt(class_feat)    for data in test_data:      if data[0] == value and data[1] != major:        error += 1.0  # print 'myTree %d' % error  return error

测试

  error = 0.0  for i in range(len(data_test)):    if classify(myTree, labels, data_test[i]) != data_test[i][-1]:      error += 1  # print 'myTree %d' % error  return float(error)def testingMajor(major, data_test):  error = 0.0  for i in range(len(data_test)):    if major != data_test[i][-1]:      error += 1  # print 'major %d' % error  return float(error)**递归产生决策树**```def createTree(dataSet,labels,data_full,labels_full,test_data,mode):  classList=[example[-1] for example in dataSet]  if classList.count(classList[0])==len(classList):    return classList[0]  if len(dataSet[0])==1:    return majorityCnt(classList)  labels_copy = copy.deepcopy(labels)  bestFeat=chooseBestFeatureToSplit(dataSet,labels)  bestFeatLabel=labels[bestFeat]  if mode == "unpro" or mode == "post":    myTree = {bestFeatLabel: {}}  elif mode == "prev":    if testing_feat(bestFeatLabel, dataSet, test_data, labels_copy) < testingMajor(majorityCnt(classList),        test_data):      myTree = {bestFeatLabel: {}}    else:      return majorityCnt(classList)  featValues=[example[bestFeat] for example in dataSet]  uniqueVals=set(featValues)  if type(dataSet[0][bestFeat]).__name__ == 'unicode':    currentlabel = labels_full.index(labels[bestFeat])    featValuesFull = [example[currentlabel] for example in data_full]    uniqueValsFull = set(featValuesFull)  del (labels[bestFeat])  for value in uniqueVals:    subLabels = labels[:]    if type(dataSet[0][bestFeat]).__name__ == 'unicode':      uniqueValsFull.remove(value)    myTree[bestFeatLabel][value] = createTree(splitDataSet \   (dataSet, bestFeat, value), subLabels, data_full, labels_full, splitDataSet \   (test_data, bestFeat, value), mode=mode)  if type(dataSet[0][bestFeat]).__name__ == 'unicode':    for value in uniqueValsFull:      myTree[bestFeatLabel][value] = majorityCnt(classList)  if mode == "post":    if testing(myTree, test_data, labels_copy) > testingMajor(majorityCnt(classList), test_data):      return majorityCnt(classList)  return myTree<div class="se-preview-section-delimiter"></div>```**读入数据**```def load_data(file_name):  with open(r"dd.csv", 'rb') as f:   df = pd.read_csv(f,sep=",")   print(df)   train_data = df.values[:11, 1:].tolist()  print(train_data)  test_data = df.values[11:, 1:].tolist()  labels = df.columns.values[1:-1].tolist()  return train_data, test_data, labels<div class="se-preview-section-delimiter"></div>```测试并绘制树图import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="round4", color='red') # 定义判断结点形态leafNode = dict(boxstyle="circle", color='grey') # 定义叶结点形态arrow_args = dict(arrowstyle="<-", color='blue') # 定义箭头# 计算树的叶子节点数量def getNumLeafs(myTree):  numLeafs = 0  firstSides = list(myTree.keys())  firstStr = firstSides[0]  secondDict = myTree[firstStr]  for key in secondDict.keys():    if type(secondDict[key]).__name__ == 'dict':      numLeafs += getNumLeafs(secondDict[key])    else:      numLeafs += 1  return numLeafs# 计算树的最大深度def getTreeDepth(myTree):  maxDepth = 0  firstSides = list(myTree.keys())  firstStr = firstSides[0]  secondDict = myTree[firstStr]  for key in secondDict.keys():    if type(secondDict[key]).__name__ == 'dict':      thisDepth = 1 + getTreeDepth(secondDict[key])    else:      thisDepth = 1    if thisDepth > maxDepth:      maxDepth = thisDepth  return maxDepth# 画节点def plotNode(nodeTxt, centerPt, parentPt, nodeType):  createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \  xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \  bbox=nodeType, arrowprops=arrow_args)# 画箭头上的文字def plotMidText(cntrPt, parentPt, txtString):  lens = len(txtString)  xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002  yMid = (parentPt[1] + cntrPt[1]) / 2.0  createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):  numLeafs = getNumLeafs(myTree)  depth = getTreeDepth(myTree)  firstSides = list(myTree.keys())  firstStr = firstSides[0]  cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)  plotMidText(cntrPt, parentPt, nodeTxt)  plotNode(firstStr, cntrPt, parentPt, decisionNode)  secondDict = myTree[firstStr]  plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD  for key in secondDict.keys():    if type(secondDict[key]).__name__ == 'dict':      plotTree(secondDict[key], cntrPt, str(key))    else:      plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW      plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)      plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))  plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalDdef createPlot(inTree):  fig = plt.figure(1, facecolor='white')  fig.clf()  axprops = dict(xticks=[], yticks=[])  createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  plotTree.totalW = float(getNumLeafs(inTree))  plotTree.totalD = float(getTreeDepth(inTree))  plotTree.x0ff = -0.5 / plotTree.totalW  plotTree.y0ff = 1.0  plotTree(inTree, (0.5, 1.0), '')  plt.show()
if __name__ == "__main__":  train_data, test_data, labels = load_data("dd.csv")  data_full = train_data[:]  labels_full = labels[:]  mode="post"  mode = "prev"  mode="post"  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)  createPlot(myTree)  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图

if __name__ == "__main__":  train_data, test_data, labels = load_data("dd.csv")  data_full = train_data[:]  labels_full = labels[:]  mode="post"  mode = "prev"  mode="post"  myTree = createTree(train_data, labels, data_full, labels_full, test_data, mode=mode)  createPlot(myTree)  print(json.dumps(myTree, ensure_ascii=False, indent=4))

选择mode就可以分别得到三种树图

更多关于Python相关内容感兴趣的读者可查看本站专题:《Python数据结构与算法教程》、《Python加密解密算法与技巧总结》、《Python编码操作技巧总结》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。


  • 上一条:
    python rsa实现数据加密和解密、签名加密和验签功能
    下一条:
    python生成requirements.txt的两种方法
  • 昵称:

    邮箱:

    0条评论 (评论内容有缓存机制,请悉知!)
    最新最热
    • 分类目录
    • 人生(杂谈)
    • 技术
    • linux
    • Java
    • php
    • 框架(架构)
    • 前端
    • ThinkPHP
    • 数据库
    • 微信(小程序)
    • Laravel
    • Redis
    • Docker
    • Go
    • swoole
    • Windows
    • Python
    • 苹果(mac/ios)
    • 相关文章
    • 在python语言中Flask框架的学习及简单功能示例(0个评论)
    • 在Python语言中实现GUI全屏倒计时代码示例(0个评论)
    • Python + zipfile库实现zip文件解压自动化脚本示例(0个评论)
    • python爬虫BeautifulSoup快速抓取网站图片(1个评论)
    • vscode 配置 python3开发环境的方法(0个评论)
    • 近期文章
    • 智能合约Solidity学习CryptoZombie第四课:僵尸作战系统(0个评论)
    • 智能合约Solidity学习CryptoZombie第三课:组建僵尸军队(高级Solidity理论)(0个评论)
    • 智能合约Solidity学习CryptoZombie第二课:让你的僵尸猎食(0个评论)
    • 智能合约Solidity学习CryptoZombie第一课:生成一只你的僵尸(0个评论)
    • 在go中实现一个常用的先进先出的缓存淘汰算法示例代码(0个评论)
    • 在go+gin中使用"github.com/skip2/go-qrcode"实现url转二维码功能(0个评论)
    • 在go语言中使用api.geonames.org接口实现根据国际邮政编码获取地址信息功能(1个评论)
    • 在go语言中使用github.com/signintech/gopdf实现生成pdf分页文件功能(0个评论)
    • gmail发邮件报错:534 5.7.9 Application-specific password required...解决方案(0个评论)
    • 欧盟关于强迫劳动的规定的官方举报渠道及官方举报网站(0个评论)
    • 近期评论
    • 122 在

      学历:一种延缓就业设计,生活需求下的权衡之选中评论 工作几年后,报名考研了,到现在还没认真学习备考,迷茫中。作为一名北漂互联网打工人..
    • 123 在

      Clash for Windows作者删库跑路了,github已404中评论 按理说只要你在国内,所有的流量进出都在监控范围内,不管你怎么隐藏也没用,想搞你分..
    • 原梓番博客 在

      在Laravel框架中使用模型Model分表最简单的方法中评论 好久好久都没看友情链接申请了,今天刚看,已经添加。..
    • 博主 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 @1111老铁这个不行了,可以看看近期评论的其他文章..
    • 1111 在

      佛跳墙vpn软件不会用?上不了网?佛跳墙vpn常见问题以及解决办法中评论 网站不能打开,博主百忙中能否发个APP下载链接,佛跳墙或极光..
    • 2016-10
    • 2016-11
    • 2018-04
    • 2020-03
    • 2020-04
    • 2020-05
    • 2020-06
    • 2022-01
    • 2023-07
    • 2023-10
    Top

    Copyright·© 2019 侯体宗版权所有· 粤ICP备20027696号 PHP交流群

    侯体宗的博客