糖尿病康复,内容丰富有趣,生活中的好帮手!
糖尿病康复 > 机器学习基础算法17-决策树-鸢尾花数据集分类及决策树深度与过拟合

机器学习基础算法17-决策树-鸢尾花数据集分类及决策树深度与过拟合

时间:2019-03-01 15:36:14

相关推荐

机器学习基础算法17-决策树-鸢尾花数据集分类及决策树深度与过拟合

文章目录

决策树代码运行结果多个决策树

决策树

决策树(Decision Tree)是一种基本的分类与回归方法,当决策树用于分类时称为分类树,用于回归时称为回归树。主要介绍分类树。

决策树由结点和有向边组成。结点有两种类型:内部结点和叶结点,其中内部结点表示一个特征或属性,叶结点表示一个类。

决策树学算法通常是一个递归地选择最优特征,并根据该特征对训练数据进行分割,使得对各个子数据集有一个最好的分类的过程。根据信息增益准则的特征选择方法:对于训练数据集(或子集),计算其每个特征的信息增益,并比较它们的大小,选择信息增益最大的特征。

代码

import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport matplotlib as mplfrom sklearn import treefrom sklearn.tree import DecisionTreeClassifierfrom sklearn.model_selection import train_test_splitfrom sklearn.pipeline import Pipelinefrom sklearn.metrics import accuracy_score# 生成dot格式文件import pydotplus# 花萼长度、花萼宽度,花瓣长度,花瓣宽度iris_feature_E = 'sepal length', 'sepal width', 'petal length', 'petal width'iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'iris_class = 'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'if __name__ == "__main__":# 解决matplotlib不能读取中文的问题mpl.rcParams['font.sans-serif'] = [u'SimHei']mpl.rcParams['axes.unicode_minus'] = False# 加载数据path = 'iris.data' # 数据文件路径# header=None文件指定了列名data = pd.read_csv(path, header=None)# print(data)# 特征值x = data[range(4)]# 目标值-由于y为字符串,需要转换成类别数据,再转换成编码-0,1,2y = pd.Categorical(data[4]).codes# print(y)# 为了可视化,仅使用前两列特征x = x.iloc[:, :2]# 相当于x = x[[0,1]]# 分割数据集-训练集与测试集x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=1)print(y_test.shape)'''训练模型'''# 决策树参数估计# DecisionTreeClassifier:决策树分类器# 损失函数criterion:gini或者entropy,前者是基尼系数,后者是信息熵。# min_samples_split = 10:如果该结点包含的样本数目大于10,则(有可能)对其分支# min_samples_leaf = 10:若将某结点分支后,得到的每个子结点样本数目都大于10,则完成分支;否则,不进行分支model = DecisionTreeClassifier(criterion='entropy')model.fit(x_train, y_train)y_test_hat = model.predict(x_test) # 测试数据print('accuracy_score = ', accuracy_score(y_test_hat, y_test))# 保存# dot -Tpng my.dot -o my.png# 1、输出with open('iris.dot', 'w') as f:tree.export_graphviz(model, out_file=f)# 2、给定文件名# tree.export_graphviz(model, out_file='iris1.dot')# 画图N, M = 50, 50 # 横纵各采样多少个值x1_min, x2_min = x.min()x1_max, x2_max = x.max()t1 = np.linspace(x1_min, x1_max, N)t2 = np.linspace(x2_min, x2_max, M)x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点# x1.flat生成一个迭代器x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点print(x_show.shape) # (2500, 2)# # 无意义,只是为了凑另外两个维度# # 打开该注释前,确保注释掉x = x[:, :2]# x3 = np.ones(x1.size) * np.average(x[:, 2])# x4 = np.ones(x1.size) * np.average(x[:, 3])# x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1) # 测试点# 颜色cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])y_show_hat = model.predict(x_show) # 预测值print(y_show_hat.shape)print(y_show_hat)# 使之与输入的形状相同y_show_hat = y_show_hat.reshape(x1.shape)print(y_show_hat)plt.figure(facecolor='w')# plt.pcolormesh的作用在于能够直观表现出分类边界plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light) # 预测值的显示# y_test.ravel()表示将多维数组转换为一维数组# 测试数据plt.scatter(x_test[0], x_test[1], c=y_test.ravel(), edgecolors='k', s=150, zorder=10, cmap=cm_dark,marker='*')# print(y_test.ravel())# 全部数据plt.scatter(x[0], x[1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark)plt.xlabel(iris_feature[0], fontsize=15)plt.ylabel(iris_feature[1], fontsize=15)plt.xlim(x1_min, x1_max)plt.ylim(x2_min, x2_max)plt.grid(True)plt.title(u'鸢尾花数据的决策树分类', fontsize=17)plt.show()# 自己求一个测试集的精确度y_test = y_test.reshape(-1)print(y_test_hat)print(y_test)result = (y_test_hat == y_test) # True则预测正确,False则预测错误acc = np.mean(result)print('准确度: %.2f%%' % (100 * acc))# 分析决策树深度对准确度的影响# 过拟合:错误率depth = np.arange(1, 15)err_list = []for d in depth:clf = DecisionTreeClassifier(criterion='entropy', max_depth=d)clf.fit(x_train, y_train)y_test_hat = clf.predict(x_test) # 测试数据result = (y_test_hat == y_test) # True则预测正确,False则预测错误if d == 1:print(result)err = 1 - np.mean(result)err_list.append(err)# print d, ' 准确度: %.2f%%' % (100 * err)print(d, ' 错误率: %.2f%%' % (100 * err))plt.figure(facecolor='w')plt.plot(depth, err_list, 'ro-', lw=2)plt.xlabel(u'决策树深度', fontsize=15)plt.ylabel(u'错误率', fontsize=15)plt.title(u'决策树深度与过拟合', fontsize=17)plt.grid(True)plt.show()

运行结果

(45,)accuracy_score = 0.6222222222222222(2500, 2)(2500,)[0 0 0 ... 2 2 2][[0 0 0 ... 1 1 1][0 0 0 ... 1 1 1][0 0 0 ... 1 1 1]...[0 0 0 ... 2 2 2][0 0 0 ... 2 2 2][0 0 0 ... 2 2 2]][0 1 2 0 2 2 1 0 0 2 2 0 1 2 1 0 2 1 0 0 1 0 2 0 2 1 0 0 1 1 2 2 2 2 1 0 10 2 1 2 0 1 1 1][0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2 1 2 1 2 2 0 10 1 2 2 0 2 2 1]准确度: 62.22%[False False False True True False True True True True False TrueTrue False False True False False True True False False False TrueTrue False False True False True False True False True True TrueFalse True False True True True True True False]1 错误率: 44.44%2 错误率: 40.00%3 错误率: 20.00%4 错误率: 24.44%5 错误率: 24.44%6 错误率: 26.67%7 错误率: 35.56%8 错误率: 37.78%9 错误率: 37.78%10 错误率: 40.00%11 错误率: 37.78%12 错误率: 40.00%13 错误率: 37.78%14 错误率: 37.78%

多个决策树

import numpy as npimport pandas as pdimport matplotlib as mplimport matplotlib.pyplot as pltfrom sklearn.tree import DecisionTreeClassifier# 'sepal length', 'sepal width', 'petal length', 'petal width'iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'if __name__ == "__main__":mpl.rcParams['font.sans-serif'] = [u'SimHei'] # 黑体 FangSong/KaiTimpl.rcParams['axes.unicode_minus'] = Falsepath = 'iris.data' # 数据文件路径data = pd.read_csv(path, header=None)x_prime = data[range(4)]y = pd.Categorical(data[4]).codes# 特征两两组合,共有6对feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]plt.figure(figsize=(10, 9), facecolor='#FFFFFF')# enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,# 同时列出数据和数据下标,一般用在 for 循环当中。for i, pair in enumerate(feature_pairs):# 准备数据x = x_prime[pair]# 决策树学习clf = DecisionTreeClassifier(criterion='entropy', min_samples_leaf=3)clf.fit(x, y)# 画图N, M = 500, 500 # 横纵各采样多少个值x1_min, x2_min = x.min()x1_max, x2_max = x.max()t1 = np.linspace(x1_min, x1_max, N)t2 = np.linspace(x2_min, x2_max, M)x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点x_test = np.stack((x1.flat, x2.flat), axis=1) # 测试点# 训练集上的预测结果y_hat = clf.predict(x)y = y.reshape(-1)c = np.count_nonzero(y_hat == y) # 统计预测正确的个数print('特征: ', iris_feature[pair[0]], ' + ', iris_feature[pair[1]])print('\t预测正确数目:', c)print('\t准确率: %.2f%%' % (100 * float(c) / float(len(y))))# 显示cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])y_hat = clf.predict(x_test) # 预测值y_hat = y_hat.reshape(x1.shape) # 使之与输入的形状相同plt.subplot(2, 3, i+1)plt.pcolormesh(x1, x2, y_hat, cmap=cm_light) # 预测值plt.scatter(x[pair[0]], x[pair[1]], c=y, edgecolors='k', cmap=cm_dark) # 样本plt.xlabel(iris_feature[pair[0]], fontsize=14)plt.ylabel(iris_feature[pair[1]], fontsize=14)plt.xlim(x1_min, x1_max)plt.ylim(x2_min, x2_max)plt.grid()plt.suptitle(u'决策树对鸢尾花数据的两特征组合的分类结果', fontsize=18)plt.tight_layout(2)plt.subplots_adjust(top=0.92)plt.show()

运行结果

特征: 花萼长度 + 花萼宽度预测正确数目: 123准确率: 82.00%特征: 花萼长度 + 花瓣长度预测正确数目: 145准确率: 96.67%特征: 花萼长度 + 花瓣宽度预测正确数目: 144准确率: 96.00%特征: 花萼宽度 + 花瓣长度预测正确数目: 143准确率: 95.33%特征: 花萼宽度 + 花瓣宽度预测正确数目: 145准确率: 96.67%特征: 花瓣长度 + 花瓣宽度预测正确数目: 147准确率: 98.00%

如果觉得《机器学习基础算法17-决策树-鸢尾花数据集分类及决策树深度与过拟合》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。