《机器学习实战》中用matplotlib绘制决策树, python3

论坛 期权论坛 脚本     
匿名技术用户   2020-12-30 13:44   11   0

人笨, 绘制树形图那里的代码看了几次也没看懂(很多莫名其妙的(全局?)变量), 然后就自己想办法写了个

import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties


def getTreeDB(mytree):
 """
 利用递归获取字典最大深度, 子叶数目
 :param mytree:一个字典树, 或者树的子叶节点(字符型)
 :return:返回 树的深度, 子叶数目
 """
 if not isinstance(mytree, dict):  # 如果是子叶节点, 返回1
  return 1, 1
 depth = []  # 储存每条树枝的深度
 leafs = 0  # 结点当前的子叶数目
 keys = list(mytree.keys())  # 获取字典的键
 if len(keys) == 1:  # 如果键只有一个(说明是个结点而不是树枝)
  mytree = mytree[keys[0]]  # 结点的value一定是树枝(判断的是每条支路的深度而不是结点)
 for key in mytree.keys():  # 遍历每条树枝
  res = getTreeDB(mytree[key])  # 获取子树的深度, 子叶数目
  depth.append(1 + res[0])  # 把每条树枝的深度(加上自身)放在节点的深度集合中
  leafs += res[1]  # 累积子叶数目
 return max(depth), leafs  # 返回最大的深度值, 子叶数目


def plotArrow(what, xy1, xy2, which):
 """
 画一个带文字描述的箭头, 文字在箭头中间
 :param what: 文字内容
 :param xy1: 箭头起始坐标
 :param xy2: 箭头终点坐标
 :param which: 箭头所在的图对象
 :return: suprise
 """

 # 画箭头
 which.arrow(
  xy1[0], xy1[1], xy2[0] - xy1[0], xy2[1] - xy1[1],
  length_includes_head = True,  # 增加的长度包含箭头部分
  head_width = 0.15, head_length = 0.5, fc = 'r', ec = 'brown')

 tx = (xy1[0] + xy2[0]) / 2
 ty = (xy1[1] + xy2[1]) / 2

 zhfont = FontProperties(fname = 'msyh.ttc')  # 显示中文的方法

 # 画文字
 which.annotate(
  what,
  size = 10,
  xy = (tx, ty),
  xytext = (-5, 5),  # 偏移量
  textcoords = 'offset points',
  bbox = dict(boxstyle = "square", ec = (1., 0.5, 0.5), fc = (1., 0.8, 0.8)),  # 外框, fc 内部颜色, ec 边框颜色
  fontproperties = zhfont)  # 字体


def plotNode(what, xy, which, mod = 'any'):
 """
 画树的节点
 :param what: 节点的内容
 :param xy: 节点的坐标
 :param which: 节点所在的图对象
 :param mod: 判断节点是子叶还是非子叶(颜色不同)
 :return: suprise
 """
 zhfont = FontProperties(fname = 'msyh.ttc')  # 显示中文的方法, msyh.ttc是微软雅黑的字体文件
 if mod == 'leaf':
  color = 'yellow'
 else:
  color = 'greenyellow'
 which.text(
  xy[0], xy[1],
  what, size = 18,
  ha = "center", va = "center",
  bbox = dict(boxstyle = "round", ec = (1., 0.5, 0.5), fc = color),
  fontproperties = zhfont)


def plotInfo(what, which):
 """
 提示图中内容
 :param what: 子叶标签
 :param which: 所在的图对象
 :return: suprise
 """
 what = '绿色: 特征,  粉红: 特征值,  黄色: ' + what
 zhfont = FontProperties(fname = 'msyh.ttc')  # 显示中文的方法
 which.text(
  2, 2,
  what, size = 18,
  ha = "center", va = "center",
  bbox = dict(boxstyle = "round", ec = (1., 0.5, 0.5), fc = '#BB91A6'),
  fontproperties = zhfont)


def plotTree(mytree, figxsize, figysize, what):
 """
 利用递归画决策树
 所有子叶节点两两之间的间距都是xsize
 每一层节点之间的间距都是ysize
 子叶节点的数目都是确定的, 所以横坐标也是确定的, 从左往右第leafnum个子叶节点的横坐标x = leafs * xsize
 非子叶节点的横坐标由该节点孩子的横坐标确定, x = 孩子横坐标平均值
 每一层节点的纵坐标由层数deep确定, y = ylen - deep * ysize, 其中ylen为画板高度
 :param mytree: 要画的字典树
 :param figxsize: 画布的x长度    (两者会影响显示效果)
 :param figysize: 画布的y长度    (这两个值很影响树的分布,(不宜过大)(?) ))
 :param what: 子叶的标签(用于提示图的结果是什么)
 :return: suprise
 """

 def plotAll(subtree, deep, leafnum):
  """
  内部函数, 递归画图, 会使用外部的变量
  :param subtree: 要画的子树
  :param deep: 子树根节点所在的深度
  :param leafnum: 下一个子叶节点从左到右的排号(用来决定下一个子叶节点的横坐标)
  :return:suprise
  """
  if not isinstance(subtree, dict):  # 如果是子叶节点(非字典)
   x = leafnum * xsize  # 计算横坐标
   y = ylen - deep * ysize  # 计算纵坐标
   plotNode(subtree, (x, y), ax, 'leaf')  # 画节点
   return x, y, leafnum + 1  # 返回子叶节点的坐标, 已画子叶数目+1

  key = list(subtree.keys())  # 获取子树的根节点的键(节点的名称)
  if len(key) != 1:  # 传进来的子树应该只有一个根节点
   raise TypeError("非字典树")  # 不满足就报错
  xlist = []  # 储存根节点孩子的横坐标
  ylist = []  # 储存根节点孩子的纵坐标
  keyvalue = subtree[key[0]]  # 根节点的孩子(子字典, 子字典的key为权值, value为子树)
  for k in keyvalue:  # k为每一格权值(每一个选择)
   res = plotAll(keyvalue[k], deep + 1, leafnum)  # 获取这个孩子的坐标
   leafnum = res[2]  # 更新已画的子叶树
   xlist.append(res[0])  # 储存孩子的坐标
   ylist.append(res[1])
  x = sum(xlist) / len(xlist)  # 求平均得出该根节点的横坐标
  y = ylen - deep * 3  # 计算该根节点的纵坐标
  plotNode(key[0], (x, y), ax)  # 画该节点

  i = 0
  for k in keyvalue:  # 依次画出根节点与孩子之间的箭头
   plotArrow(k, (x, y), (xlist[i], ylist[i]), ax)
   i += 1

  return x, y, leafnum  # 返回该节点的坐标

 xsize, ysize = 4, 3  # 默认子叶间距为4, 每层的间距为3 (设置为这两个值的原因...我觉得这样好看些...可以试试别的值)
 fig = plt.figure(figsize = (figxsize, figysize))  # 一张画布
 axprops = dict(xticks = [], yticks = [])  # 横纵坐标显示的数字(设置为空, 不显示)
 ax = fig.add_subplot(111, frameon = False, **axprops)  # 隐藏坐标轴
 depth, leaf = getTreeDB(mytree)  # 获取深度, 子叶节点数目
 xlen, ylen = 4 * (leaf + 1), 3 * (depth + 1)  # 计算横纵间距
 ax.set_xlim(0, xlen)  # 设置坐标系x, y的范围
 ax.set_ylim(0, ylen)
 plotAll(mytree, 1, 1)  # 画树
 plotInfo(what, ax)  # 提示标签
 plt.show()  # show show show show show


testtree = {'有自己的房子': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}}  # 一个树
testlabel = ['年龄', '有工作', '有自己的房子', '信贷情况']  #训练数据的标签
plotTree(testtree, 10, 6, testlabel[-1])

看起来还是不错

代码的注释可能有(fei)点(chang)令人费解... 有问题的地方很多...

测试数据来源 机器学习 决策树算法实战(理论+详细的python3代码实现)

画箭头方法的来源 180122 利用matplotlib绘制箭头的2种方法, 自己改了下颜色,比例

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:7942463
帖子:1588486
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP