機械学習エンジニアの備忘録

主に自分が勉強したことのメモ

アンサンブル学習 アルゴリズム入門 〜〜決定木その2〜〜

前回記事からの続きになります。
rikeiin.hatenablog.com

前回の記事で作成したDecisionStumpは、深さが1の決定木でしたが、同じ方法によるデータの分割を再帰的に行えば、より深い階層を持つ決定木が作成できます。

この記事では先ほど作成したDecisionStumpの応用として、木の深さを指定できる決定木アルゴリズムを作成します。
深さが可変の決定木アルゴリズムはDecisionTreeという名前のクラスとして作成します。このクラスは先ほど作成したDecisionStumpクラスの派生クラスとして作成することで、決定木アルゴリズムに必要となる木分割の関数をDecisionStumpクラスから継承して利用できるようにします。
まずはdtree.pyという名前のファイルを作成し次のクラスを作成します。このクラスには学習させる決定木の深さを表すmax_depth変数と再帰的に生み出される際に使用する現在のノードの深さを表すdepth変数を作成します。

import numpy as np
import support
import entropy
from zeror import ZeroRule
from linear import Linear
from dstump import DecisionStump

class DecisionTree(DecisionStump):
    def __init__(self, max_depth=5, metric=entropy.gini, leaf=ZeroRule, depth=1):
        super().__init__(metric=metric, leaf=leaf)
        self.max_depth = max_depth
        self.depth = depth

このDecisionStumpは決定木内の一つのノードを表しており、葉となるノード自分自身のクラスで置き換えることで、高さが可変の決定を作成します。
それには次のようにfit()をオーバーライドし、現在のノードの深さが最大深さに達していないならば左右の葉をget_node()から取得する新しいノードで置き換えます。
新しく子ノードとなるDecisionTreeでは引数のdepthに現在のノードの深さに1を加えた値を入れることで、現在のノードの深さを増やしていきます。
現在のノードの深さが最初に指定したmax_depthに達するとDecisionTreeクラスの動作はDecisionStumpと同じになり、左右の端に対して学習を行います。

def fit(self, x, y):
    # create leaf node of left and right
    self.left = self.leaf()
    self.right = self.leaf()
    # split data into left and right node
    left, right = self.split_tree(x, y)

    if self.depth < self.max_depth:
        if len(left) > 0:
            self.left = self.get_node()
        if len(right) > 0:
            self.right = self.get_node()

    # learn left and right node
    if len(left) > 0:
        self.left.fit(x[left], y[left])
    if len(right) > 0:
        self.right.fit(x[right], y[right])
    return self

新しいノードを生成して返すget_node()は以下の様になります。

def get_node(self):
    return DecisionTree(max_depth=self.max_depth, metric=self.metric,
                        leaf=self.leaf, depth=self.depth + 1)

DecisionTreeでは、推論を行うpredict()は親クラスのDecisionTreeがそのまま使えます。

以上でDecisionTreeの実装ができました。
検証用データに対して学習して評価するコードは下記にあるので参考にしてみてください。

dtree.py全体のコードはこちら

次の記事ではプルーニングのアルゴリズムを実装していきます。
rikeiin.hatenablog.com

参考文献

作ってわかる! アンサンブル学習アルゴリズム入門

作ってわかる! アンサンブル学習アルゴリズム入門