nまでの素数の{個数, リスト}を求める・またはPythonの並列処理を巡る冒険
なんでだか、Pythonでの並列処理、または
Pythonでの並列処理は、円周率を求めるときにChudnovskyの公式を並列化するのに使ったことがある。
さいきん「Pythonの並列処理を学ぶのに良いネタはないですかね?」という問いに、「
だから、高速にループ回すのになんでPython? とかの突っ込みは、なしの方向で。
2024年5月頃の3日間の記録。
素数とわし
素数のリストと言えば、黒歴史もある。
大学院生のとき、研究室に4月に配属されてきた4年生相手にプログラミングコンテストをやっていた。初級編の最後で、10万までだったか100万までだったかの素数のリストを何秒でつくれるか、というのをやっていたのだが……ある年あろうことか「わしに勝ったら、裸で逆立ちしてグランド一周してやる!!」と宣言してしまった。まあ若気の至りっていうかワカメのイタリアンっていうやつですかね。
そのときわしが使ったのは、いわゆる『
その女の子が顔を赤らめながら「カップ麺奢ってくれればいいです」と言ってくれたので、裸でグランド一周は免れたが……言い訳するが、ちょうど時代の変わり目で、個人向けのメモリ64KBの8bit MS-DOSマイコンではなく大学では実メモリ4MB・swapその10倍くらいの32bit UNIXワークステーション (確か使ったのはSun-4 workstationだった) の威力が使えるようになった。つまり配列を100Kだか1Mだか取っても、unsigned short
を超える数を扱っても、それなりに速い時代になったのだ。
時代は変わるのだ、という教訓だ。
この記事でも、きょうびの流行りで、ループ回したら速度の遅さには定評があるPython縛りで、またMacbook Air/Pro程度の個人向けラップトップにてプレイしてみることにする。
multiprocessingまたはthreading
Python3で並列処理をするには、multiprocessing (プロセス並列) またはthreading (スレッド並列) というライブラリがある。どちらも、オブジェクトを作って自分自身の中の呼びたい関数を渡し、スタートして待ち合わせするだけである。書式もほとんど同じなので、円周率のときには簡単にいろいろ試せた。
(スレッド版)
import threading
import time
def testthread(x: float) -> None:
print('start:', x)
time.sleep(x)
print('finish:', x)
return
if __name__ == '__main__':
th = []
th.append(threading.Thread(target = testthread,
args = (2., )))
th[0].start()
th.append(threading.Thread(target = testthread,
args = (5., )))
th[1].start()
for t in th:
t.join()
print('*')
円周率の記事にもあるが、注意点としては『メインを地べたに書かないこと』くらいである。どうやらスレッドなりプロセスなりを立てるときに、自分が実行しているのと同じスクリプトを再度読む (このときには__name__
は'__main__'
にならない) みたいなので、地べたに書くと初期化でなんかおかしなことになってエラーになってしまう。まあこれは、エラーメッセージを読んでその通りにメインを書き直せばすぐ解決する。
n までの素数を並列処理で求める
nproc
個のスレッドまたはプロセス (以下面倒なので、どちらでも当てはまるときは『プロセス』って書いちゃったりしますね) で、並列処理により
- 適当に範囲を
- 等分して
nproc
個に分ける - より細かい範囲に分けて→空いてるプロセスに与え→終わって暇になったら次の仕事を渡す
- 掛かる時間が均等になるように予測して範囲を分ける
- 等分して
- 剰余類で分ける
- ……などなど
範囲等分のサンプルは、次のようになるだろう (シングル版・マルチスレッド版・マルチプロセス版)。
範囲等分のサンプル
import numpy as np
import time
import multiprocessing
import threading
def isprime(n: int) -> None:
for i in range(2, int(np.sqrt(n)) + 1):
if n % i == 0:
return False
return True
def printprimes(start: int, end: int) -> None:
# print('entered printprimes(', start, ',', end, ')')
for i in range(start, end):
if isprime(i):
pass
## to get output, uncomment here:
# print(i, end = ' ')
# print()
if __name__ == '__main__':
MAX = 2000000
NPROCS = 10
# single process/thread
start = time.time()
printprimes(0, MAX)
print()
print(time.time() - start)
# multithreads
start = time.time()
width = MAX // NPROCS
th = []
for pn in range(NPROCS):
th.append(threading.Thread(target = printprimes,
args = (pn * width, (pn + 1) * width)))
print('thread', pn, 'starts')
th[pn].start()
for pn in range(NPROCS):
th[pn].join()
print('thread', pn, 'ended')
print(time.time() - start)
# multiprocess
start = time.time()
width = MAX // NPROCS
th = []
for pn in range(NPROCS):
th.append(multiprocessing.Process(target = printprimes,
args = (pn * width, (pn + 1) * width)))
print('thread', pn, 'starts')
th[pn].start()
for pn in range(NPROCS):
th[pn].join()
print('thread', pn, 'ended')
print(time.time() - start)
ここではシンプルに、2からprintprimes()
のコメントアウトを外せば標準出力に素数リストが出るが、並列処理に馴染みがないひとにとっては、プロセスがごちゃごちゃな順番で出力するのが逆に面白かろう。と思うが、小さい数までだと次のプロセスを起動する前にひとつのプロセスが終わってしまうので、あまり面白くない。
『素数のリスト』など見飽きているので(笑) ってか何億とかになると長いので、素数が求まったら出力せずカウントアップして『素数の個数』だけを求めればよし、と、この時点で問題を変更している(笑)。どーせ素数のリスト作らなきゃ、素数の個数なんて求められないだろ……?
これが後でまたハマる原因になる。
なぜ除数は$\sqrt{n}$まででよいか
おわり。
閑話休題。
範囲を等分するやり方は、少しプロセスの負担均等化に問題がある。この方法で
だからといって、『より細かい範囲に分ける』と、終わったプロセスの監視などで余計なことを考えないといけなくなり、またプロセス起動のオーバヘッドも大きくなってしまう。
はじめから『掛かる時間が均等になるように予測』すれば無駄はなさそうだ。範囲ごとに素数を含む割合は近似式があるが、非素数の場合に約数にたどりつくまでの時間など、考えなければいけないことが多すぎる。
というわけで、剰余類で分けてみることにした。プロセスがnproc
個のとき、プロセス番号pn
番は i % nproc == pn
なる候補数i
を受け持つことになる。
だがこれには問題がある。各プロセスごとに求まった素数の個数を表示させてみれば明らかだが (させなくてもちょっと考えりゃわかるのだがw)、例えばpn
が偶数番号のプロセスは偶数ばかりを検査することになるので、明らかに素数は見つからず無駄に帰ってくることになる。
(6n \pm 1) 法
並列処理からちょっと離れる。
素数のリストを作るとき、効率的な方法として有名なのは
なぜかというと、
ということは、中学生くらいでも気の利いた奴なら知っている。
これを使えば、割られる数の候補は
同様の手法で
だから普通はやらない(笑)。
(6n \pm 1) 法と並列処理
さて、これでプロセス番号を分けるとしたら
- プロセス数
nproc
は偶数とする -
width = (nproc // 2) * 6
と置く (『a // b
』は「a
をb
で割った整数部」「 」の意)。\lfloor a / b\rfloor -
pn
番のプロセスには- もし偶数なら
6 * (pn // 2) + 1
から始まる - もし奇数なら
6 * (pn // 2) + 5
から始まる
width
おきの数を、被除数として検査してもらう
- もし偶数なら
たとえば、nproc = 10
個だとしたら
width = 10 // 2 * 6 = 30
- プロセスごとに以下の被除数を検査してもらう
- 0番のプロセスは、
\{(1,) 31, 61, \ldots\} - 1番のプロセスは、
\{5, 35, 65, \ldots\} - 2番のプロセスは、
\{7, 37, 67, \ldots\} - ……
- 9番のプロセスは、
\{59, 89, 119, \ldots\}
- 0番のプロセスは、
とすれば、なかなか均等に割れるのではないか。
まあこれも試してみると、仕事しないで いやしてるのだが空振りで帰って来るプロセスがいる。
そりゃそうだ。たとえば上記の例では、1番と8番のプロセスは明らかに5の倍数ばかりの被除数を検査している。
一般化すると、分割数nproc // 2
が
まあプロセス分割ではなくてスレッド分割の場合 (スレッドで有効に並列化するはなしは、この後すぐ!!)、多少CPUコア数より多くても無駄時間は生じないし、空振りスレッドが早めに仕事を終えて帰ってきたら、その資源は別のスレッドに割り当てられるため、そこまで無駄にはならない。だが無駄は無駄である。
試してみたMacbook Air/Proくらいであれば (コア数8とか10とか)、スレッド数nproc = 8, 12, 16, 18, 24,...
あたりが、無駄仕事が割り振られないスレッド数ということになる。
(6n \pm 1) 法
ダブルわしは
上記は被除数を
なんで奇数でよいかというと、被除数が偶数だったら2で割り切れるから素数じゃないし (当たり前)、被除数が奇数かつ割り切れる数だったら奇数掛ける奇数だからだ。偶数に何を掛けても奇数にはならぬ。すべての(
同様に、被除数が
このことは、6で割った余りの表を書いてみればわかる (『a→』『↓b』はそれぞれ
↓b\a→ | 0 | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|---|
0 | 0 | 0 | 0 | 0 | 0 | 9 |
1 | 0 | 1 | 2 | 3 | 4 | 5 |
2 | 0 | 2 | 4 | 0 | 2 | 4 |
3 | 0 | 3 | 0 | 3 | 0 | 3 |
4 | 0 | 4 | 2 | 0 | 4 | 2 |
5 | 0 | 5 | 4 | 3 | 2 | 1 |
つまり、被除数
除数と被除数が共に
並列化についても、前記の方法で被除数を
Python3.(<=12) のGILと、subinterpreterでスレッド有効活用
以前の円周率ネタでも、なかなかthreadingを使ったマルチスレッドが速くならない、というのは経験していたが、そのときは深く考えず放置していた。まあmultiprocessingでも、数分〜数時間帰ってこないお仕事の分割なので、別にオーバヘッドは問題ではなく、CPUのコア数をうまく勘案すればプロセス分割でよいか、と思っていたからである。
ところが今回、スレッドについていろいろ調べているうちに、Python3でスレッド分割しても、ひとつのプロセス内ではロック (GIL) が掛かってしまい、スレッドが同時にふたつ以上実行されない(爆)ということが分かった。いやそんな基本的なこと知らなかったわしがアホなだけであるが。
そしてまた最近のPython3.12では、subinterpreterというものが実装されて、インタプリタごとに異なるスレッドが実行できる!! らしい……なんのこっちゃ?
基本的な書き方は、次のようなものらしい。
subinterpreterのサンプル
import time
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels
scr = '''
import time
import _xxinterpchannels as channels
print('start:', arg, chid)
time.sleep(arg)
print('finish:')
channels.send(chid, arg ** 2)
'''
if __name__ == '__main__':
ids = []
chids = []
id = interpreters.create()
ids.append(id)
chid = channels.create()
chids.append(chid)
arg = 5
interpreters.run_string(
id, scr, shared = {'arg': arg, 'chid': chid}
)
id = interpreters.create()
ids.append(id)
chid = channels.create()
chids.append(chid)
arg = 2
interpreters.run_string(
id, scr, shared = {'arg': arg, 'chid': chid}
)
for chid in chids:
print(channels.recv(chid))
for id in ids:
interpreters.destroy(id)
まだ現在のPythonでは試用版らしく、subinterpreterのライブラリは_xxsubinterpreter
なる、なんかエロい 怪しい名前になっている。
このsubinterpreterは、文字列として実行したいスクリプトを渡し (オレンジのhere document部分)、引数 (ってかシェアしたい値) をshared
で渡してやる。あと、戻り値はchannelというのを使えば (subinterpreterで実行しているスレッドからはchannels.send()
、メイン側でchannels.recv()
で) 戻って来る。
これとthreadingを組み合わせると、こうなる。
subinterpreter + threadsのサンプル
import threading
import time
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels
scr = '''
import time
import _xxinterpchannels as channels
print('start:', arg, chid)
time.sleep(arg)
print('finish:')
channels.send(chid, arg ** 2)
'''
if __name__ == '__main__':
ths = []
ids = []
chids = []
id = interpreters.create()
ids.append(id)
chid = channels.create()
chids.append(chid)
arg = 2
th = threading.Thread(target = interpreters.run_string,
args = (id, scr),
kwargs = {'shared': {'arg': arg, 'chid': chid}})
# interpreters.run_string(
# id, scr, shared = {'arg': arg, 'chid': chid}
# )
ths.append(th)
th.start()
id = interpreters.create()
ids.append(id)
chid = channels.create()
chids.append(chid)
arg = 5
th = threading.Thread(target = interpreters.run_string,
args = (id, scr),
kwargs = {'shared': {'arg': arg, 'chid': chid}})
# interpreters.run_string(
# id, scr, shared = {'arg': arg, 'chid': chid}
# )
ths.append(th)
th.start()
for th in ths:
th.join()
for chid in chids:
print(channels.recv(chid))
for id in ids:
interpreters.destroy(id)
threading.Threadにはinterpreters.run_string()
を食わせてやる。interpreters.run_string()
の引数はkwargs =
で辞書で与える必要がある。
そんなこんなで、ダブル
ダブル$(6n \pm 1)$法のスレッド並列化
import threading
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels
import time
import sys
script = '''
import _xxinterpchannels as channels
def isprime6(x: int) -> bool:
d = 5
mx = int(x ** .5)
while d <= mx:
if x % d == 0:
return False
d = d + 2
if x % d == 0:
return False
d = d + 4
return True
cnt = 0
for i in range(start, end, stps):
# print(i, end = '')
if isprime6(i):
# print('*', end = '')
cnt = cnt + 1
# print('|', end = '')
#print('\\n{}..{}: {}'.format(start, end, cnt))
channels.send(chid, cnt)
'''
# subinterpreter + threads
if __name__ == '__main__':
if len(sys.argv) != 3:
print('usage: hspmod.py MAX NPROCS', file = sys.stderr)
print(' NPROCS must be even', file = sys.stderr)
sys.exit(1)
MAX = int(sys.argv[1])
NPROCS = int(sys.argv[2])
if NPROCS % 2 != 0:
print('NPROCS must be even', file = sys.stderr)
sys.exit(1)
stt = time.time()
stps = NPROCS * 6 // 2
ids = []
chids = []
threads = []
for pn in range(NPROCS):
start = (pn // 2) * 6
if pn % 2 == 0:
start = start + 1
if start == 1:
start = start + stps
else:
start = start + 5
end = MAX + 1
id = interpreters.create()
ids.append(id)
chid = channels.create()
chids.append(chid)
print(pn, start, end, stps, chid)
thread = threading.Thread(
target = interpreters.run_string,
args = (id, script),
kwargs = {'shared':
{ 'start': start,
'end': end,
'stps': stps,
'chid': chid
}
}
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
cnt = 2 # {2, 3}
for chid in chids:
retv = channels.recv(chid)
cnt = cnt + retv
print(retv, end = ' ')
print('\n', cnt)
for id in ids:
interpreters.destroy(id)
print(time.time() - stt)
例によって、100億 (小学生かよ)
ここまでで、およその実実行時間は、Macbook Airで1千万までで数秒程度になった。
となると、わしの大好きな100億 (※ 円周率参照) が視野に入ってくる。そのくらいになると、もう少し高速化してみたいところである。
除数についてだが、
- まずはシングルスレッドで、ダブル
法求めた(6n + \{1, 5\}) のリストをつくる\sqrt{n} - 前記の剰余群にスレッド分割して、担当のスレッドにスタート数・終了数・ステップ (前記
width
)・素数リストを渡して計算してもらう- スレッド内ではスタート数から終了数までステップごとに被除数を
- 素数リストの中の被除数の平方根より小さい除数で
割り切れるかテスト
とすれば大幅に速くなるはずだ。
ただし、前者の除数リスト・後者のスレッド分割開始の数について、つなぎ目に気をつけないと素数の数をダブルカウントしてしまうなどのおそれがある。
これについては (ややこしいが)、linmax
として、2..linmax + 2
(linmax + 2
はlinmax + 6
(linmax + 6
は
で、ここで問題が生じた。各プロセスに渡す
以下長くなるが、最高速版のフルリスト。
import threading
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels
import pickle
import time
import sys
FFNAME = 'pf'
def isprime6(x: int) -> bool:
d = 5
mx = int(x ** .5)
while d <= mx:
if x % d == 0:
return False
d = d + 2
if x % d == 0:
return False
d = d + 4
return True
def primelist(end: int) -> list:
# plist = [2, 3]
plist = []
for i in range(5, end + 1, 6):
if isprime6(i):
plist = plist + [i]
if isprime6(i + 2):
plist = plist + [i + 2]
return plist
script = '''
import pickle
import _xxinterpchannels as channels
FBASE = 'p'
def isprimediv(n: int) -> bool:
mx = int(n ** .5)
for div in divlist:
if mx < div:
return True
if n % div == 0:
return False
return True
with open('plist.pkl', 'rb') as fp:
divlist = pickle.load(fp)
#print(divlist)
with open(FBASE + str(chid) + '.txt', 'w') as fp:
cnt = 0
for i in range(start, end + 1, stps):
if isprimediv(i):
print(i, file = fp)
cnt = cnt + 1
channels.send(chid, cnt)
'''
# subinterpreter + threads
def printlist(l: int, fn: str) -> None:
with open(fn, 'w') as fp:
for i in l:
print(i, file = fp)
return
if __name__ == '__main__':
if len(sys.argv) != 3:
print('usage: hspmod.py MAX NPROCS', file = sys.stderr)
print(' NPROCS must be even', file = sys.stderr)
sys.exit(1)
MAX = int(sys.argv[1])
NPROCS = int(sys.argv[2])
if NPROCS % 2 != 0:
print('NPROCS must be even', file = sys.stderr)
sys.exit(1)
r = (NPROCS // 2) % 6
if r == 1 or r == 5:
print('warning: some threads donot count any primes')
stt = time.time()
width = 6 * (NPROCS // 2)
linmax = ((int(MAX ** .5) + 5) // 6) * 6 - 1
# linmax <= sqrt(MAX), linmax % 6 == 5
# 2..(linmax + 2) are tested (by linear method) number
# (within divlist if prime).
plist = primelist(linmax)
with open ('plist.pkl', 'wb') as fp:
pickle.dump(plist, fp)
print('2..{} tested by linear'.format(linmax + 2))
print(time.time() - stt, 'sec')
printlist([2, 3] + plist, FFNAME + '.txt')
start0 = linmax + 6
# start from (linmax + 6) by multiproc method (start0 % 6 == 5).
print('{}..{} tested by multiproc'.format(start0, MAX))
ids = []
chids = []
threads = []
for pn in range(NPROCS):
start = start0 + (pn // 2) * 6 # 6n + 5
if pn % 2 == 1:
start = start + 2 # 6n + 1
id = interpreters.create()
ids.append(id)
chid = channels.create()
chids.append(chid)
print('proc #{}: {}..{} step {}, chid {}'.\
format(pn, start, MAX, width, chid))
thread = threading.Thread(
target = interpreters.run_string,
args = (id, script),
kwargs = {'shared':
{ 'start': start,
'end': MAX,
'stps': width,
'chid': chid
}
}
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
cnt = 2 + len(plist) # {2, 3}
for chid in chids:
retv = channels.recv(chid)
cnt = cnt + retv
print(retv, end = ' ')
print('\n', cnt)
for id in ids:
interpreters.destroy(id)
print(time.time() - stt, 'sec')
こんな感じになりました。
なお上のプログラムは、求めた素数をリストにして書き出して (pf.txt
、並列部分はスレッドごとにひとつのファイルp{0,...,nproc}.txt
) いるが、個数のカウントのみならず素数リストを作っているのは……次の節の理由による (このプログラムより前のベンチマークでは、求めた素数そのものは捨てて、個数のみカウント)。
いまどのへんまで求めているかは、上記のpてきとう.txt
をtail
してみればわかるので、無駄に時間を喰うモニタ出力は一切、してません。あと出力テキストファイル (1個ずつ改行) の大きさは、素数の個数を見積もる近似式 (次節) があるので、全部11桁と仮定しても5GBちょっとになると暗算して、安心して寝ますた。
Macbook Airで
(Macbook Air 8core。いろいろなnproc =
8・12・16で試してみてますが、グラフ上ではほとんど変わらん(笑))
素数の個数のカウントから素数のリストへ
そんなこんなで、時間の見積もりとかしているときに『
これを検索していたら、間違ってなんと、『
え? え? 近似じゃなくて、正確な個数を、素数そのものを求めずにカウントする方法があるの!?
ちょっと読んだだけでは理解できなかったのだが、どうやらそっちのほうが速そうだ。
ならば悔しいので『具体的な素数を求めてリストを作る』問題に、ゴールを変えちゃえ(笑)。
というわけで、100億までの素数を全部出力してみました(あほ)。これが家にあれば、ちょっと嫌なことがあっても「まあ家に帰れば素数100億まであるしな」ってなるし仕事でむかつく人に会っても「そんな口きいていいのか?私は自宅で100億までの素数の表とよろしくやってる身だぞ」ってなれる。
個数は455052511個です。
Macbook Pro (Arm M1, 16core)で、CPU load 99%♡ でぶん回して、14619秒 (4時間3分39秒) で求まりました。
ログ
2..100003 tested by linear
0.14380383491516113 sec
proc #0: 100007..10000000000 step 48, chid 0
proc #1: 100009..10000000000 step 48, chid 1
proc #2: 100013..10000000000 step 48, chid 2
proc #3: 100015..10000000000 step 48, chid 3
proc #4: 100019..10000000000 step 48, chid 4
proc #5: 100021..10000000000 step 48, chid 5
proc #6: 100025..10000000000 step 48, chid 6
proc #7: 100027..10000000000 step 48, chid 7
proc #8: 100031..10000000000 step 48, chid 8
proc #9: 100033..10000000000 step 48, chid 9
proc #10: 100037..10000000000 step 48, chid 10
proc #11: 100039..10000000000 step 48, chid 11
proc #12: 100043..10000000000 step 48, chid 12
proc #13: 100045..10000000000 step 48, chid 13
proc #14: 100049..10000000000 step 48, chid 14
proc #15: 100051..10000000000 step 48, chid 15
28441480 28437829 28439317 28440138 28438970 28439733 28439461 28439126 28441462 28437590 28441167 28440729 28441228 28441900 28441495 28441293
455052511
14618.746592998505 sec
結果のマージソート
上記はファイルが、pf.txt
と、プロセスごと (width
の剰余群ごと) のp{0,1,...}.txt
に分かれちゃっている。これをソートしなければひとつの素数リストにならないが……まず前者は後者の手前 (
したがって、前者を最終出力にコピー、それに後者をマージソートしたものをくっつければおしまい。
引数として計算に使ったのと同じ最大数
やっつけで書いたので、結構汚い。
import sys
import numpy as np
PFNAME = 'pf'
FBASE = 'p'
nmax = int(sys.argv[1])
nproc = int(sys.argv[2])
ofp = open('pout.txt', 'w')
#ofp = sys.stdout
ff = open(PFNAME + '.txt', 'r')
ns = ff.readline()
while ns:
print(ns, end = '', file = ofp)
ns = ff.readline()
ff.close()
fp = [None] * nproc
pr = np.zeros(nproc, dtype = int)
for i in range(nproc):
fp[i] = open(FBASE + str(i) + '.txt')
pr[i] = int(fp[i].readline())
opencount = nproc
while 0 < opencount:
am = np.argmin(pr)
print(pr[am], file = ofp)
ns = fp[am].readline()
if not ns:
fp[am].close()
pr[am] = nmax + 1
opencount = opencount - 1
else:
pr[am] = int(ns)
ofp.close()
Pythonでやったら遅いかな? と思ったけど、そうでもなかった。11分24秒で完了しました。
これも計算時間に含めないと、アンフェアかな?
100億までの素数 (長いのでw 頭と末尾のみ)
2
3
5
7
11
13
17
19
23
29
...
9999999769
9999999781
9999999787
9999999817
9999999833
9999999851
9999999881
9999999929
9999999943
9999999967
素数の個数のカウント (素数を計算せずに) ふたたび
実際に素数を求めずに、
しかもこれ、競技プログラミングでは定番(?) ううむ?
と思ったが、英語版Wikipediaや英語の論文は読むのが面倒くさく、日本語の説明もよくわからん。
まあ、『エラトステネスの篩』とかdynamic programmingとか、キーワードを拾い読みしていたらなんとなく分かってしまった、ってか自分で再発明した(笑)ので、以下説明。オリジナルと違ってるかもしれないし、間違ってる……ことはないだろうが、効率悪いかもしれぬ。
まず、
- 25以下の整数は25個w
- 2の倍数は素数じゃないので除外する (2は素数だけどそのはなしは後ほど)。何個あるかというと
個\lfloor 25/2 \rfloor = 12 - 同3の倍数は8個、5の倍数は5個
-
で何も残らん(笑)が、これはなんでかというと25 - (12 + 8 + 5) = 0 - 2と3の公倍数6の倍数は、2回除外されている。6の倍数は4個
- 同2と5の公倍数10の倍数は2個、3と5の公倍数15の倍数は1個
- したがってこれらを足し戻して
個残る0 + 4 + 2 + 1 = 7
- (さらに2と3と5の公倍数30の倍数は足し戻されすぎているので引き (戻し戻し) たいが、30の倍数は25以下に0個)
- これに
を加えて10個、あと1は素数じゃないので除外して9個\{2, 3, 5\}
ということになる。
あるいは、『エラトステネスの篩』法で素数のリストを作り終わったとき、取り消し線の総数が25本 (実際は下図のように22本だけど。2・3・5は消してないので)、うち7つの数はひとつの数に2本線が引かれている。それに消されない
これを一般化すると
-
以下の素数のリストを作る\sqrt{n} - はじめに
以下の整数の個数n を『カウンタ』の初期値とするn - このリストから、すべての1つ以上のコンビネーションを作る:
をコンビネーションの積としたとき、p int(n / p)
を- コンビネーションの要素数が奇数ならカウンタから引く
- 偶数ならカウンタに足す
- 最後に、素数のリストの要素数を足して、1 (整数1の分) を引く
ということでいかがだろうか。
これが速いのか……というと。
たとえば100億までの素数の個数をカウントしようとしたら、すべてに
実際には、すべてのコンビネーションを実行する必要はない。前記の25まで→
実際には、前からポインタを動かしつつ、次の数のポインタを動かしつつ、……をすべて掛け合わせて
したがって、せっかくPythonなのにitertools.combinations()
とかは使ってはいけない。DPによって打ち切ることなくすべての組み合わせを延々数え始めるので、ほんとに
ここで、ふとあることに気づく。再帰の段数は大丈夫?
Pythonの再帰可能な段数は、確か以前に (なんちゃって関数型プログラミングとかやってみたとき) 500段くらいだったか、結構浅いところでエラーが出た (※ sys.setrecursionlimit()
である程度変えられる)。
100億までの素数は上記の通り最大9592段ループできなければダメ……ってわけじゃなくて、最初から
1 2
2 6
3 30
4 210
5 2310
6 30030
7 510510
8 9699690
9 223092870
10 6469693230
11 200560490130
なんだ、小さい方から11個で素数の積が100億超えるじゃん……11段くらいの再帰でいいのね。
まあ、いったん再帰な頭が完成したら、再帰を使わないループで簡単に書き直したりできるのだが。
あと、計算量も
(関数primelist(int: n) -> list
は、さっきと同じなので略)
import sys
import time
def countcombi(level: int, start: int, lprod: int) -> None:
global count
# global lplist
for pp in range(start, len(plist)):
div = lprod * plist[pp]
if mx < div:
return
# print('prod(', lplist + [plist[pp]], ') =', div)
if level % 2 == 0:
count = count - mx // div
# print('-', mx // div, '=>', count)
else:
count = count + mx // div
# print('+', mx // div, '=>', count)
# lplist = lplist + [plist[pp]]
countcombi(level + 1, pp + 1, div)
# lplist = lplist[:-1]
if __name__ == '__main__':
if len(sys.argv) != 2:
print('usage: countprimes MAX', file = sys.stderr)
sys.exit(1)
mx = int(sys.argv[1])
stt = time.time()
plist = primelist(int(mx ** .5))
# print('plist:', plist)
count = mx
# lplist = []
countcombi(0, 0, 1)
# print('count:', count, '+', len(plist), '- 1 =>', end = ' ')
count = count + len(plist) - 1
print(count)
print(time.time() - stt)
関数countcombi(level: int, start: int, lprod: int) -> None
は、
再帰的にコンビネーションを生成し、積を計算し、倍数の個数を足し引きする。
level
は再帰レベルだが、すなわちこれまでいくつの素数を掛け算したかを示している。start
はplist[]
中のスタート要素の番号で、この番号〜最後までを走査することになる。lprod
はこれまで (つまりstart
の左側のlevel
個の素数) の積 (キャッシュみたいなもの) で、いちいちlevel + 1
個の要素を掛け算するのを防ぐ。なおデバッグプリント (プログラム中のコメントアウトしているprint()
) のためにlplist[]
にこれまでかけ合わせたの素数のリストを作ってみているが、ただ個数をカウントするだけならlprod
があるので必要ない。
デバッグプリントを有効にして、100までの素数を求めるとこんな感じ。
100までの素数の個数を求める詳細
python3 countprimes.py 100
plist: [2, 3, 5, 7]
prod( [2] ) = 2
- 50 => 50
prod( [2, 3] ) = 6
+ 16 => 66
prod( [2, 3, 5] ) = 30
- 3 => 63
prod( [2, 3, 7] ) = 42
- 2 => 61
prod( [2, 5] ) = 10
+ 10 => 71
prod( [2, 5, 7] ) = 70
- 1 => 70
prod( [2, 7] ) = 14
+ 7 => 77
prod( [3] ) = 3
- 33 => 44
prod( [3, 5] ) = 15
+ 6 => 50
prod( [3, 7] ) = 21
+ 4 => 54
prod( [5] ) = 5
- 20 => 34
prod( [5, 7] ) = 35
+ 2 => 36
prod( [7] ) = 7
- 14 => 22
count: 22 + 4 - 1 => 25
さて実行時間だが、次のような感じになる。なおオリジナルのMeissel-Lehmerアルゴリズムの計算量はこんなところで考察してくれているひとがいる (やっぱり競技プログラミング筋のようだ)。
んで100億はというと……。
% python3 countprimes.py 10000000000
455052511
511.4769809246063
Macbook Proで、8分32秒ほどで求まりました。
やっぱ数えるだけのほうが、断然速い……。
(紫はPython・緑は下の『おまけ』のC言語版、Macbook Pro (ARM M1))
ここまで来たら、再帰の最初のほうのレベルで範囲を10人くらいの子供プロセスに分割して並列……とか思ったのだが、先に下のC言語への書き直しをして爆速体験をしてしまったので、もう萎えた(笑)。
おまけ: C言語で書き直してみた
カウントするだけのほう、ただC言語に直してみました。
C言語版
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <stdbool.h>
long long mx;
long *plist;
long long count;
bool
isprime6(long x)
{
long d = 5;
long mx = sqrt(x);
while (d <= mx) {
if (x % d == 0) {
return false;
}
d = d + 2;
if (x % d == 0) {
return false;
}
d = d + 4;
}
return true;
}
void
primelist(long end, long *plist, int plsize)
{
plist[0] = 2L;
plist[1] = 3L;
int ptr = 2;
for (long i = 5; i < end + 1 && ptr < plsize; i = i + 6) {
if (isprime6(i)) {
plist[ptr++] = i;
}
if (isprime6(i + 2L)) {
plist[ptr++] = i + 2L;
}
}
plist[ptr] = -1L;
return;
}
void
countcombi(int level, int start, long long lprod)
{
long long div;
for (int pp = start; plist[pp] != -1L; pp++) {
div = lprod * (long long)plist[pp];
if (mx < div) {
return;
}
if (level % 2 == 0) {
count = count - mx / div;
} else {
count = count + mx / div;
}
countcombi(level + 1, pp + 1, div);
}
return;
}
int
main(int argc, char **argv)
{
int lenplist;
if (argc != 2) {
fprintf(stderr, "usage: countprimes MAX\n");
exit(1);
}
mx = atoll(argv[1]);
int plsize = (int)((float)mx / log((float)mx) * 1.01);
if ((plist = (long *)malloc(sizeof(long) * (size_t)plsize)) == NULL) {
fprintf(stderr, "cannot malloc()\n");
exit(11);
}
long mxsq = (long)sqrt((float)mx);
primelist(mxsq, plist, plsize);
for (lenplist = 0; plist[lenplist] != -1L; lenplist++) {
// printf("%ld ", plist[lenplist]);
}
count = mx;
countcombi(0, 0, 1LL);
count = count + (long long)lenplist - 1LL;
printf("%lld\n", count);
exit(0);
}
% time ./countprimes 10000000000
455052511
./countprimes 10000000000 8.20s user 0.05s system 99% cpu 8.253 total
% time ./countprimes 100000000000
4118054813
./countprimes 100000000000 80.24s user 0.62s system 98% cpu 1:22.34 total
ぶっ。
やっぱ、C言語速いわ……。
Discussion