【AOJ解説(python)】ALDS1_8_C 二分探索木の実装
本記事ではAizu Online Judgeより、ALDS1_8_Cの考え方と実際の解答をpythonで解説します。問題のリンク先は下記となります。
方針
二分探索木の実装に関する問題です。まず二分探索木がどのような条件を満たす木であるかを確認します。
その後、条件を満たすよう実装していきますが、実装には関数で行う方法とクラスで行う方法の2通りが考えられます。
今回はクラスでの実装を試みたいと思います。
二分探索木の特徴
二分探索木とは、下記特徴を持つ木のことです。
1.親が持つ子ノードは必ず2個以下。
2. 左の子<親<右の子が必ず成り立つ。
3.左部分木<親<右部分木が必ず成り立つ。
上記を満たすような二分探索木の実装を試みます。
二分探索木の実装 -事前準備-
通常の木構造と同様に、以下のようにNodeと木のクラスを準備します。
BinarySearchTree
クラスでは、root
(親ノード)を保持する事とし、かつinsert
やfind
関数を備えるものとします。
class Node:
def __init__(self, key: int) -> None:
self.key = key
self.left = None
self.right = None
class BinarySearchTree:
def __init__(self) -> None:
self.root = None
def insert(self):
pass
def find(self):
pass
def delete(self):
pass
# 呼び出し例
if __name__ == "__main__":
binary_tree = BinarySearchTree()
binary_tree.insert(1)
binary_tree.insert(2)
binary_tree.find(2)
binary_tree.delete(3)
二分探索木の実装 -insert-
それでは、まずinsert
関数の実装から着手します。
引数で与えられた値のノードを、木の中のどこに挿入すればよいかを考えていきます。
木にノードが一つも無い場合
ノードが一つもない場合、新規追加するノードを親ノードに設定して完了です。
木にノードが複数ある場合
下記の再帰ロジックに基づいて実装を行います。
大小比較の箇所で左の子<親<右の子が必ず成り立つという二分探索木の条件を満たしています。さらに再帰で子ノードに対してもそれぞれ大小比較を行うことで、左部分木<親<右部分木が必ず成り立つという条件も満たすよう実装しています。
再帰ロジック
・引数
node: 確認対象のノード
key: ノードとして挿入する値
・現在ノードの有無確認
有り:処理なし
無し:keyを値としたノードを返す
・大小比較
key < nodeの値:左の子nodeを再帰で確認
key > nodeの値:右の子nodeを再帰で確認
・上記の確認完了
確認対象のノードを返す
コードで示すと以下の通りとなります。
def insert(self, key: int) -> None:
if self.root is None:
self.root = Node(key)
return
def _insert(node: Node, key: int) -> Node:
# 現在ノードの有無確認
if node is None:
return Node(key)
# 大小比較
if key < node.key:
node.left = _insert(node.left, key)
else:
node.right = _insert(node.right, key)
# 上記の確認完了
return node
_insert(self.root, key)
補足1:インナー関数を用いた呼び出しについて
上記のコードでは、insert
関数の中にインナー関数である_insert
関数を定義しています。insert
関数では、根ノードが有る場合に再帰を呼び出す処理を記載し、_insert
関数では呼び出されて実行される再帰処理の記載をしています。このように書き分ける事で、再帰部分への外部からのアクセスを防いでいます。(カプセル化)
なお、insertの処理は必ず木の根ノード(root)から行いたいため、_insert(self.root, key)
として_insert()の第一引数にBinarySearchTree
クラスで保持しているroot
属性のノードを指定しています。
補足2:戻り値の設定について
戻り値return
の設定について補足します。現在ノードが無い場合はNode(key)
を返し、そうでない場合は大小比較を行ったのちにnode
を返しています。
このようにすることで、①挿入箇所が見つかるまで再帰を繰り返し、挿入箇所が見つかった場合はNode(key)
を返して再帰を遡るという処理を実装しています。そして、➁再帰を遡る際、戻り値にnode
を指定することで木の復元を実現しています。
先ほどの図で示すと下記の通りとなります。
1.再帰2までは確認対象ノードが存在するため、大小比較を行って再帰を繰り返す。
2.再帰3の時点で確認対象ノードが存在しなくなる(挿入箇所が見つかる)ため、Node(key)
であるノード7を再帰2へ返す。ここから再帰を遡る。 ※①の箇所
3.再帰2へ遡るとnode
が指し示すのはノード5であるため、node.right
に再帰3の戻り値であるノード7を入れる。大小比較が終了するため、戻り値としてnode
のノード5を再帰1へ返す。
4.再帰1へ遡るとnode
が指し示すのはノード8であるため、node.left
に再帰2の戻り値であるノード5を入れる。
※➁の箇所: ノード5をノード8の左子ノードへ入れている箇所が該当。これにより元の木を復元している
二分探索木の実装 -find-
続いて、find
関数の実装します。基本的なロジックはinsert
関数と同様です。
戻り値がtrue/false
のbool
値となっている点が主な変更点となります。
なお、戻り値であるtrue/false
を呼び出し階層を遡って元のfind関数へ返すため、それぞれの戻り値をreturn
で返しています。
def find(self, key) -> bool:
def _find(node: Node, key: int) -> bool:
if node is None:
return False
if key == node.key:
return True
elif key < node.key:
return _find(node.left, key)
else:
return _find(node.right, key)
return _find(self.root, key)
二分探索木の実装 -delete-
次に、delete
関数を実装します。
対象となるノードを再帰により探索するロジックは、上述のinsert
やfind
と変わりありません。ただし、delete
の場合は、ノードを削除した後の処理に工夫が必要です。
以下の場合分けで考えていきます。
1.削除対象ノードに子ノードが一つもない場合
この場合は、単純に対象ノードを削除して完了です。
2.削除対象ノードに子ノードが一つだけある場合
削除対象ノードに子ノードが一つだけある場合、対象ノードを削除した後に子ノードを削除ノードの場所へ移す必要があります。
図に記載の通り、子ノードが一つだけである場合、その子ノードの下にさらに孫ノードが付いていたとしても対応は変わりません。
3.削除対象ノードに子ノードが二つある場合
削除ノードに子ノードが二つある場合は、対象ノードを削除した後に右部分木の中で最小となる子ノードを削除ノードの場所へ移すという処理が必要になります。
対象ノードを削除した後、最小となる子ノードを選択する理由は図の通りです。削除ノードにあるノードを移した後も、木全体が二分探索木の条件を満たす必要がある事がポイントとなります。
上記を踏まえてコードに落とし込んだものが下記となります。
再帰処理を行う_delete
関数の引数は_insert
や_find
と同様で、それぞれ node
: 確認対象のノード、key
: 削除対象のノードの値となります。
def min_key(self, node) -> Node:
current_node = node
while current_node.left is not None:
current_node = current_node.left
return current_node
def delete(self, key) -> None:
def _delete(node: Node, key: int) -> Node:
# 確認対象ノードが見つからなかった場合
if node is None:
return node
# keyの値を持つノードを探索
if key < node.key:
node.left = _delete(node.left, key)
elif key > node.key:
node.right = _delete(node.right, key)
# keyの値を持つノードが見つかった場合
else:
# 子ノードが一つもない場合:if node.left内を実行。node.right→Noneを返す
# 子ノードが一つの場合:子ノードを削除対象ノードkeyの場所へ代入
if node.left is None:
return node.right
elif node.right is None:
return node.left
# 子ノードが二つの場合
# 右部分木の最小値ノードを取得
tmp = self.min_key(node.right)
# 削除するノードの値へ最小値ノードの値をコピー
node.key = tmp.key
# 右部分木内に残った最小値ノードを削除して、削除対象ノードkeyの場所へ移す
node.right = _delete(node.right, tmp.key)
return node
_delete(self.root, key)
二分探索木の実装 -print-
最後に、print関数を実装します。
こちらは先行順巡回(preorder)と中間順巡回(inorder)の節点の列を出力する箇所となります。
先行順巡回と中間順巡回については、下記記事にて説明しているのでご確認頂けると幸いです。
なお、本問題における実装例は以下の通りとなります。
def print(self) -> str:
#中間順巡回(inorder)を出力
def inorder(node: Node) -> None:
if node is not None:
inorder(node.left)
output.append(str(node.key))
inorder(node.right)
#先行順巡回(preorder)を出力
def preorder(node: Node) -> None:
if node is not None:
output.append(str(node.key))
preorder(node.left)
preorder(node.right)
解答例
最後に、上述で示したコードを一つにまとめて解答を完成させます。
import sys
class Node:
def __init__(self, key) -> None:
self.key = key
self.left = None
self.right = None
class BinarySearchTree:
def __init__(self) -> None:
self.root = None
def insert(self, key: int) -> None:
if self.root is None:
self.root = Node(key)
return
def _insert(node: Node, key: int) -> Node:
if node is None:
return Node(key)
if key < node.key:
node.left = _insert(node.left, key)
else:
node.right = _insert(node.right, key)
return node
_insert(self.root, key)
def find(self, key) -> bool:
def _find(node: Node, key: int) -> bool:
if node is None:
return False
if key == node.key:
return True
elif key < node.key:
return _find(node.left, key)
else:
return _find(node.right, key)
return _find(self.root, key)
def min_key(self, node) -> Node:
current_node = node
while current_node.left is not None:
current_node = current_node.left
return current_node
def delete(self, key) -> None:
def _delete(node: Node, key: int) -> Node:
if node is None:
return node
if key < node.key:
node.left = _delete(node.left, key)
elif key > node.key:
node.right = _delete(node.right, key)
else:
if node.left is None:
return node.right
elif node.right is None:
return node.left
tmp = self.min_key(node.right)
node.key = tmp.key
node.right = _delete(node.right, tmp.key)
return node
_delete(self.root, key)
def print(self) -> str:
def inorder(node: Node) -> None:
if node is not None:
inorder(node.left)
output.append(str(node.key))
inorder(node.right)
def preorder(node: Node) -> None:
if node is not None:
output.append(str(node.key))
preorder(node.left)
preorder(node.right)
output = []
inorder(self.root)
inorder_output = " ".join(output)
output = []
preorder(self.root)
preorder_output = " ".join(output)
return " " + inorder_output + "\n" + " " + preorder_output
m = int(input())
T = BinarySearchTree()
output = []
for _ in range(0, m):
command = sys.stdin.readline().split()
if command[0] == "insert":
T.insert(int(command[1]))
elif command[0] == "find":
if T.find(int(command[1])):
output.append("yes")
else:
output.append("no")
elif command[0] == "delete":
T.delete(int(command[1]))
elif command[0] == "print":
output.append(T.print())
print("\n".join(output))
参考
Discussion