女神也用的约会决策:决策树算法实践

一、决策树

决策树是一个应用非常广泛的模型。由于决策树算法模型非常有价值,还衍生出了很多高级版本,比如随机森林、梯度提升决策树算法(GBDT)。

今天要介绍的是一个应用非常广泛的机器学习模型——决策树。首先从一个例子出发,看看女神是怎样决策要不要约会的;然后分析它的算法原理、思路形成的过程;由于决策树非常有价值,还衍生出了很多高级版本。决策树是机器学习中强大的有监督学习模型,本质上是一个二叉树的流程图,其中每个节点根据某个特征变量将一组观测值拆分。决策树的目标是将数据分成多个组,这样一个组中的每个元素都属于同一个类别。决策树也可以用来近似连续的目标变量。在这种情况下,树将进行拆分,使每个组的均方误差最小。决策树的一个重要特性可解释性好,即使你不熟悉机器学习技术,也可以理解决策树在做什么。

举个例子:
一位女神有很多的追求者,她肯定不会和每个追求者都约会,因为时间不够,必须要好好管理自己的时间才行。于是女神给每个想要约会的人发信息说:“发一下你的简历叭,我考虑一下~~。” 接收到简历,第一眼先看照片,颜值打几分?然后再看年收入,长得帅的可以少挣点,毕竟“帅也可以当饭吃啊”。不帅的呢?那收入要求高一点,“颜值不够,薪资来凑”。薪资还差点的,再看看学历是不是985/211研究生,看看身高有没有180……所以这下你就可以对号入座了,发现自己哪条都不符合,好了,还是去好好干活挣钱吧。

由此可见,女神是否答应你的约会的筛选条件有颜值、身高、收入、学历等,每一项都会对最后是否约会的结果产生影响,即女神通过对这几种条件的判断,决定是否要安排约会。上面这个举例体现了决策树的基本思路,

1. CART概述

所谓 CART 算法,全名叫Classification and Regression Tree,即分类与回归树。顾名思义,相较于此前的 ID3 算法和 C4.5 算法,CART除了可以用于分类任务外,还可以完成回归分析。完整的 CART 算法包括特征选择决策树生成决策树剪枝三个部分。

CART是在给定输入随机变量 X 条件下输出随机变量 Y 的条件概率分布的学习方法。CART算法通过选择最优特征和特征值进行划分,将输入空间也就是特征空间划分为有限个单元,并在这些单元上确定预测的概率分布,也就是在输入给定的条件下输出条件概率分布。

CART算法主要包括回归树和分类树两种。回归树用于目标变量为连续型的建模任务,其特征选择准则用的是平方误差最小准则。分类树用于目标变量为离散型的的建模任务,其特征选择准则用的是基尼指数(Gini Index),这也有别于此前 ID3 的信息增益准则和 C4.5 的信息增益比准则。无论是回归树还是分类树,其算法核心都在于递归地选择最优特征构建决策树。

除了选择最优特征构建决策树之外,CART算法还包括另外一个重要的部分:剪枝。剪枝可以视为决策树算法的一种正则化手段,作为一种基于规则的非参数监督学习方法,决策树在训练很容易过拟合,导致最后生成的决策树泛化性能不高。

另外,CART作为一种单模型,也是 GBDT 的基模型。当很多棵 CART 分类树或者回归树集成起来的时候,就形成了 GBDT 模型。

2. 回归树

给定输入特征向量 X 和输出连续型变量Y,一个回归树的生成就对应着输入空间的一个划分以及在划分的单元上的输出值。假设输入空间被划分为 M 个单元R1,R2…,RM,在每一个单元 Rm 上都有一个固定的输出值Cm,所以回归树模型可以表示为


在输入空间划分确定时,回归树算法使用最小平方误差准则来选择最优特征和最优且切分点。具体来说就是对全部特征进行遍历,按照最小平方误差准则来求解最优切分变量和切分点。即求解如下公式:

这种按照最小平方误差准则来递归地寻找最佳特征和最优切分点构造决策树的过程就是最小二乘回归树算法。

完整的最小二乘回归树生成算法如下:(来自统计学习方法)

最小二乘回归树拟合数据如下图所示。可以看到,回归树的树深度越大的情况下,模型复杂度越高,对数据的拟合程度就越好,但相应的泛化能力就得不到保证。

3. 分类树

CART分类树跟回归树大不相同,但与此前的 ID3 和 C4.5 基本套路相同。ID3和 C4.5 分别采用信息增益和信息增益比来选择最优特征,但CART分类树采用Gini指数来进行特征选择。先来看 Gini 指数的定义。

Gini指数是针对概率分布而言的。假设在一个分类问题中有 K 个类,样本属于第 k 个类的概率为Pk,则该样本概率分布的基尼指数为

具体到实际的分类计算中,给定样本集合 D 的 Gini 指数计算如下

相应的条件 Gini 指数,也即给定特征 A 的条件下集合 D 的 Gini 指数计算如下:

实际构造分类树时,选择条件 Gini 指数最小的特征作为最优特征构造决策树。完整的分类树构造算法如下:

一棵基于 Gini 指数准则选择特征的分类树构造:

4. 剪枝

基于最小平方误差准则和 Gini 指数准则构造好决策树只能算完成的模型的一半。为了构造好的决策树能够具备更好的泛化性能,通过我们需要对其进行剪枝(pruning)。在特征选择算法效果趋于一致的情况下,剪枝逐渐成为决策树更为重要的一部分。

所谓剪枝,就是将构造好的决策树进行简化的过程。具体而言就是从已生成的树上裁掉一些子树或者叶结点,并将其根结点或父结点作为新的叶结点。

通常来说,有两种剪枝方法。一种是在决策树生成过程中进行剪枝,也叫预剪枝(pre-pruning)。另一种就是前面说的基于生成好的决策树自底向上的进行剪枝,又叫后剪枝(post-pruning)。

先来看预剪枝。预剪枝是在树生成过程中进行剪枝的方法,其核心思想在树中结点进行扩展之前,先计算当前的特征划分能否带来决策树泛化性能的提升,如果不能的话则决策树不再进行生长。预剪枝比较直接,算法也简单,效率高,适合大规模问题计算,但预剪枝可能会有一种”早停”的风险,可能会导致模型欠拟合。

后剪枝则是等树完全生长完毕之后再从最底端的叶子结点进行剪枝。CART剪枝正是一种后剪枝方法。简单来说,就是自底向上对完全树进行逐结点剪枝,每剪一次就形成一个子树,一直到根结点,这样就形成一个子树序列。然后在独立的验证集数据上对全部子树进行交叉验证,哪个子树误差最小,哪个就是最优子树。

二、算法的优缺点

决策树最初的版本称为ID3( Iterative Dichotomiser 3 ),ID3的缺点是无法处理数据是连续值的情况,也无法处理数据存在缺失的问题,需要在准备数据环节把缺失字段进行补齐或者删除数据。

后来有人提出了改进方案称为C4.5,加入了对连续值属性的处理,同时也可以处理数据缺失的情况。
还有一种目前应用最多的 CART( Classification And Regression Tree)分类与回归树,每次分支只使用二叉树划分,同时可以用于解决回归问题。

关于这三种决策树,我列了一个对比的表格,可以看到它们之间的区别:

下面的优缺点是针对 CART 树来讲,因为现在 CART 是主流的决策树算法,而且在 sklearn 工具包中使用的也是 CART 决策树。

优点

  • 非常直观,可解释极强。 在生成的决策树上,每个节点都有明确的判断分支条件,所以非常容易看到为什么要这样处理,比起神经网络模型的黑盒处理,高解释性的模型非常受金融保险行业的欢迎。在后面的动手环节,我们能看到训练完成的决策树可以直接输出出来,以图形化的方式展示给我们生成的决策树每一个节点的判断条件是什么样子的。
  • 预测速度比较快。 由于最终生成的模型是一个树形结构,对于一条新数据的预测,只需要按照条件在每一个节点进行判定就可以。通常来说,树形结构都有助于提升运算速度。
  • 既可以处理离散值也可以处理连续值,还可以处理缺失值。

缺点

  • 容易过拟合。 试想在极端的情况下,我们根据样本生成了一个最完美的树,那么样本中出现的每一个值都会有一条路径来拟合,所以如果样本中存在一些问题数据,或者样本与测试数据存在一定的差距时,就会看出泛化性能不好,出现了过拟合的现象。
  • 需要处理样本不均衡的问题。 如果样本不均衡,某些特征的样本比例过大,最终的模型结果将会更偏向这些特征。
  • 样本的变化会引发树结构巨变。

关于剪枝

  • 决策树容易过拟合,那么我们需要使用剪枝的方式来使得模型的泛化能力更好,所以剪枝可以理解为简化我们的决策树,去掉不必要的节点路径以提高泛化能力。
  • 预剪枝: 在决策树构建之初就设定一个阈值,当分裂节点的熵阈值小于设定值的时候就不再进行分裂了;然而这种方法的实际效果并不是很好,因为谁也没办法预料到我们设定的恰好是我们想要的。
  • 后剪枝: 后剪枝方法就是在我们的决策树已经构建完成以后,再根据设定的条件来判断是否要合并一些中间节点,使用叶子节点来代替。在实际的情况下,通常都是采用后剪枝的方案。

三、算法实践

以经典的鸢尾花数据分类为例,熟悉决策树算法基本原理。使用 sklearn 自带的鸢尾花数据集,这个数据集里面有 150 条数据,共有 3 个类别,即 Setosa 鸢尾花、Versicolour 鸢尾花和 Virginica 鸢尾花,每个类别有 50 条数据,每条数据有 4 个维度,分别记录了鸢尾花的花萼长度、花萼宽度、花瓣长度和花瓣宽度。

导入需要的依赖库

from sklearn import datasets   # sklearn自带的数据集
from sklearn.tree import DecisionTreeClassifier   # 引入决策树算法包
import numpy as np    # 矩阵运算库numpy

# 设置随机种子,不设置的话默认是按系统时间作为参数
# 设置后可以保证我们每次产生的随机数是一样的,便于测试
np.random.seed(6)

加载数据

iris = datasets.load_iris()
iris_x = iris.data      # 数据部分
iris_y = iris.target    # 类别部分
print(iris_x)
print(iris_y)

结果如下:

这个数据集里面有 150 条数据,共有 3 个类别,即 Setosa 鸢尾花、Versicolour 鸢尾花和 Virginica 鸢尾花,每个类别有 50 条数据,每条数据有 4 个维度,分别记录了鸢尾花的花萼长度、花萼宽度、花瓣长度和花瓣宽度。

决策树算法预测,在模型训练时,设置了树的最大深度为4。

# permutation 接收一个数作为参数(这里为数据集长度150) 产生一个0-149乱序一维数组
randomarr= np.random.permutation(len(iris_x))
# 随机从150条数据中选120条作为训练集,30条作为测试集
iris_x_train = iris_x[randomarr[:-30]] # 训练集数据
iris_y_train = iris_y[randomarr[:-30]] # 训练集标签
iris_x_test = iris_x[randomarr[-30:]]  # 测试集数据
iris_y_test = iris_y[randomarr[-30:]]  # 测试集标签

# 在模型训练时,设置了树的最大深度为 4
dct = DecisionTreeClassifier(max_depth=4)
dct.fit(iris_x_train, iris_y_train)
# 调用预测方法,主要接收一个参数:测试数据集
iris_y_predict = dct.predict(iris_x_test)
# 计算各测试样本预测的概率值 这里我们没有用概率值,但是在实际工作中可能会参考概率值来进行最后结果的筛选,而不是直接使用给出的预测标签
probility = dct.predict_proba(iris_x_test)
print(probility)
print('------------------------------------------------------------------')
# 调用该对象的打分方法,计算出准确率
score = dct.score(iris_x_test, iris_y_test, sample_weight=None)
# 输出测试的结果
print('iris_y_predict = ')
print(iris_y_predict)
print('------------------------------------------------------------------')
# 输出原始测试数据集的正确标签,以方便对比
print('iris_y_test = ')
print(iris_y_test)
print('------------------------------------------------------------------')
# 输出准确率计算结果
print('Accuracy:', score)

结果如下:


简单测试了一下,预测分类的准确率还不错~~

决策树可视化:

# 引入画图相关的包 
from IPython.display import Image
from sklearn import tree
# dot是一个程式化生成流程图的简单语言
import pydotplus

dot_data = tree.export_graphviz(dct, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)

graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())


调用 fit 方法进行模型训练,决策树算法会生成一个树形的判定模型,我们可以把决策树算法生成的模型可视化展示出来,便于我们更直观地理解决策树算法。图中可以看到每一次的判定条件以及基尼系数,还有能够落入此决策的样本数量和分类的类别。

打赏
文章很值,打赏犒劳作者一下
相关推荐
<p style="text-align:left;"> <span> </span> </p> <p class="ql-long-24357476" style="font-size:11pt;color:#494949;"> <span style="font-family:"color:#E53333;font-size:14px;background-color:#FFFFFF;line-height:24px;"><span style="line-height:24px;">限时福利1:</span></span><span style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;">购课进答疑群专享柳峰(刘运强)老师答疑服务。</span> </p> <p> <br /> </p> <p class="ql-long-24357476"> <strong><span style="color:#337FE5;font-size:14px;">为什么说每一个程序员都应该学习MySQL?</span></strong> </p> <p class="ql-long-24357476"> <span style="font-size:14px;">根据《2019-2020年中国开发者调查报告》显示,超83%的开发者都在使用MySQL数据库。</span> </p> <p class="ql-long-24357476"> <img src="https://img-bss.csdn.net/202003301212574051.png" alt="" /> </p> <p class="ql-long-24357476"> <span style="font-size:14px;">使用量大同时,掌握MySQL早已是运维、DBA的必备技能,甚至部分IT开发岗位也要求对数据库使用和原理有深入的了解和掌握。</span><br /> <br /> <span style="font-size:14px;">学习编程,你可能会犹豫选择 C++ 还是 Java;入门数据科学,你可能会纠结于选择 Python 还是 R;但无论如何, MySQL 都是 IT 从业人员不可或缺的技能!</span> </p> <span></span> <p> <br /> </p> <p> <span> </span> </p> <h3 class="ql-long-26664262"> <p style="font-size:12pt;"> <strong class="ql-author-26664262 ql-size-14"><span style="font-size:14px;color:#337FE5;">【课程设计】</span></strong> </p> <p style="font-size:12pt;"> <span style="color:#494949;font-weight:normal;"><br /> </span> </p> <p style="font-size:12pt;"> <span style="color:#494949;font-weight:normal;font-size:14px;">在本课程中,刘运强老师会结合自己十多年来对MySQL的心得体会,通过课程给你分享一条高效的MySQL入门捷径,让学员少走弯路,彻底搞懂MySQL。</span> </p> <p style="font-size:12pt;"> <span style="color:#494949;font-weight:normal;"><br /> </span> </p> <p style="font-size:12pt;"> <span style="font-weight:normal;font-size:14px;">本课程包含3大模块:</span><span style="font-weight:normal;font-size:14px;"> </span> </p> </h3> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <strong class="ql-author-26664262"><span style="font-size:14px;">一、基础篇:</span></strong> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <span class="ql-author-26664262" style="font-size:14px;">主要以最新的MySQL8.0安装为例帮助学员解决安装与配置MySQL的问题,并对MySQL8.0的新特性做一定介绍,为后续的课程展开做好环境部署。</span> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <span class="ql-author-26664262" style="font-size:14px;"><br /> </span> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <strong class="ql-author-26664262"><span style="font-size:14px;">二、SQL语言篇</span></strong><span class="ql-author-26664262" style="font-size:14px;">:</span> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <span class="ql-author-26664262" style="font-size:14px;">本篇主要讲解SQL语言的四大部分数据查询语言DQL,数据操纵语言DML,数据定义语言DDL,数据控制语言DCL,</span><span style="font-size:14px;">学会熟练对库表进行增删改查等必备技能。</span> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <span style="font-size:14px;"><br /> </span> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <strong class="ql-author-26664262"><span style="font-size:14px;">三、MySQL进阶篇</span></strong><span style="font-size:14px;">:</span> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <span style="font-size:14px;">本篇可以帮助学员更加高效的管理线上的MySQL数据库;具备MySQL的日常运维能力,语句调优、备份恢复等思路。</span> </p> <span><span> <p style="font-size:11pt;color:#494949;"> <span style="font-size:14px;"> </span><img src="https://img-bss.csdn.net/202004220208351273.png" alt="" /> </p> </span></span>
©️2020 CSDN 皮肤主题: 书香水墨 设计师:CSDN官方博客 返回首页

打赏

叶庭云

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付 29.90元
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值