Python競技プログラミング:CFFIを利用してPyPyで爆速のordered setとmultisetを使う
急いでる人向け
- C言語をwrapして作ったC言語並の速度のordered set(とmultiset)をPyPyで利用できます。
- バイナリ提出を使用するので注意してください。AtCoderでは使えますがCodeForcesでの使用は規約に抵触する可能性があります。
- ここにあるコードを改変すれば(コピペはRawボタンをクリックし、
Ctrl+a
で全選択がおススメです。コピーしたコードはAtCoderのコードテストにペーストして一番上まで移動すると見やすいと思います。下2000行の部分はほぼdocumentで弄る必要はありません。)即時使えるようになります。(自分でビルドしたい方向けのやり方は後日書きます) - 使用方法は下の方の「使い方」を参照してください。
- 検証用に8問程度ACしましたが、バグがあったらすみません。
- 同様にCFFIを用いてACLと同等の機能を持つ高速なPyPy用ライブラリを作成中です。
動機
近年、機械学習人気の高まりなどを背景にPythonが脚光を浴びるようになりました。科学計算・データ解析をはじめとした幅広い分野の豊富なライブラリと、簡便な文法が強みの言語で、競技プログラミングサイトAtCoderにおいても多くの人が愛用しています。一方で、AtCoderで最もよく使われているC++と比較すると速度が出にくく、また、C++のSTLに存在するordered setが標準で備わっていないなど、競プロをやる上でいくつか不便な点があります。
速度面の課題については、JITコンパイルによる高速化機能が備わっているPyPyを使用することでほぼ解決できます[1]。一方、C/C++のような高速な言語でコーディングし、コンパイル済みのバイナリをPythonで呼び出すという、より踏み込んだ高速化の手法も存在します。この手法をAtCoder上でやる方法についてはいくつかノウハウが知られています[2]。特に事前にライブラリを用意して使う場合にこの方法は強力で、準備さえ怠らなければPyPyで同様に実装したものよりもかなり高速なコードを軽い実装コストで提出できることがあります。
私はCPython(PyPyではない普通のPython)でこの手法を時たま使っていたのですが(例えばこれがコンテスト中のACコード■ ■)、調べてみたところCFFI(公式ドキュメント)というものを使えばPyPyでも同様のことができそう(ただバイナリ提出が必要なのでCodeForcesでは使えなさそう)なので、今回は特にordered setとmultisetでやってみました。Pythonにordered setがなくて困っているという人の助けになれれば幸いです。PyPyによるpure Python部分の高速化とC言語によるボトルネック部分の高速化が合わされば、今まで使いにくかった
コード
AtCoderですぐ使えるテンプレ
C ordered set (coset)
C ordered multiset (comultiset)
テンプレの方のリンク先のコードをコピー(Rawボタンをクリックし、Ctrl+a
で全選択がおススメです。コピーしたコードはAtCoderのコードテストにペーストして一番上まで移動すると見やすいと思います。下2000行の部分はほぼdocumentで弄る必要はありません。)し、#write your code here
の部分に自分のコードを書き込んでください。AtCoderのコードテスト(PyPy3(7.3.0))で動作確認できます。手元にPyPy(7.3.0)がある人はもちろんそちらでも確認できます(おそらくversionが異なると動かないので注意)。また、他人の作ったバイナリの実行には少し抵抗がある方もいると思うので、そのような方向けに自分でビルドする方法については後日別記事に書きます。
解説
方法論など詳細な解説は後日します。今回はコードの解説に留めます。
import base64
提出用に符号化してあるバイナリを復号するためのライブラリです。
binary = b'xxxxxxxxxxxxxxxxxxxxxx...(中略)...xxxxxxxxxxx'
バイナリの本体です。提出用に符号化済です。このあたりのやり方はkyomukyomupurinさんのこちらのリポジトリを参考にしました。バイナリ自体をどうコンパイルして作ったかについては別記事に書く予定ですが、今回使用したC言語のコードはこちらにあります(C言語にとても自信があるわけではないので、問題点を見つけた方がいれば積極的にマサカリを投げていただけると幸いです)。内部実装はAVL木を採用しています[3]。
open('./_compprog_cffi.pypy36-pp73-x86_64-linux-gnu.so','wb').write(base64.b85decode(binary))
バイナリファイルを復号してAtCoderのジャッジサーバーのディレクトリ内に書き込んでいます。
from _compprog_cffi import ffi, lib
今回はCFFIというライブラリでC言語のコードをwrapしています。ffiとlibを通じてライブラリを呼び出すことができます。
# write your code here!
ここに自分のコードを入れましょう。
"""
This code was created from PyPy CFFI.
(中略 2000行ぐらい?)
"""
この部分がなくてもAtCoder上で動作します。しかし、他のユーザーの提出ファイルを読んで学習できることが競プロのよいところだと思っているので、どのようなコードに従ってバイナリが生成されたのかを載せることにしました(さすがに長すぎるしコピペするだけでもめんどくさいので、圧縮して適宜展開する形式への変更を検討しています→変更しました。2021.9.13追記)。なくても問題ない[4]と思っていますが、できれば併記していただけると幸いです。
使い方
具体的な使い方をここに書きます。64bit整数型用のordered setとmultisetです(将来的には文字列やタプルなどでも使えるようにします)。このページのローカルルールで、64bit整数型には型annotationとしてll
を使用します。int
は32bit整数型です。用途上あまりないと思いますがオーバーフローに注意しましょう(64bit整数は大体-9 * 10^18~9 * 10^18の範囲です)。時間計算量を付記しましたが、これは要素数
coset
基本
初期化
coset_init() -> [coset_ll *]
空のordered setを作ります。返り値coset_ll *
型はcoset_ll
(ordered setのC言語上の型)へのポインタですが、PyPy上では他のオブジェクトと同じような普通のオブジェクトだと思って大丈夫です。この返り値のオブジェクトをこれから紹介する関数に引数として渡し、ordered set関連のいろいろな処理をするというのが基本の流れになります。
cs = coset_init()
挿入
insert(cs: [coset_ll *], x: [ll]) -> None
ordered set cs
に整数x
を挿入します。既に整数x
がcs
内に存在している場合何もしません(計算自体は行われるので計算量的に定数倍がシビアな場合は注意)。
cs = coset_init()
insert(cs, 0) # 0を追加
insert(cs, 5) # 5を追加
insert(cs, -1) # -1を追加
削除
remove(cs: [coset_ll *], x: [ll]) -> None
ordered set cs
から整数x
を削除します。整数x
がcs
内に存在しない場合何もしません(計算自体は行われるので計算量的に定数倍がシビアな場合は注意)。
cs = coset_init()
insert(cs, 0) # 0を追加
insert(cs, 5) # 5を追加
remove(cs, 5) # 5を削除
検索
find(cs: [coset_ll *], x: [ll]) -> [cs_node_ll *]
ordered set cs
内の整数x
が存在するかどうか判定できます。ややこしいのですが、返り値はbool
ではないです。判定方法はコード例を参考にしてください。
cs = coset_init()
insert(cs, 0) # 0を追加
insert(cs, 5) # 5を追加
if find(cs, 5)==ffi.NULL: # 見つからなければNULLを返す。
print("did not find 5") # NULLか否かはffi.NULLとの比較で判定可能。
else:
print("found 5")
"""
found 5
"""
if find(cs, 4)==ffi.NULL: # 見つからなければNULLを返す。
print("did not find 4") # NULLか否かはffi.NULLとの比較で判定可能。
else:
print("found 4")
"""
did not find 4
"""
# Tips: 頻繁にNULLと比較するようならNULL = ffi.NULLのようにして直接アクセルできるようにした方が高速だと思われる。
サイズ取得
get_s(cs: [coset_ll *]) -> [int]
ordered setcs
の要素数を返します。
cs = coset_init()
if get_s(cs)==0:
print("It's empty.")
"""
It's empty.
"""
順序関係操作
最大・最小
get_max(cs: [coset_ll *]) -> [cs_node_ll *]
get_min(cs: [coset_ll *]) -> [cs_node_ll *]
ordered set cs
中の最大値及び最小値を持つノードを返します。.key
でノードが持つ値にアクセスできます。cs
が空の場合はffi.NULL
が返されます。
cs = coset_init()
insert(cs, 1)
insert(cs, 4)
insert(cs, 2)
max_node = get_max(cs)
print(max_node.key)
"""
4
"""
min_node = get_min(cs)
print(min_node.key)
"""
1
"""
bound系
lower_bound(cs: [coset_ll *], x: [ll]) -> [cs_node_ll *]
rlower_bound(cs: [coset_ll *], x: [ll]) -> [cs_node_ll *]
upper_bound(cs: [coset_ll *], x: [ll]) -> [cs_node_ll *]
rupper_bound(cs: [coset_ll *], x: [ll]) -> [cs_node_ll *]
ordered set cs
において、
-
lower_bound
:x
以上である整数の中で最小の整数のノード -
rlower_bound
:x
以下である整数の中で最大の整数のノード -
upper_bound
:x
より大きい整数の中で最小の整数のノード -
rupper_bound
:x
より小さい整数の中で最大の整数のノード
を返します。存在しない場合はffi.NULL
を返します。値へのアクセスは.key
でしましょう。
cs = coset_init()
insert(cs, 1)
insert(cs, 5)
insert(cs, 10)
print(lower_bound(cs, 5).key)
"""
5
"""
print(upper_bound(cs, 5).key)
"""
10
"""
print(rlower_bound(cs, 5).key)
"""
5
"""
print(rupper_bound(cs, 5).key)
"""
1
"""
k番目の値を取得
get_k(cs: [coset_ll *], k: [int]) -> [cs_node_ll *]
ordred set cs
中で下から数えてk番目(0-indexed)の値を持つノードを返します。この機能はC++ STLのstd::setにはない割に実装がそこまで難しくないので、あったら嬉しいと思ってつけました。indexが不正な場合、ffi.NULL
が返されます。
cs = coset_init()
insert(cs, 1)
insert(cs, 5)
insert(cs, 10)
print(get_k(cs, 0).key)
"""
1
"""
print(get_k(cs, 1).key)
"""
5
"""
print(get_k(cs, 2).key)
"""
10
"""
その他
clear(cs: [coset_ll *]) -> None
ordered set cs
を空にします。オブジェクトを破棄したいときにこれを呼ぶとメモリを解放できます。CFFIとPyPyのgarbage collectionの関係をまだいまいち理解していないのでつけましたが、PyPyが自動でやってくれる可能性があり要らないかもしれません。特に競プロでは書く必要がない気がしてきたので、覚えなくてもいいです。
comultiset
大体cosetと同じですが、違うものだけここに追記します。計算量の
基本
初期化
comultiset_init() -> [comultiset_ll *]
空のmultisetを作ります。
cms = comultiset_init()
挿入
insert(cms: [comultiset_ll *], x: [ll], int n) -> None
multiset cms
に整数x
をn個挿入します。既にx
がcs
中に存在する場合も追加されます。
cms = coset_init()
insert(cms, 0, 4) # 0を4つ追加
insert(cms, 5, 2) # 5を2つ追加
insert(cms, 5, 4) # 5を4つ追加
insert(cms, -1, 1) # -1を1つ追加
削除
remove(cms: [comultiset_ll *], x: [ll],remove_all: [bool]) -> None
multiset cms
から整数x
を削除します。このとき、remove_all
がTrue
ならば全て消去します。False
ならば1つだけ消去します[5]。
cms = comultiset_init()
insert(cms, 0, 4) # 0を4つ追加
insert(cms, 5, 2) # 5を2つ追加
insert(cms, 5, 4) # 5を4つ追加
insert(cms, -1, 1) # -1を1つ追加
remove(cms, 0, False) # 0を1つ削除
remove(cms, 5, True) # 5を全て削除
数える
count(cms: [comultiset_ll *], x: [ll]) -> [int]
multiset cms
中に整数x
がいくつ入っているかを返します。
cms = comultiset_init()
insert(cms, 0, 4) # 0を4つ追加
insert(cms, 5, 2) # 5を2つ追加
insert(cms, 5, 4) # 5を4つ追加
insert(cms, -1, 1) # -1を1つ追加
remove(cms, 0, False) # 0を1つ削除
print(count(cms, 0))
"""
3
"""
入出力関係(おまけ)
入出力もC言語で書いたものをwrapしました。おまけと書きましたが、地味にこれによってかなり速くなります。そのため、後述の性能比較では同じ土俵で比較するため、普通のPyPyの評価コードでもこれを使用することにしました。(まだ比較用のコードはできていませんが)
入力
scan() -> [ll]
整数を1つ受け取ります。
出力
prin(x: [ll]) -> None
整数を1つ出力し、改行します。
検証
いくつかの問題で検証してみました(今回使ったのはすべてmultiset、multisetが最良ではない問題もあるので、そのような問題はheapqやordered setを使う想定解よりも遅くなっているので注意)。時間の測定について、たまにジャッジの仕組み上最初のケースで長い時間かかってしまうことがあるのでそのような場合については再度提出してやり直しました。
なお、PyPyはそもそもオーバーヘッドが100 ms程度存在するので、最長ケースの時間 - 最短ケースの時間で見た方がより実態が分かるかもしれません。というわけで、下にそちらの方の数字も併記しています。
最長ケースの時間
問題名 | PyPy CFFI | C++ STL | C言語 | PyPy |
---|---|---|---|---|
ABC137 D - Summer Vacation | 171 ms (コード) | 59 ms (コード) | ||
ABC140 E - Second Sum | 189 ms (コード) | |||
ABC170 E - Smart Infants | 780 ms (コード) | 382 ms (コード) | ||
ABC212 D - Querying Multiset | 186 ms (コード) | |||
ABC217 D - Cutting Woods | 184 ms (コード) | |||
ABC217 E - Sorting Queries | 173 ms (コード) | |||
ABC218 G - Game on Tree 2 | 365 ms (コード) | |||
ARC033 C - データ構造 | 242 ms (コード) |
最長時間 - 最短時間
問題名 | PyPy CFFI | C++ STL | C言語 | PyPy |
---|---|---|---|---|
ABC137 D - Summer Vacation | 77 ms (コード) | 57 ms (コード) | ||
ABC140 E - Second Sum | 96 ms (コード) | |||
ABC170 E - Smart Infants | 690 ms (コード) | 369 ms (コード) | ||
ABC212 D - Querying Multiset | 95 ms (コード) | |||
ABC217 D - Cutting Woods | 93 ms (コード) | |||
ABC217 E - Sorting Queries | 81 ms (コード) | |||
ABC218 G - Game on Tree 2 | 284 ms (コード) | |||
ARC033 C - データ構造 | 148 ms (コード) |
ちゃんとした比較のために表の他の部分も埋めていく予定ですが、所感として他のPyPyのコードよりかなり速いです(入出力高速化の影響もかなり大きいですが大体最速クラス)。最長時間-最短時間についていえば、C/C++の提出と比べてもそこまで遅くないレベルに見えますが、他言語との比較は条件を揃えてからきちんと行いたいと思います。
課題
- バイナリ提出はAtCoderにおいて将来制限される可能性があります。他の人の提出コードを読んで勉強できる機能が競プロの魅力の一つだと思うので、コードが隠蔽されてしまうバイナリ提出は確かに微妙なところがありますね。
- C言語ベースなので開発が少し難しいです。C++ベースでwrapperを作れるcppyyというライブラリもあるのですが、少し調べた限りジャッジサーバーにもcppyyがインストールされていないと利用できなさそうです。
ジャッジサーバーにインストールされていればCFFIを使うにしてもバイナリ提出ではなくコード提出で済みますし、cppyyも使用可能になるので、次のAtCoderの言語アップデートなどで入れられるとうれしいと思うのですが、いかがでしょうか?
おまけ
heapqも実装してあります。また、同じやり方でACLと同機能のライブラリも現在作成しています(まだ完成は遠そうですが)。詳細なhow toの解説については、近いうちに書こうと思ってます。
-
PyPyを使わなくてもnumpyやnumbaによる高速化でもAtCoderで十分問題を解けます ↩︎
-
Cythonのcythonizeという機能を利用する方法(ACLのwrapに関するc_r_5さんの記事)、Python.hを使う方法(nagissさんによるC++ std::setのwrap) ↩︎
-
C++のSTLは基本的に赤黒木を採用しているようで、同じ平衡二分探索木でも少し違います。AVL木は赤黒木と比べ、検索は定数倍速めですが、挿入削除は定数倍遅めです。そんなにここの速度差が問題になる場面はないと思いますが、赤黒木も今後作る予定です。 ↩︎
-
chokudaiさんによると、将来的にはバイナリ提出自体に制限が加わる可能性があるそうで(ソース)、そこまで好ましいやり方だとは思われていないようです。ただし、Rustのローカルでビルドしたバイナリの提出ツールなども存在し、現時点では特に問題ないと思います(2021.9.9時点)。しかし、やはり上述のように、コードをトレースできた方がよいと思うので、元のコードは載せておきます。 ↩︎
-
よく考えたら消す個数を指定できた方がよいので将来的にはそうする予定です ↩︎
Discussion
リポジトリを見てもらえて、とても嬉しいです。ありがとうございます!
こちらこそとても参考になるリポジトリありがとうございます!
バイナリ提出を行うPythonでのコード例がありこの記事の方法にピッタリでした!