機械学習エンジニアの備忘録

主に自分が勉強したことのメモ

アンサンブル学習 アルゴリズム入門 〜〜プルーニング〜〜

前回の記事では、決定木による機械学習プログラムを作成しました。

rikeiin.hatenablog.com

この決定木では、作成する木の深さに上限がなくパラメータの設定を深くすればするほど複雑なモデルが作成できます。なので、単に学習データをよく再現できるモデルを作成するという観点からは、決定木アルゴリズムはメモリと計算時間が許す限り、学習データを完全に再現できるモデルを作成できます。しかし、これだと学習データに過学習してしまい、学習していないデータに対する性能が低下してしまいます。そのため、いかに過学習を防止し、汎化誤差を少なくするかが問題となりますが、本記事では決定木アルゴリズムにおける過学習を防止するための手法であるプルーニングを実装してみます。

決定木のプルーニング

プルーニングとは

プルーニング(枝刈り)は木構造をしたデータに対して適用されるアルゴリズムです。
前記事で作成した決定木アルゴリズムでは、作成される木は常に葉まで同じ深さを持つ完全二分木と呼ばれるデータ構造をしていました。
しかし、このようなモデルで深い木を作成すると無駄な判断を行うノードが増えてしまうので過学習が起こりやすくなります。そこで一度深い深さで決定木を作成した後で不要な枝を削除することで、よりシンプルな決定木を作成する手法がとられます。この木構造のデータから不要な枝を削除することをプルーニングと呼びます。

f:id:rikeiin:20191022235901p:plain
深さが2で葉の数が4つある決定木からプルーニングによって葉Dが取り除かれる例

プルーニングのルール

プルーニングのアルゴリズムでは決定木を走査して削除すべき枝と残すべき枝を選択します。どの枝を削除するべきかの判断がプルーニングアルゴリズムのキモとなるのですが、最も単純なルールは「枝を削除しても結果が変わらないならばその枝を削除する」というものになります。すなわちその枝を削除する前後でモデルを実行した結果のスコアが悪化しないならばその枝を削除するということになります。プルーニングの際に使用するスコアの計算には決定木に対する学習データと同じデータを使用する方法と学習データを決定木の学習用とプルーニング用のテストデータに二分する方法があります 。

再帰関数によるプルーニング

上記のルールをプログラムで実装すると、決定木のノードをたどりながら、次の2つのケースでノード内の枝を削除すればプルーニングのアルゴリズムを実装することができあす。

  1. 学習データのすべてが左右のノードどちらか一方に振り分ける場合
  2. ノード内どちらかの枝を削除してもノード全体のスコアが悪化しない場合

図に表すと以下のようになります。

f:id:rikeiin:20191104152704p:plain

決定木のノードをたどるには再帰関数を仕様し、与えられたノードの枝を関数内で都度更新することでプルーニングを行います。
本記事で作成する再帰関数の基本的な構造は次のようになります。

def プルーニング関数(決定木内のノード, 学習データ):
    if 左右にデータが振り分けられるか:
        # 1つの枝のみの場合、その枝を置き換える
        return プルーニング関数(枝, 学習データ)
    # 再帰呼び出しで枝を辿る
    左の枝 = プルーニング関数(左の枝, 学習データ)
    右の枝 = プルーニング関数(右の枝, 学習データ)
    # 枝刈りを行う
    if 枝刈りを行うべき:
        return ノード内の残す方の枝
    return 決定木内のノード

関数の中では一つのノードを見て再帰的にそのノードの左右の枝を更新していきます。再帰関数は枝刈り行う場合は残された枝を、そうでない場合は現在のノードを返します。そして再帰的に呼び出される関数の戻り値でノートの枝を更新することで、不要な枝を削除して行きます。

プルーニングの実装

実際にプルーニングを行うアルゴリズムを実装していきます。
ここでは、前記事で作成した決定木と同じ構造を持つPrunedTreeというクラスがあるものとして再帰関数を実装していきます。

スコアによる判定

再帰関数の名前はreducederrorとして、引数として決定木内の現在のノードを表すnode, 学習データと正解データを表すxとyを引数にとります。
まずはreducederror関数のひな形として、ノードが葉でないことを確認して、現在のノードを返す関数を実装します。

def reducederror(node, x, y):
    # ノードが葉でなかったら
    if isinstance(node, PrunedTree):
        # ここにプルーニングの処理が入る
        
    # 現在のノードを返す
    return node

次に上記の「# ここにプルーニングの処理が入る」という部分に学習データに対するノードのスコアと左右の端のスコアを計算する部分を作成します。
この関数では決定木の種類がクラス分類であるかは回帰であるかによって処理を分けますが、それは目的変数の次元数で判断できます。
クラス分類の場合、ノードそれ自体と左右の葉に対してpredict関数を呼び出した結果から、間違いの個数をカウントします。
回帰の場合は、同様に正解データとの二乗誤差を作成して、その値をスコアとします。
その後、枝を削除してもスコアが悪化しないようならスコアが良い方の枝を返すようにします。

# ここに枝刈りのコードが入る
# calculate score with train data
p1 = node.predict(x)
p2 = node.left.predict(x)
p3 = node.right.predict(x)
# if classification
if y.shape[1] > 1:
    # socre as a number of misclassifications
    ya = y.argmax(axis=1)
    d1 = np.sum(p1.argmax(axis=1) != ya)
    d2 = np.sum(p2.argmax(axis=1) != ya)
    d3 = np.sum(p3.argmax(axis=1) != ya)
else:
    # score as mean squared error
    d1 = np.mean((p1 - y) ** 2)
    d2 = np.mean((p2 - y) ** 2)
    d3 = np.mean((p3 - y) ** 2)

if d2 <= d1 or d3 <= d1:  # score is not worse with which left or right
    # return node with better score
    if d2 < d3:
        return node.left
    else:
        return node.right

決定木のプルーニング

上記の「# ここに枝刈りのコードが入る」という箇所ではノードに左右の枝が両方あるかどうかをチェックし、一つの枝のみにデータが振り分けられるのであれば、現在のノードをその枝で置き換えます。
そしてその後、関数を再帰的に呼び出して左右の枝を更新します。

feat = x[:, node.feat_index]
val = node.feat_val
l, r = node.make_split(feat, val)
if val is np.inf or len(r) == 0:
    return reducederror(node.left, x, y)
elif len(l) == 0:
    return reducederror(node.right, x, y)

# update the branch of right and left
node.left = reducederror(node.left, x[l], y[l])
node.right = reducederror(node.right, x[r], y[r])

以上で「Reduce Error」プルーニングのための再帰関数が完成します。

Critical Valueによるプルーニング

先ほど作成したReduce Errorプルーニングは、実際の実行結果をもとに処理を行うので、シンプルのアルゴリズムながら性能の良い結果を得ることができるという特徴があります。
一方でReduce Errorプルーニングは、プルーニングの際に木とその各枝に対して毎回決定木の実行を行われるため、処理時間の面で不利になるという欠点があります。決定木に対するプルーニングの処理については他にもさまざまなものがありますが、ここではプルーニングのアルゴリズムとして、もう一つ Critical Value プルーニングというアルゴリズムを実装してみます。

Critical Valueの概要

Critical Value プルーニングは決定木の学習時に使用した分割スコアを元にプルーニングの処理を行います。
決定木の学習ではすべてのノードにおいてMetrics関数の値から分割のスコアが一度求められていました。そこでプルーニングの処理を行う際には決定木全体の中での最も良い分割スコアを求め、その値をもとにある程度以下の値で分割されたノードを枝刈りすることでプルーニングの処理を行います。また削除するノードのしきい値は全ての枝のスコアから削除する枝の割合を指定することで求めます。
Critical Value プルーニングは決定木の分割のみを完成した時点で行うことができるので、必ずしも全ての葉を学習する必要はありません。葉の学習をプルーニングの後に行うことで深い階層の決定木を作成しても学習の時間が指数関数的に増加することを防ぐことができます。また決定木の分割が完成した時点でCritical Valueプルーニングを行い、さらに葉の学習が終了したらReduce Errorプルーニングを行うといったこともできます 。

f:id:rikeiin:20191109123744p:plain
Critical Valueプルーニングの例

Critical Valueの実装

それでは実際にCritical Valueを行うコードの実装をしていきます。先程と同じくpruning.pyの中に書いていきます。

全ノードのスコア

まずは決定木全体から分割時に使用した最も良いメトリクス関数の値を求めるために、すべてのノードのスコアを取得する関数を作成します。
前々回の記事でDecisionStumpクラスを作成した際に、分割で求めた良いメトリクス関数の値はそのノードの分割のスコアとしてスコア変数に保存しておきました。本記事で作成するPrunedTreeクラスもDecisionStumpクラスの派生クラスなので、同じくスコア変数から分割のスコアを取得することができます。ここで気をつけておくことは、このスコア変数が表すスコアは値が小さいほど良い値だという点です。全てのノードをたどるには先ほどと同じく再帰関数を使用し、引数に与えられたリストにノードのスコア変数の値を追加していきます。ここではgetscore()という名前ですべてのノードのスコアを取得する関数を作成します。

def getscore(node, score):
    if isinstance(node, PrunedTree):
        if node.score >= 0 and node.score is not np.inf:
            score.append(node.score)
        getscore(node.left, score)
        getscore(node.right, score)
Critical Valueを行う関数

次に、実際にプルーニングを行う関数criticalscore()を実装します。この関数も再帰関数として実装し、前節と同じ構造をしています。
今回は引数としてしきい値となるscore_maxが追加されています。

def criticalscore(node, score_max):
    if type(node) is PrunedTree:
        # pruning process
        # update the branch of right and left
        node.left = criticalscore(node.left, score_max)
        node.right = criticalscore(node.right, score_max)
        # delete node
        if node.score > score_max:
            leftisleaf = not isinstance(node.left, PrunedTree)
            rightisleaf = not isinstance(node.right, PrunedTree)
            # leave one leaf if both are leaf
            if leftisleaf and rightisleaf:
                return node.left
            # leave branch if which one is leaf
            elif leftisleaf and not rightisleaf:
                return node.right
            elif not leftisleaf and rightisleaf:
                return node.left
            # leave node with better score if both are branch
            elif node.left.score < node.right.score:
                return node.left
            else:
                return node.right

プルーニング用決定木クラスの実装

先程までは前記事で作成した決定木と同じ構造を持つPrunedTreeというクラスがあるものとして、再帰関数を実装していました。
ここではそのPrunedTreeクラスを実装していきます

プルーニング用の決定木の作成

PrunedTreeクラスは前記事で作成したDecisionTreeクラスの派生クラスとして作成します。本記事で使用するPrunedTreeクラスでは、クラス内の変数でプルーニング用の関数名を表すprunfncと、学習データとプルーニング用のテストデータを別にするかどうかを表すpruntestを作成します。ここでprunfnc変数は、文字列型で"reduce"または"critical"いずれかの値が入るものとします。またプルーニング用のテストデータを別にするかどうかを表すpruntest、 テストデータの割合を表すsplitratio、Critical Value プルーニングで使用するパーセンテージの変数であるcriticalも作成します。

class PrunedTree(DecisionTree):
    def __init__(self, prunfnc='critical', pruntest=False, splitratio=0.5, critical=0.8,
                 max_depth=5, metric=entropy.gini, leaf=ZeroRule, depth=1):
        super().__init__(max_depth=max_depth, metric=metric, leaf=leaf, depth=depth)
        self.prunfnc = prunfnc # プルーニング用関数
        self.pruntest = pruntest # プルーニング用にテストデータを取り分けるか
        self.splitratio = splitratio # プルーニング用テストデータの割合
        self.critical = critical # criticalプルーニング用のしきい値

新しいノードを作成するためのget_node()も次のようにオーバーライドしておきます。

def get_node(self):
    return PrunedTree(prunfnc=self.prunfnc, pruntest=self.pruntest, splitratio=self.splitratio, critical=self.critical,
                        max_depth=self.max_depth, metric=self.metric, leaf=self.leaf, depth=self.depth + 1)

次に学習を行うfit()関数をオーバーライドします。 fit()関数内ではまず根のノードとそうでない時で処理が異なり、根のノードの時のみプルーニング用の処理が行われます。これはPrunedTreeクラスは決定木内のノードの一つを表すのに対して、プルーニング用の関数が学習後の決定木に対して再帰的に呼び出されるためです。PrunedTreeクラスで根のノードの時には、まずプルーニングの際に枝の削除を行うかどうかを判断するためのテストデータを用意します。テストデータはself.pruntestがTrueであれば学習データからランダムにself.splitratioで指定された割合のデータをテストデータとして取り分けます。self.pruntestがFalseの場合は学習データと同じデータを使用してプルーニングの処理を行います。

def fit(self, x, y):
    # if depth=1, root node
    if self.depth == 1 and self.prunfnc is not None:
        # data for pruning
        x_t, y_t = x, y

        if self.pruntest:
            n_test = int(round(len(x) * self.splitratio))
            n_idx = np.random.permutation(len(x))
            tmpx = x[n_idx[n_test:]]
            tmpy = y[n_idx[n_test:]]
            x_t = x[n_idx[:n_test]]
            y_t = y[n_idx[:n_test]]
            x = tmpx
            y = tmpy

        # ここで決定木の学習を行う

        return self
            

上記の「 # ここで決定木の学習を行う」の部分には以下のコードが入ります。前記事でのDecisionTreeの学習アルゴリズムとほぼ同じですが、Critical Valueプルーニングの場合は葉の学習は行わず、木の分割のみ学習します。

# 決定木の学習
self.left = self.leaf()
self.right = self.leaf()
left, right = self.split_tree(x, y)
if self.depth < self.max_depth:
    self.left = self.get_node()
    self.right = self.get_node()
if self.depth < self.max_depth or self.prunfnc != 'critical':
    if len(left) > 0:
        self.left.fit(x[left], y[left])
    if len(right) > 0:
        self.right.fit(x[right], y[right])

# ここでプルーニングの処理を行う。
            

上記の「# ここでプルーニングの処理を行う」には根のノードの時にself.prunfncに入っている関数の名前から、再帰関数を呼び出してプルーニングの処理を行います。Reduce Errorプルーニングの場合、再帰関数の引数には自分自身のインスタンスとプルーニング用のテストデータを渡してreducederror()関数を呼び出します。また Critical Value プルーニングの場合、まずgetscore()関数ですべてのノードの分割の際のスコアを取得し、その数からパラメーターで指定された割合を求め、しきい値となる値を計算します。そしてそのしきい値を引数にcriticalscore()関数呼び出し、プルーニングの処理を行った後、学習させていなかった葉について改めて学習を行います。葉のみ学習を行う関数はfit_leaf()という名前で作成します。

# pruning process
# only whene depth = 1, root node
if self.depth == 1 and self.prunfnc is not None:
    if self.prunfnc == 'reduce':
        reducederror(self, x_t, y_t)
    elif self.prunfnc == 'critical':
        # get score of metrics function when training
        score = []
        getscore(self, score)
        if len(score) > 0:
            # calculate max score of branch left
            i = int(round(len(score) * self.critical))
            score_max = sorted(score)[min(i, len(score) - 1)]
            # pruning
            criticalscore(self, score_max)

        # learn leaf
        self.fit_leaf(x, y)
            

fit_leaf()関数は、すでに作成されている枝に従ってデータを分割し、枝の指しているノードが葉であればfit()関数を呼び出します。

def fit_leaf(self, x, y):
    feat = x[:, self.feat_index]
    val = self.feat_val
    l, r = self.make_split(feat, val)

    # learn only leaf
    if len(l) > 0:
        if isinstance(self.left, PrunedTree):
            self.left.fit_leaf(x[l], y[l])
        else:
            self.left.fit(x[l], y[l])
    if len(r) > 0:
        if isinstance(self.right, PrunedTree):
            self.right.fit_leaf(x[r], y[r])
        else:
            self.right.fit(x[r], y[r])
            

以上でPrunedTreeの実装ができました。
検証用データに対して学習して評価するコードは下記にあるので参考にしてみてください。

ensumble-learning-introduction/pruning.py at master · wdy06/ensumble-learning-introduction · GitHub

参考文献

作ってわかる! アンサンブル学習アルゴリズム入門

作ってわかる! アンサンブル学習アルゴリズム入門