🔍

【AOJ解説(python)】ALDS1_8_C 二分探索木の実装

2023/09/02に公開

本記事ではAizu Online Judgeより、ALDS1_8_Cの考え方と実際の解答をpythonで解説します。問題のリンク先は下記となります。
https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ALDS1_8_C&lang=ja

方針

二分探索木の実装に関する問題です。まず二分探索木がどのような条件を満たす木であるかを確認します。
その後、条件を満たすよう実装していきますが、実装には関数で行う方法とクラスで行う方法の2通りが考えられます。
今回はクラスでの実装を試みたいと思います。

二分探索木の特徴

二分探索木とは、下記特徴を持つ木のことです。
1.親が持つ子ノードは必ず2個以下。
2. 左の子<親<右の子が必ず成り立つ。
3.左部分木<親<右部分木が必ず成り立つ。

上記を満たすような二分探索木の実装を試みます。

二分探索木の実装 -事前準備-

通常の木構造と同様に、以下のようにNodeと木のクラスを準備します。
BinarySearchTreeクラスでは、root(親ノード)を保持する事とし、かつinsertfind関数を備えるものとします。

class Node:
    def __init__(self, key: int) -> None:
        self.key = key
        self.left = None
        self.right = None

class BinarySearchTree:
    def __init__(self) -> None:
        self.root = None
    def insert(self):
        pass
    def find(self):
        pass
    def delete(self):
        pass

# 呼び出し例
if __name__ == "__main__":
    binary_tree = BinarySearchTree()
    binary_tree.insert(1)
    binary_tree.insert(2)
    binary_tree.find(2)
    binary_tree.delete(3)

二分探索木の実装 -insert-

それでは、まずinsert関数の実装から着手します。
引数で与えられた値のノードを、木の中のどこに挿入すればよいかを考えていきます。

木にノードが一つも無い場合

ノードが一つもない場合、新規追加するノードを親ノードに設定して完了です。

木にノードが複数ある場合

下記の再帰ロジックに基づいて実装を行います。
大小比較の箇所で左の子<親<右の子が必ず成り立つという二分探索木の条件を満たしています。さらに再帰で子ノードに対してもそれぞれ大小比較を行うことで、左部分木<親<右部分木が必ず成り立つという条件も満たすよう実装しています。

再帰ロジック

引数
 node: 確認対象のノード
 key: ノードとして挿入する値
現在ノードの有無確認 
 有り:処理なし
 無し:keyを値としたノードを返す
大小比較
 key < nodeの値:左の子nodeを再帰で確認
 key > nodeの値:右の子nodeを再帰で確認
上記の確認完了
 確認対象のノードを返す

コードで示すと以下の通りとなります。

def insert(self, key: int) -> None:
    if self.root is None:
        self.root = Node(key)
        return
    def _insert(node: Node, key: int) -> Node:
        # 現在ノードの有無確認
        if node is None:
            return Node(key)
        # 大小比較
        if key < node.key:
            node.left = _insert(node.left, key)
        else:
            node.right = _insert(node.right, key)
        # 上記の確認完了
        return node
    _insert(self.root, key)

補足1:インナー関数を用いた呼び出しについて

上記のコードでは、insert関数の中にインナー関数である_insert関数を定義しています。insert関数では、根ノードが有る場合に再帰を呼び出す処理を記載し、_insert関数では呼び出されて実行される再帰処理の記載をしています。このように書き分ける事で、再帰部分への外部からのアクセスを防いでいます。(カプセル化)
なお、insertの処理は必ず木の根ノード(root)から行いたいため、_insert(self.root, key)として_insert()の第一引数にBinarySearchTreeクラスで保持しているroot属性のノードを指定しています。

補足2:戻り値の設定について

戻り値returnの設定について補足します。現在ノードが無い場合はNode(key)を返し、そうでない場合は大小比較を行ったのちにnodeを返しています。
このようにすることで、①挿入箇所が見つかるまで再帰を繰り返し、挿入箇所が見つかった場合はNode(key)を返して再帰を遡るという処理を実装しています。そして、➁再帰を遡る際、戻り値にnodeを指定することで木の復元を実現しています。

先ほどの図で示すと下記の通りとなります。
1.再帰2までは確認対象ノードが存在するため、大小比較を行って再帰を繰り返す。

2.再帰3の時点で確認対象ノードが存在しなくなる(挿入箇所が見つかる)ため、Node(key)であるノード7を再帰2へ返す。ここから再帰を遡る。 ※①の箇所

3.再帰2へ遡るとnodeが指し示すのはノード5であるため、node.rightに再帰3の戻り値であるノード7を入れる。大小比較が終了するため、戻り値としてnodeのノード5を再帰1へ返す。

4.再帰1へ遡るとnodeが指し示すのはノード8であるため、node.leftに再帰2の戻り値であるノード5を入れる。
※➁の箇所: ノード5をノード8の左子ノードへ入れている箇所が該当。これにより元の木を復元している

二分探索木の実装 -find-

続いて、find関数の実装します。基本的なロジックはinsert関数と同様です。
戻り値がtrue/falsebool値となっている点が主な変更点となります。
なお、戻り値であるtrue/falseを呼び出し階層を遡って元のfind関数へ返すため、それぞれの戻り値をreturnで返しています。

def find(self, key) -> bool:
    def _find(node: Node, key: int) -> bool:
        if node is None:
            return False
        if key == node.key:
            return True
        elif key < node.key:
            return _find(node.left, key)
        else:
            return _find(node.right, key)
    return _find(self.root, key)

二分探索木の実装 -delete-

次に、delete関数を実装します。
対象となるノードを再帰により探索するロジックは、上述のinsertfindと変わりありません。ただし、deleteの場合は、ノードを削除した後の処理に工夫が必要です。

以下の場合分けで考えていきます。
1.削除対象ノードに子ノードが一つもない場合
この場合は、単純に対象ノードを削除して完了です。

2.削除対象ノードに子ノードが一つだけある場合
削除対象ノードに子ノードが一つだけある場合、対象ノードを削除した後に子ノードを削除ノードの場所へ移す必要があります。

図に記載の通り、子ノードが一つだけである場合、その子ノードの下にさらに孫ノードが付いていたとしても対応は変わりません。

3.削除対象ノードに子ノードが二つある場合
削除ノードに子ノードが二つある場合は、対象ノードを削除した後に右部分木の中で最小となる子ノードを削除ノードの場所へ移すという処理が必要になります。

対象ノードを削除した後、最小となる子ノードを選択する理由は図の通りです。削除ノードにあるノードを移した後も、木全体が二分探索木の条件を満たす必要がある事がポイントとなります。

上記を踏まえてコードに落とし込んだものが下記となります。
再帰処理を行う_delete関数の引数は_insert_findと同様で、それぞれ node: 確認対象のノード、key: 削除対象のノードの値となります。

def min_key(self, node) -> Node:
    current_node = node
    while current_node.left is not None:
        current_node = current_node.left
    return current_node

def delete(self, key) -> None:
    def _delete(node: Node, key: int) -> Node:
        # 確認対象ノードが見つからなかった場合
        if node is None:
            return node
        # keyの値を持つノードを探索
        if key < node.key:
            node.left = _delete(node.left, key)
        elif key > node.key:
            node.right = _delete(node.right, key)
        # keyの値を持つノードが見つかった場合
        else:
            # 子ノードが一つもない場合:if node.left内を実行。node.right→Noneを返す
            # 子ノードが一つの場合:子ノードを削除対象ノードkeyの場所へ代入
            if node.left is None:
                return node.right
            elif node.right is None:
                return node.left
            # 子ノードが二つの場合
            # 右部分木の最小値ノードを取得
            tmp = self.min_key(node.right)
            # 削除するノードの値へ最小値ノードの値をコピー
            node.key = tmp.key
            # 右部分木内に残った最小値ノードを削除して、削除対象ノードkeyの場所へ移す
            node.right = _delete(node.right, tmp.key)
        return node
    _delete(self.root, key)

二分探索木の実装 -print-

最後に、print関数を実装します。
こちらは先行順巡回(preorder)と中間順巡回(inorder)の節点の列を出力する箇所となります。
先行順巡回と中間順巡回については、下記記事にて説明しているのでご確認頂けると幸いです。
https://zenn.dev/usma11dia0/articles/solve-alds-1-7-d

なお、本問題における実装例は以下の通りとなります。

def print(self) -> str:
    #中間順巡回(inorder)を出力
    def inorder(node: Node) -> None:
        if node is not None:
            inorder(node.left)
            output.append(str(node.key))
            inorder(node.right) 
    #先行順巡回(preorder)を出力
    def preorder(node: Node) -> None:
        if node is not None:
            output.append(str(node.key))
            preorder(node.left)
            preorder(node.right)

解答例

最後に、上述で示したコードを一つにまとめて解答を完成させます。

import sys

class Node:
    def __init__(self, key) -> None:
        self.key = key
        self.left = None
        self.right = None

class BinarySearchTree:
    def __init__(self) -> None:
        self.root = None

    def insert(self, key: int) -> None:
        if self.root is None:
            self.root = Node(key)
            return
        def _insert(node: Node, key: int) -> Node:
            if node is None:
                return Node(key)
            if key < node.key:
                node.left = _insert(node.left, key)
            else:
                node.right = _insert(node.right, key)
            return node
        _insert(self.root, key)
    
    def find(self, key) -> bool:
        def _find(node: Node, key: int) -> bool:
            if node is None:
                return False
            if key == node.key:
                return True
            elif key < node.key:
                return _find(node.left, key)
            else:
                return _find(node.right, key)
        return _find(self.root, key)

    def min_key(self, node) -> Node:
        current_node = node
        while current_node.left is not None:
            current_node = current_node.left
        return current_node

    def delete(self, key) -> None:
        def _delete(node: Node, key: int) -> Node:
            if node is None:
                return node
            if key < node.key:
                node.left = _delete(node.left, key)
            elif key > node.key:
                node.right = _delete(node.right, key)
            else:
                if node.left is None:
                    return node.right
                elif node.right is None:
                    return node.left
                tmp = self.min_key(node.right)
                node.key = tmp.key
                node.right = _delete(node.right, tmp.key)
            return node
        _delete(self.root, key)

    def print(self) -> str:
        def inorder(node: Node) -> None:
            if node is not None:
                inorder(node.left)
                output.append(str(node.key))
                inorder(node.right) 
        def preorder(node: Node) -> None:
            if node is not None:
                output.append(str(node.key))
                preorder(node.left)
                preorder(node.right)
        
        output = []
        inorder(self.root)
        inorder_output = " ".join(output)
        output = []
        preorder(self.root)
        preorder_output = " ".join(output)

        return " " + inorder_output + "\n" + " " + preorder_output

m = int(input())
T = BinarySearchTree()
output = []

for _ in range(0, m):
    command = sys.stdin.readline().split()
    if command[0] == "insert":
        T.insert(int(command[1]))
    elif command[0] == "find":
        if T.find(int(command[1])):
            output.append("yes")
        else:
            output.append("no")
    elif command[0] == "delete":
        T.delete(int(command[1]))
    elif command[0] == "print":
        output.append(T.print())

print("\n".join(output))

参考

https://www.udemy.com/course/python-algo/
https://www.momoyama-usagi.com/entry/info-algo-tree

Discussion