🕌

heapqを使うときの注意点

2024/12/22に公開

この記事の経緯

ゼロから作るDeep Learning 3のフレームワーク編での読者への宿題でheapqに関して詰まったので記載しておきます。書籍を持っている方向けに書くと、p.195です。

問題概要

ざっくりと問題を抽象化すると、

class HogeHoge: 
    def __init__(self, val):
        self.value = value

hogehoge_list: List[HogeHoge] = []  

このhogehoge_listは、都度append()とpop()がされます。
ただし、pop()するのは最もvalueが大きいものとします。

書籍では都度sortをしているのですが、読者向けへの宿題として優先度つきキューであるheapqを使えば効率よく取得できるとあります。

最初に考えたこと

heapqを使えばいいと書いてあるので、素朴に以下のようなことを考えました。

import heapq

hogehoge_list = []

# appendするとき
heapq.heappush(hogehoge_list, (-hogehoge.value, hogehoge) #大きいものから取得したいので-を付ける。

# popするとき
_, hogehoge = heapq.heappop(hogehoge_list) # hogehogeだけほしい。

heapqの仕様では、リストの要素がタプルの1番目の要素に対する順序が小さいものから順に格納されます。(今回の場合はhogehoge.value) そのため、これでうまくいくと思ったのですがこれでは、うまくいきませんでした。

なぜなら、タプルの1番目の要素が同じもの同士は、タプルの2番目の要素で比較するからです。
今回のケースの場合、タプルの2番目の要素はHogeHogeクラスのインスタンスで順序が定義されていないのでエラーがでます。(実際は、Functionクラスという関数のクラスを定義しています。)

どうしたか

HogeHogeクラスには順序が定義できないので、タプルに順序を入れることにしました。
つまり、以下のようにすればいいです。

class TupleOrder:

    def __init__(self, value, hogehoge):
        self.value = value # 比較に使う値
        self.hogehoge = hogehoge # 比較には使わない

    def __lt__(self, other):
        return self.value < other.value # hogehogeは使わない  

このようにすることで、冒頭の問題は以下のようにして解けました。

import heapq

hogehoge_list = []

# appendするとき
heapq.heappush(hogehoge_list, TupleOrder(-hogehoge.value, hogehoge))

# popするとき
tupleorder = heapq.heappop(hogehoge_list)
hogehoge = tupleorder.hogehoge # hogehogeだけ取り出す。

感想

データ構造を雰囲気で使ってると、痛い目にあいますね・・・

Discussion