python中使用scikit-learn和pandas決策樹(shù)進(jìn)行iris鳶尾花數(shù)據(jù)分類(lèi)建模和交叉驗(yàn)證
原文鏈接:http://tecdat.cn/?p=9326
?
?
在這篇文章中,我將使用python中的決策樹(shù)(用于分類(lèi))。重點(diǎn)將放在基礎(chǔ)知識(shí)和對(duì)最終決策樹(shù)的理解上。
?
導(dǎo)入
因此,首先我們進(jìn)行一些導(dǎo)入。
from __future__ import print_function
import os
import subprocess
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz
?
數(shù)據(jù)
接下來(lái),我們需要考慮一些數(shù)據(jù)。我將使用著名的iris數(shù)據(jù)集,該數(shù)據(jù)集可對(duì)各種不同的iris類(lèi)型進(jìn)行各種測(cè)量。pandas和sckit-learn都可以輕松導(dǎo)入這些數(shù)據(jù),我將使用pandas編寫(xiě)一個(gè)從csv文件導(dǎo)入的函數(shù)。這樣做的目的是演示如何將scikit-learn與pandas一起使用。因此,我們定義了一個(gè)獲取iris數(shù)據(jù)的函數(shù):
def get_iris_data():
"""從本地csv或pandas中獲取iris數(shù)據(jù)。"""
if os.path.exists("iris.csv"):
print("-- iris.csv found locally")
df = pd.read_csv("iris.csv", index_col=0)
else:
print("-- trying to download from github")
fn = "https://raw.githubusercontent.com/pydata/pandas/" + \
"master/pandas/tests/data/iris.csv"
try:
df = pd.read_csv(fn)
except:
exit("-- Unable to download iris.csv")
with open("iris.csv", 'w') as f:
print("-- writing to local iris.csv file")
df.to_csv(f)
return df
?
此函數(shù)首先嘗試在本地讀取數(shù)據(jù)。利用os.path.exists()?方法。如果在本地目錄中找到iris.csv文件,則使用pandas通過(guò)pd.read_csv()讀取文件。
如果本地iris.csv沒(méi)有發(fā)現(xiàn),抓取URL數(shù)據(jù)來(lái)運(yùn)行。
下一步是獲取數(shù)據(jù),并使用head()和tail()方法查看數(shù)據(jù)的樣子。因此,首先獲取數(shù)據(jù):
df = get_iris_data()
-- iris.csv found locally
然后 :
print("* df.head()", df.head(), sep="\n", end="\n\n")
print("* df.tail()", df.tail(), sep="\n", end="\n\n")
* df.head()
SepalLength ?SepalWidth ?PetalLength ?PetalWidth ? ? ? ? Name
0 ? ? ? ? ?5.1 ? ? ? ? 3.5 ? ? ? ? ?1.4 ? ? ? ? 0.2 ?Iris-setosa
1 ? ? ? ? ?4.9 ? ? ? ? 3.0 ? ? ? ? ?1.4 ? ? ? ? 0.2 ?Iris-setosa
2 ? ? ? ? ?4.7 ? ? ? ? 3.2 ? ? ? ? ?1.3 ? ? ? ? 0.2 ?Iris-setosa
3 ? ? ? ? ?4.6 ? ? ? ? 3.1 ? ? ? ? ?1.5 ? ? ? ? 0.2 ?Iris-setosa
4 ? ? ? ? ?5.0 ? ? ? ? 3.6 ? ? ? ? ?1.4 ? ? ? ? 0.2 ?Iris-setosa
* df.tail()
SepalLength ?SepalWidth ?PetalLength ?PetalWidth ? ? ? ? ? ?Name
145 ? ? ? ? ?6.7 ? ? ? ? 3.0 ? ? ? ? ?5.2 ? ? ? ? 2.3 ?Iris-virginica
146 ? ? ? ? ?6.3 ? ? ? ? 2.5 ? ? ? ? ?5.0 ? ? ? ? 1.9 ?Iris-virginica
147 ? ? ? ? ?6.5 ? ? ? ? 3.0 ? ? ? ? ?5.2 ? ? ? ? 2.0 ?Iris-virginica
148 ? ? ? ? ?6.2 ? ? ? ? 3.4 ? ? ? ? ?5.4 ? ? ? ? 2.3 ?Iris-virginica
149 ? ? ? ? ?5.9 ? ? ? ? 3.0 ? ? ? ? ?5.1 ? ? ? ? 1.8 ?Iris-virginica
?
從這些信息中,我們可以討論我們的目標(biāo):給定特征SepalLength,?SepalWidth,?PetalLength?和PetalWidth來(lái)預(yù)測(cè)iris類(lèi)型。
預(yù)處理
為了將這些數(shù)據(jù)傳遞到scikit-learn,我們需要將Names編碼為整數(shù)。為此,我們將編寫(xiě)另一個(gè)函數(shù),并返回修改后的數(shù)據(jù)框以及目標(biāo)(類(lèi))名稱(chēng)的列表:
?
讓我們看看有什么:
* df2.head()
Target ? ? ? ? Name
0 ? ? ? 0 ?Iris-setosa
1 ? ? ? 0 ?Iris-setosa
2 ? ? ? 0 ?Iris-setosa
3 ? ? ? 0 ?Iris-setosa
4 ? ? ? 0 ?Iris-setosa
* df2.tail()
Target ? ? ? ? ? ?Name
145 ? ? ? 2 ?Iris-virginica
146 ? ? ? 2 ?Iris-virginica
147 ? ? ? 2 ?Iris-virginica
148 ? ? ? 2 ?Iris-virginica
149 ? ? ? 2 ?Iris-virginica
* targets
['Iris-setosa' 'Iris-versicolor' 'Iris-virginica']
?
接下來(lái),我們獲得列的名稱(chēng):
features = list(df2.columns[:4])
print("* features:", features, sep="\n")
* features:
['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']
?
用scikit-learn擬合決策樹(shù)
現(xiàn)在,我們可以使用?上面導(dǎo)入的DecisionTreeClassifier擬合決策樹(shù),如下所示:
?
?
我們使用簡(jiǎn)單的索引從數(shù)據(jù)框中提取X和y數(shù)據(jù)。
開(kāi)始時(shí)導(dǎo)入的決策樹(shù)用兩個(gè)參數(shù)初始化:min_samples_split = 20需要一個(gè)節(jié)點(diǎn)中的20個(gè)樣本才能拆分,并且?random_state = 99進(jìn)行種子隨機(jī)數(shù)生成器。
可視化樹(shù)
我們可以使用以下功能生成圖形:
?
?
從上面的scikit-learn導(dǎo)入的export_graphviz方法寫(xiě)入一個(gè)點(diǎn)文件。此文件用于生成圖形。
生成圖形?dt.png。
運(yùn)行函數(shù):
visualize_tree(dt, features)
結(jié)果?

我們可以使用此圖來(lái)了解決策樹(shù)發(fā)現(xiàn)的模式:
所有數(shù)據(jù)(所有行)都從樹(shù)頂部開(kāi)始。
考慮了所有功能,以了解如何以最有用的方式拆分?jǐn)?shù)據(jù)-默認(rèn)情況下使用基尼度量。
在頂部,我們看到最有用的條件是?PetalLength <= 2.4500。
這種分裂一直持續(xù)到
拆分后僅具有一個(gè)類(lèi)別。
或者,結(jié)果中的樣本少于20個(gè)。
?
決策樹(shù)的偽代碼
最后,我們考慮生成代表學(xué)習(xí)的決策樹(shù)的偽代碼。
目標(biāo)名稱(chēng)可以傳遞給函數(shù),并包含在輸出中。
使用spacer_base?參數(shù),使輸出更容易閱讀。
?
應(yīng)用于iris數(shù)據(jù)的結(jié)果輸出為:
get_code(dt, features, targets)
if ( PetalLength <= 2.45000004768 ) {
return Iris-setosa ( 50 examples )
}
else {
if ( PetalWidth <= 1.75 ) {
if ( PetalLength <= 4.94999980927 ) {
if ( PetalWidth <= 1.65000009537 ) {
return Iris-versicolor ( 47 examples )
}
else {
return Iris-virginica ( 1 examples )
}
}
else {
return Iris-versicolor ( 2 examples )
return Iris-virginica ( 4 examples )
}
}
else {
if ( PetalLength <= 4.85000038147 ) {
return Iris-versicolor ( 1 examples )
return Iris-virginica ( 2 examples )
}
else {
return Iris-virginica ( 43 examples )
}
}
}
?
將其與上面的圖形輸出進(jìn)行比較-這只是決策樹(shù)的不同表示。
在python中進(jìn)行決策樹(shù)交叉驗(yàn)證
?
導(dǎo)入
首先,我們導(dǎo)入所有代碼:
from __future__ import print_function
import os
import subprocess
from time import time
from operator import itemgetter
from scipy.stats import randint
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.grid_search import GridSearchCV
from sklearn.grid_search import RandomizedSearchCV
from sklearn.cross_validation import ?cross_val_score
?
主要添加的內(nèi)容是sklearn.grid_search中的方法,它們可以:
時(shí)間搜索
使用itemgetter對(duì)結(jié)果進(jìn)行排序
使用scipy.stats.randint生成隨機(jī)整數(shù)。
現(xiàn)在我們可以開(kāi)始編寫(xiě)函數(shù)了。
包括:
get_code?–為決策樹(shù)編寫(xiě)偽代碼,
visualize_tree?–生成決策樹(shù)的圖形。
encode_target?–處理原始數(shù)據(jù)以與scikit-learn一起使用。
get_iris_data?–如果需要,從網(wǎng)絡(luò)上獲取?iris.csv,并將副本寫(xiě)入本地目錄。
?
新功能
接下來(lái),我們添加一些新功能來(lái)進(jìn)行網(wǎng)格和隨機(jī)搜索,并報(bào)告找到的主要參數(shù)。首先是報(bào)告。此功能從網(wǎng)格或隨機(jī)搜索中獲取輸出,輸出模型的報(bào)告并返回最佳參數(shù)設(shè)置。
?
網(wǎng)格搜索
接下來(lái)是run_gridsearch。該功能需要
特征X,
目標(biāo)y,
(決策樹(shù))分類(lèi)器clf,
嘗試參數(shù)字典的param_grid
交叉驗(yàn)證cv的倍數(shù),默認(rèn)為5。
param_grid是一組參數(shù),這將是作測(cè)試,要注意不要列表中有太多的選擇。
?
隨機(jī)搜尋
接下來(lái)是run_randomsearch函數(shù),該函數(shù)從指定的列表或分布中采樣參數(shù)。與網(wǎng)格搜索類(lèi)似,參數(shù)為:
功能X
目標(biāo)y
(決策樹(shù))分類(lèi)器clf
交叉驗(yàn)證cv的倍數(shù),默認(rèn)為5?
n_iter_search的隨機(jī)參數(shù)設(shè)置數(shù)目,默認(rèn)為20。
?
好的,我們已經(jīng)定義了所有函數(shù)。
交叉驗(yàn)證
獲取數(shù)據(jù)
接下來(lái),讓我們使用上面設(shè)置的搜索方法來(lái)找到合適的參數(shù)設(shè)置。首先進(jìn)行一些初步準(zhǔn)備-獲取數(shù)據(jù)并構(gòu)建目標(biāo)數(shù)據(jù):
print("\n-- get data:")
df = get_iris_data()
print("")
features = ["SepalLength", "SepalWidth",
"PetalLength", "PetalWidth"]
df, targets = encode_target(df, "Name")
y = df["Target"]
X = df[features]
-- get data:
-- iris.csv found locally
?
第一次交叉驗(yàn)證
在下面的所有示例中,我將使用10倍交叉驗(yàn)證。
將數(shù)據(jù)分為10部分
擬合9個(gè)部分
其余部分的測(cè)試準(zhǔn)確性
使用當(dāng)前參數(shù)設(shè)置,在所有組合上重復(fù)此操作,以產(chǎn)生十個(gè)模型精度估計(jì)。通常會(huì)報(bào)告十個(gè)評(píng)分的平均值和標(biāo)準(zhǔn)偏差。
print("-- 10-fold cross-validation "
"[using setup from previous post]")
dt_old = DecisionTreeClassifier(min_samples_split=20,
random_state=99)
dt_old.fit(X, y)
scores = cross_val_score(dt_old, X, y, cv=10)
print("mean: {:.3f} (std: {:.3f})".format(scores.mean(),
scores.std()),
end="\n\n" )
-- 10-fold cross-validation [using setup from previous post]
mean: 0.960 (std: 0.033)
?
0.960還不錯(cuò)。這意味著平均準(zhǔn)確性(使用經(jīng)過(guò)訓(xùn)練的模型進(jìn)行正確分類(lèi)的百分比)為96%。該精度非常高,但是讓我們看看是否可以找到更好的參數(shù)。
網(wǎng)格搜索的應(yīng)用
首先,我將嘗試網(wǎng)格搜索。字典para_grid提供了要測(cè)試的不同參數(shù)設(shè)置。
print("-- Grid Parameter Search via 10-fold CV")
dt = DecisionTreeClassifier()
ts_gs = run_gridsearch(X, y, dt, param_grid, cv=10)
-- Grid Parameter Search via 10-fold CV
GridSearchCV took 5.02 seconds for 288 candidate parameter settings.
Model with rank: 1
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 10, 'max_leaf_nodes': 5,
'criterion': 'gini', 'max_depth': None, 'min_samples_leaf': 1}
Model with rank: 2
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 20, 'max_leaf_nodes': 5,
'criterion': 'gini', 'max_depth': None, 'min_samples_leaf': 1}
Model with rank: 3
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 10, 'max_leaf_nodes': 5,
'criterion': 'gini', 'max_depth': 5, 'min_samples_leaf': 1}
?
在大多數(shù)運(yùn)行中,各種參數(shù)設(shè)置的平均值為0.967。這意味著從96%改善到96.7%!我們可以看到最佳的參數(shù)設(shè)置ts_gs,如下所示:
print("\n-- Best Parameters:")
for k, v in ts_gs.items():
print("parameter: {:<20s} setting: {}".format(k, v))
-- Best Parameters:
parameter: min_samples_split ? ?setting: 10
parameter: max_leaf_nodes ? ? ? setting: 5
parameter: criterion ? ? ? ? ? ?setting: gini
parameter: max_depth ? ? ? ? ? ?setting: None
parameter: min_samples_leaf ? ? setting: 1
?
并復(fù)制交叉驗(yàn)證結(jié)果:
#測(cè)試最佳參數(shù)
print("\n\n-- Testing best parameters [Grid]...")
dt_ts_gs = DecisionTreeClassifier(**ts_gs)
scores = cross_val_score(dt_ts_gs, X, y, cv=10)
print("mean: {:.3f} (std: {:.3f})".format(scores.mean(),
scores.std()),
end="\n\n" )
-- Testing best parameters [Grid]...
mean: 0.967 (std: 0.033)
?
接下來(lái),讓我們使用獲取最佳樹(shù)的偽代碼:
print("\n-- get_code for best parameters [Grid]:", end="\n\n")
dt_ts_gs.fit(X,y)
get_code(dt_ts_gs, features, targets)
-- get_code for best parameters [Grid]:
if ( PetalWidth <= 0.800000011921 ) {
return Iris-setosa ( 50 examples )
}
else {
if ( PetalWidth <= 1.75 ) {
if ( PetalLength <= 4.94999980927 ) {
if ( PetalWidth <= 1.65000009537 ) {
return Iris-versicolor ( 47 examples )
}
else {
return Iris-virginica ( 1 examples )
}
}
else {
return Iris-versicolor ( 2 examples )
return Iris-virginica ( 4 examples )
}
}
else {
return Iris-versicolor ( 1 examples )
return Iris-virginica ( 45 examples )
}
}
?
我們還可以制作決策樹(shù)的圖形:
visualize_tree(dt_ts_gs, features, fn="grid_best")
?

隨機(jī)搜索的應(yīng)用
接下來(lái),我們嘗試使用隨機(jī)搜索方法來(lái)查找參數(shù)。在此示例中,我使用288個(gè)樣本,以便測(cè)試的參數(shù)設(shè)置數(shù)量與上面的網(wǎng)格搜索相同:
?
與網(wǎng)格搜索一樣,這通常會(huì)找到平均精度為0.967或96.7%的多個(gè)參數(shù)設(shè)置。如上所述,最佳交叉驗(yàn)證的參數(shù)為:
print("\n-- Best Parameters:")
for k, v in ts_rs.items():
print("parameters: {:<20s} setting: {}".format(k, v))
-- Best Parameters:
parameters: min_samples_split ? ?setting: 12
parameters: max_leaf_nodes ? ? ? setting: 5
parameters: criterion ? ? ? ? ? ?setting: gini
parameters: max_depth ? ? ? ? ? ?setting: 19
parameters: min_samples_leaf ? ? setting: 1
?
并且,我們可以再次測(cè)試最佳參數(shù):
#測(cè)試最佳參數(shù)
)
-- Testing best parameters [Random]...
mean: 0.967 (std: 0.033)
?
要查看決策樹(shù)是什么樣的,我們可以生成偽代碼以獲得最佳隨機(jī)搜索結(jié)果
并可視化樹(shù)
visualize_tree(dt_ts_rs, features, fn="rand_best")
?

結(jié)論
因此,我們使用了帶有交叉驗(yàn)證的網(wǎng)格和隨機(jī)搜索來(lái)調(diào)整決策樹(shù)的參數(shù)。在這兩種情況下,從96%到96.7%的改善都很小。當(dāng)然,在更復(fù)雜的問(wèn)題中,這種影響會(huì)更大。最后幾點(diǎn)注意事項(xiàng):
通過(guò)交叉驗(yàn)證搜索找到最佳參數(shù)設(shè)置后,通常使用找到的最佳參數(shù)對(duì)所有數(shù)據(jù)進(jìn)行訓(xùn)練。
傳統(tǒng)觀點(diǎn)認(rèn)為,對(duì)于實(shí)際應(yīng)用而言,隨機(jī)搜索比網(wǎng)格搜索更有效。網(wǎng)格搜索確實(shí)花費(fèi)的時(shí)間太長(zhǎng),這當(dāng)然是有意義的。
此處開(kāi)發(fā)的基本交叉驗(yàn)證想法可以應(yīng)用于許多其他scikit學(xué)習(xí)模型-隨機(jī)森林,邏輯回歸,SVM等。

最受歡迎的見(jiàn)解
1.從決策樹(shù)模型看員工為什么離職
2.R語(yǔ)言基于樹(shù)的方法:決策樹(shù),隨機(jī)森林
3.python中使用scikit-learn和pandas決策樹(shù)
4.機(jī)器學(xué)習(xí):在SAS中運(yùn)行隨機(jī)森林?jǐn)?shù)據(jù)分析報(bào)告
5.R語(yǔ)言用隨機(jī)森林和文本挖掘提高航空公司客戶(hù)滿(mǎn)意度
6.機(jī)器學(xué)習(xí)助推快時(shí)尚精準(zhǔn)銷(xiāo)售時(shí)間序列
7.用機(jī)器學(xué)習(xí)識(shí)別不斷變化的股市狀況——隱馬爾可夫模型的應(yīng)用
8.python機(jī)器學(xué)習(xí):推薦系統(tǒng)實(shí)現(xiàn)(以矩陣分解來(lái)協(xié)同過(guò)濾)
9.python中用pytorch機(jī)器學(xué)習(xí)分類(lèi)預(yù)測(cè)銀行客戶(hù)流失