コンサルでデータサイエンティスト

仕事でPythonを書いてます。機械学習、Webマーケティングに興味があります。趣味は旅です。

Pythonで決定木分類器をフルスクラッチで実装してみた

機械学習モデルをスクラッチから実装しようと思い立ったので、第一歩として決定木分類器(Decision Tree Classifer) Pythonで実装してみました。RandomForestやXGBoostなどといった決定木系の機械学習アルゴリズムを使う場面も多いと思うので、その基礎となる決定木分類器について実装をしながら仕組みを理解していきます。

決定木とは

決定木とは木構造を用いて分類または回帰のモデルを構築する手法です。変数を基準に条件分岐を行い、末端のノード(葉ノード)にたどりつくまで分岐をたどることで、分類クラスや回帰の結果を決定するものです。木が深くなるほど過学習が起きやすくなってしまうため、学習時には少ない分岐でより高い識別能力を持つ木を構築することが求められます。

分割の基準となる説明変数とその閾値が優先度の高いものから順に並ぶため、分類の解釈がしやすいといったメリットがあります。また、変数の重要度 (Feature importance) を算出することで、どの説明変数が識別に寄与しているかということを定量的に把握できるという特長もあります。

f:id:hktech:20181002225736p:plain
参考: wikipedia 決定木


決定木の学習を実装

Pythonを使って決定木の学習をスクラッチから実装していきます。スクラッチとはいえnumpyを使用しているという点についてはご容赦ください。決定木アルゴリズムにはいくつかの種類がありますが、今回は2分木を前提とし、ノードの分割条件を求める評価基準としてGini係数を使うCARTアルゴリズムで実装を進めます。実装はPython3で、numpyおよびscikit-learnがインストールされていることを前提としています。手早く環境構築をしたい方はanacondaを使いましょう。

hktech.hatenablog.com

Gini係数による不純度の計算

決定木を構築する際の最も重要な指標として不純度 (impurity) というものがあります。これは、あるノードに分岐したサンプルのクラスがどれほど散らばっているかということを表します。分岐させた結果、あるノードに1種類のクラスだけが分類された場合(きれいに分割できた場合)、そのノードの不純度は0となります。このように、不純度を求める方法としては交差エントロピーとGini係数がありますが、今回はCARTアルゴリズムに従ってGini係数を使います。
ノードtにおけるクラス C_iの割合を P(C_i|t)とすると、Gini係数は次の式を使って求めることができます。

 1-\sum_{i=1}^{K}P^2(C_i|t)

よくわからないので具体例を使って説明をしていきます。
文系・理系のサンプルが各200人ずついるとし、数学または国語の点数でこれらを分類したいとします。

f:id:hktech:20181005002646p:plain


より識別性能が高いものを木の上のほうに配置したいため、分岐条件AとBの不純度をそれぞれ求めます。

分割条件Aの ’yes’ グループのGini係数は
【Gini(yes)】1-((\frac{40}{170})^2+(\frac{130}{170})^2) = 0.36

分割条件Aの ’no’ グループのGini係数は
【Gini(no)】 1-((\frac{160}{230})^2+(\frac{70}{230})^2) = 0.42

分割ルールAによる不純度は次のよう二求めることができます。
 \frac{170}{400} \times \mathrm{Gini(yes)} + \frac{230}{400} \times \mathrm{Gini(no)}=0.39


同様に、分割条件Bの不純度を計算すると0.49となりました。
分割条件Aのほうが不純度が低いため、分割条件Aのほうが識別能力が高い分岐条件であることがわかります。Gini係数を計算することで、ノードを分割する最適な説明変数および閾値を求めていき、決定木を構築していきます。

Gini係数を求める関数は次のように実装しました。

def gini_score(data, target, feat_idx, threshold):
    gini = 0
    sample_num = len(target)
   
    div_target = [target[data[:, feat_idx] >= threshold], target[data[:, feat_idx] < threshold]]
   
    for group in div_target:
        score = 0
        classes = np.unique(group)
        for cls in classes:
            p = np.sum(group == cls)/len(group)
            score += p * p
        gini += (1- score) * (len(group)/sample_num)
    return gini

ノードを分割する最適な説明変数と閾値の選択

Gini係数が小さい分岐条件を見つけることで、識別能力の高い木が構築できるということがわかりました。今回は、最適な分岐条件をみつけるためにすべての説明変数について閾値を変化させていき、Gini係数が最小となる説明変数と閾値のペアをノードの分岐条件とする方針で実装していきます。
ノードに流れてくるdataおよびそれらに付与されたクラスラベル targetを渡すことで、変数giniが最小となるときの閾値best_thrsおよび説明変数best_fを返り値として得ることができます。

def search_best_split(data, target):   
    features = data.shape[1]
    best_thrs = None
    best_f = None
    gini = None
    gini_min = 1
 
    for feat_idx in range(features):
        values = data[:, feat_idx]
        for val in values:
            gini = gini_score(data, target, feat_idx, val)
            if gini_min > gini:
                gini_min = gini
                best_thrs = val
                best_f = feat_idx
    return gini_min, best_thrs, best_f    

再帰関数を使った木の構築

木の構築方法はいくつか方法が考えられますが、ここではオブジェクトと再帰関数を使った実装を行います。まずDecisionTreeNodeクラスを定義し、splitメソッドを用いることでノードに対応するDecisionTreeNodeの子オブジェクトを再帰的に生成していきます。停止条件としては、以下の2通りを定義しています。

  • ノードのGini係数が0になったとき(ノードに1つのクラスのサンプルしかないため、以降分岐の必要がない)
  • 木の深さがあらかじめ定義したmax_depthに達したとき


各ノードはleftとrightというクラス変数を持ち、ここに子ノードのオブジェクトが格納されていくため、仮想的に木を構築することができます。

def split(self, depth):
    self.depth = depth
   
    self.gini_min, self.threshold, self.feature = search_best_split(self.data, self.target)
    print('Depth: {}, Sep at Feature: {},Threshold: {}, Label: {}'.format(self.depth, self.feature, self.threshold, self.label))
   
   if self.depth == self.max_depth or self.gini_min == 0:
        return
   
    idx_left = self.data[:, self.feature] >= self.threshold
    idx_right = self.data[:, self.feature] < self.threshold
   
    self.left = DecisionTreeNode(self.data[idx_left],  self.target[idx_left], self.max_depth)
    self.right = DecisionTreeNode(self.data[idx_right], self.target[idx_right], self.max_depth)
    self.left.split(self.depth +1)
    self.right.split(self.depth +1)

DecisionTreeNodeクラス全体は次のように実装しました。

class DecisionTreeNode(object):
    def __init__(self, data, target, max_depth):
        self.left = None
        self.right = None
        self.max_depth = max_depth
        self.depth = None
        self.data = data
        self.target = target
        self.threshold = None
        self.feature = None
        self.gini_min = None
        self.label = np.argmax(np.bincount(target))
   
    def split(self, depth):
        self.depth = depth
        self.gini_min, self.threshold, self.feature = search_best_split(self.data, self.target)
        print('Depth: {}, Sep at Feature: {},Threshold: {}, Label: {}'.format(self.depth, self.feature, self.threshold, self.label))
       
        if self.depth == self.max_depth or self.gini_min == 0:
            return       
        idx_left = self.data[:, self.feature] >= self.threshold
        idx_right = self.data[:, self.feature] < self.threshold
   
        self.left = DecisionTreeNode(self.data[idx_left],  self.target[idx_left], self.max_depth)
        self.right = DecisionTreeNode(self.data[idx_right], self.target[idx_right], self.max_depth)
        self.left.split(self.depth +1)
        self.right.split(self.depth +1)
 
    def predict(self, data):
        if self.gini_min == 0.0 or self.depth == self.max_depth:
            return self.label
        else:
            if data[self.feature] > self.threshold:
                return self.left.predict(data)
            else:
               return self.right.predict(data)



決定木のテストを実装

さきほど作成したDecisionTreeNodeクラスに、テスト時の予測メソッドを実装しました。各ノードのオブジェクトには、best_splitで求めた最適な説明変数と閾値が格納されているためこれを使用します。先ほどの停止条件と同様に、Gini係数が0または深さがmax_depthに達した場合は末端ノード(葉ノード)となるため、ここでラベルを返します。これはノードに含まれているサンプルのうち最も多いクラスを返しており、これが予測ラベルとなります。葉ノードに達していない場合は、閾値に従って分岐を進んでいき、末端に達した時点で予測ラベルを返すという実装になります。

def predict(self, data):
    if self.gini_min == 0.0 or self.depth == self.max_depth:
        return self.label
    else:
        if data[self.feature] > self.threshold:
            return self.left.predict(data)
        else:
           return self.right.predict(data)

自作Pythonコード

全体のPythonコードは次の通りです。scikit-learnのメソッドにならって、fitとpredictを持つクラスDesicionTreeClassifierを定義しました。

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
 
def gini_score(data, target, feat_idx, threshold):
    gini = 0
    sample_num = len(target)
   
    div_target = [target[data[:, feat_idx] >= threshold], target[data[:, feat_idx] < threshold]]
   
    for group in div_target:
        score = 0
        classes = np.unique(group)
        for cls in classes:
            p = np.sum(group == cls)/len(group)
            score += p * p
        gini += (1- score) * (len(group)/sample_num)
    return gini
 
def search_best_split(data, target):   
    features = data.shape[1]
    best_thrs = None
    best_f = None
    gini = None
    gini_min = 1
 
    for feat_idx in range(features):
        values = data[:, feat_idx]
        for val in values:
            gini = gini_score(data, target, feat_idx, val)
            if gini_min > gini:
                gini_min = gini
                best_thrs = val
                best_f = feat_idx
    return gini_min, best_thrs, best_f       
 
class DecisionTreeNode(object):
    def __init__(self, data, target, max_depth):
        self.left = None
        self.right = None
        self.max_depth = max_depth
        self.depth = None
        self.data = data
        self.target = target
        self.threshold = None
        self.feature = None
        self.gini_min = None
        self.label = np.argmax(np.bincount(target))
   
    def split(self, depth):
        self.depth = depth
        self.gini_min, self.threshold, self.feature = search_best_split(self.data, self.target)
        print('Depth: {}, Sep at Feature: {},Threshold: {}, Label: {}'.format(self.depth, self.feature, self.threshold, self.label))
       
        if self.depth == self.max_depth or self.gini_min == 0:
            return       
        idx_left = self.data[:, self.feature] >= self.threshold
        idx_right = self.data[:, self.feature] < self.threshold
   
        self.left = DecisionTreeNode(self.data[idx_left],  self.target[idx_left], self.max_depth)
        self.right = DecisionTreeNode(self.data[idx_right], self.target[idx_right], self.max_depth)
        self.left.split(self.depth +1)
        self.right.split(self.depth +1)
 
    def predict(self, data):
        if self.gini_min == 0.0 or self.depth == self.max_depth:
            return self.label
        else:
            if data[self.feature] > self.threshold:
                return self.left.predict(data)
            else:
               return self.right.predict(data)
 
class DesicionTreeClassifier(object):
    def __init__(self, max_depth):
        self.max_depth = max_depth
        self.tree = None
   
    def fit(self, data, target):
        initial_depth = 0
        self.tree = DecisionTreeNode(data, target, 3)
        self.tree.split(initial_depth)
   
    def predict(self, data):
        pred = []
        for s in data:
            pred.append(self.tree.predict(s))
        return np.array(pred)
 
if __name__ == '__main__':
iris = load_iris()
data = iris.data
target = iris.target
 
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.3, random_state=0)
 
clf = DesicionTreeClassifier(max_depth=3)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
score = sum(y_pred == y_test)/float(len(y_test))
print('Classification accuracy: {}'.format(score))




自作コードを使った分類結果

フルスクラッチで書いたコードを使って、irisデータセットをもとに学習および評価を行いました。
学習した決定木は次のようになりました。可視化していないのでわかりにくいですが、深さ1の右側の木(一番下の行)はそれ以降分岐がないためGini係数が0であったことがわかります。このときの予測ラベルは0番のクラスです。また、逆側のブランチは深さ3まで伸びています。

Depth: 0, Sep at Feature: 2,Threshold: 3.0, Label: 2
Depth: 1, Sep at Feature: 2,Threshold: 5.0, Label: 2
Depth: 2, Sep at Feature: 2,Threshold: 5.1, Label: 2
Depth: 3, Sep at Feature: 0,Threshold: 6.5, Label: 2
Depth: 3, Sep at Feature: 0,Threshold: 6.7, Label: 2
Depth: 2, Sep at Feature: 3,Threshold: 1.7, Label: 1
Depth: 3, Sep at Feature: 1,Threshold: 3.2, Label: 2
Depth: 3, Sep at Feature: 0,Threshold: 5.0, Label: 1
Depth: 1, Sep at Feature: 0,Threshold: 4.8, Label: 0



最後に、正解率を求めることで分類器が機能していることを確認しました。
学習データ:テストデータ=7:3、max_depth=3としたときの識別精度は以下の通りでした。

Classification accuracy: 0.977

いい感じに分類できていることがわかりますね。

まとめ

Pythonで決定木分類器をフルスクラッチで実装してみました。とても単純なアルゴリズムかと思いきや、再帰関数などを使って実装を工夫する必要があったためなかなか大変でした。決定木アルゴリズムの中でもポピュラーなRandomForestなども実装してみたいと思います。