🧮

nまでの素数の{個数, リスト}を求める・またはPythonの並列処理を巡る冒険

2024/05/31に公開

なんでだか、Pythonでの並列処理、またはnまでの素数を求めるのにハマってしまった。

Pythonでの並列処理は、円周率を求めるときにChudnovskyの公式を並列化するのに使ったことがある。
さいきん「Pythonの並列処理を学ぶのに良いネタはないですかね?」という問いに、「nまでの素数でも求めてみたら」とかいい加減に答えたのだが、その後自分でやってみるとなかなかの学びがあった。というはなし。
だから、高速にループ回すのになんでPython? とかの突っ込みは、なしの方向で。
2024年5月頃の3日間の記録。

素数とわし

nまでの素数のリストを求めるのは、この人生でおそらく100回くらいやっている。ちょっと試しただけなら100はくだらないであろう新しいプログラミング言語との出会い、その2回に1回はたぶんこのネタをやっている。覚えている限りZ80アセンブラからPostScriptまで、プログラムを作ってみているはずだ。

素数のリストと言えば、黒歴史もある。
大学院生のとき、研究室に4月に配属されてきた4年生相手にプログラミングコンテストをやっていた。初級編の最後で、10万までだったか100万までだったかの素数のリストを何秒でつくれるか、というのをやっていたのだが……ある年あろうことか「わしに勝ったら、裸で逆立ちしてグランド一周してやる!!」と宣言してしまった。まあ若気の至りっていうかワカメのイタリアンっていうやつですかね。
そのときわしが使ったのは、いわゆる『(6n \pm 1)法』(本編参照)で、速度には自信があった……が、4年生の女子が『エラトステネスの篩』法で愚直に組んだプログラムが1位。

その女の子が顔を赤らめながら「カップ麺奢ってくれればいいです」と言ってくれたので、裸でグランド一周は免れたが……言い訳するが、ちょうど時代の変わり目で、個人向けのメモリ64KBの8bit MS-DOSマイコンではなく大学では実メモリ4MB・swapその10倍くらいの32bit UNIXワークステーション (確か使ったのはSun-4 workstationだった) の威力が使えるようになった。つまり配列を100Kだか1Mだか取っても、unsigned shortを超える数を扱っても、それなりに速い時代になったのだ。
時代は変わるのだ、という教訓だ。

この記事でも、きょうびの流行りで、ループ回したら速度の遅さには定評があるPython縛りで、またMacbook Air/Pro程度の個人向けラップトップにてプレイしてみることにする。

multiprocessingまたはthreading

Python3で並列処理をするには、multiprocessing (プロセス並列) またはthreading (スレッド並列) というライブラリがある。どちらも、オブジェクトを作って自分自身の中の呼びたい関数を渡し、スタートして待ち合わせするだけである。書式もほとんど同じなので、円周率のときには簡単にいろいろ試せた。

(スレッド版)

thtest.py
import threading
import time

def testthread(x: float) -> None:
  print('start:', x)
  time.sleep(x)
  print('finish:', x)
  return

if __name__ == '__main__':
  th = []
  th.append(threading.Thread(target = testthread,
    args = (2., )))
  th[0].start()
  th.append(threading.Thread(target = testthread,
    args = (5., )))
  th[1].start()
  for t in th:
    t.join()
    print('*')

円周率の記事にもあるが、注意点としては『メインを地べたに書かないこと』くらいである。どうやらスレッドなりプロセスなりを立てるときに、自分が実行しているのと同じスクリプトを再度読む (このときには__name__'__main__'にならない) みたいなので、地べたに書くと初期化でなんかおかしなことになってエラーになってしまう。まあこれは、エラーメッセージを読んでその通りにメインを書き直せばすぐ解決する。

nまでの素数を並列処理で求める

nproc個のスレッドまたはプロセス (以下面倒なので、どちらでも当てはまるときは『プロセス』って書いちゃったりしますね) で、並列処理によりnまでの素数を求めるには、いろんな分け方があるだろう。

  • 適当に範囲を
    • 等分してnproc個に分ける
    • より細かい範囲に分けて→空いてるプロセスに与え→終わって暇になったら次の仕事を渡す
    • 掛かる時間が均等になるように予測して範囲を分ける
  • 剰余類で分ける
  • ……などなど

範囲等分のサンプルは、次のようになるだろう (シングル版・マルチスレッド版・マルチプロセス版)。

範囲等分のサンプル
prime_mp+th.py
import numpy as np
import time
import multiprocessing
import threading

def isprime(n: int) -> None:
  for i in range(2, int(np.sqrt(n)) + 1):
    if n % i == 0:
      return False
  return True

def printprimes(start: int, end: int) -> None:
#  print('entered printprimes(', start, ',', end, ')')
  for i in range(start, end):
    if isprime(i):
      pass
## to get output, uncomment here:
#      print(i, end = ' ')
#  print()

if __name__ == '__main__':
  MAX = 2000000
  NPROCS = 10

# single process/thread
  start = time.time()
  printprimes(0, MAX)
  print()
  print(time.time() - start)

# multithreads
  start = time.time()
  width = MAX // NPROCS
  th = []
  for pn in range(NPROCS):
    th.append(threading.Thread(target = printprimes,
                           args = (pn * width, (pn + 1) * width)))
    print('thread', pn, 'starts')
    th[pn].start()
  for pn in range(NPROCS):
    th[pn].join()
    print('thread', pn, 'ended')
  print(time.time() - start)

# multiprocess
  start = time.time()
  width = MAX // NPROCS
  th = []
  for pn in range(NPROCS):
    th.append(multiprocessing.Process(target = printprimes,
                       args = (pn * width, (pn + 1) * width)))
    print('thread', pn, 'starts')
    th[pn].start()
  for pn in range(NPROCS):
    th[pn].join()
    print('thread', pn, 'ended')
  print(time.time() - start)

ここではシンプルに、2から\sqrt{n}までの数で次々と割っている。printprimes()のコメントアウトを外せば標準出力に素数リストが出るが、並列処理に馴染みがないひとにとっては、プロセスがごちゃごちゃな順番で出力するのが逆に面白かろう。と思うが、小さい数までだと次のプロセスを起動する前にひとつのプロセスが終わってしまうので、あまり面白くない。

『素数のリスト』など見飽きているので(笑) ってか何億とかになると長いので、素数が求まったら出力せずカウントアップして『素数の個数』だけを求めればよし、と、この時点で問題を変更している(笑)。どーせ素数のリスト作らなきゃ、素数の個数なんて求められないだろ……?

これが後でまたハマる原因になる。

なぜ除数は$\sqrt{n}$まででよいか

nがふたつの素数p, qの積であったとき、p > \sqrt{n}だったらq < \sqrt{n}だからである (最悪のケースはnがある素数の2乗→p = q = \sqrt{n}のときである)。

おわり。

閑話休題。
範囲を等分するやり方は、少しプロセスの負担均等化に問題がある。この方法でnが素数かどうか判定するにはO(\sqrt{n})くらいの時間が掛かるので、後ろのほうのプロセスほど余計に時間が掛かってしまう (待ち合わせの順番はどんな順番でも、一番時間が掛かるプロセスが律速プロセスになってしまうことに変わりないので、ここは単純に考えてよい)。
だからといって、『より細かい範囲に分ける』と、終わったプロセスの監視などで余計なことを考えないといけなくなり、またプロセス起動のオーバヘッドも大きくなってしまう。
はじめから『掛かる時間が均等になるように予測』すれば無駄はなさそうだ。範囲ごとに素数を含む割合は近似式があるが、非素数の場合に約数にたどりつくまでの時間など、考えなければいけないことが多すぎる。

というわけで、剰余類で分けてみることにした。プロセスがnproc個のとき、プロセス番号pn番は i % nproc == pn なる候補数iを受け持つことになる。

だがこれには問題がある。各プロセスごとに求まった素数の個数を表示させてみれば明らかだが (させなくてもちょっと考えりゃわかるのだがw)、例えばpnが偶数番号のプロセスは偶数ばかりを検査することになるので、明らかに素数は見つからず無駄に帰ってくることになる。

(6n \pm 1)

並列処理からちょっと離れる。
素数のリストを作るとき、効率的な方法として有名なのは(6n \pm 1)法である。5以上の素数は(6n \pm 1) (n = 1, 2, \ldots)にしか含まれない。剰余群として考えたいので『5以上の素数は(6n + 1)または(6n + 5) (n = 0, 1, \ldots)である』と書き直す。
なぜかというと、(6n + 0)は2または3の倍数、(6n + 2)(6n + 4)は2の倍数、(6n + 3)は3の倍数であるからだ。

ということは、中学生くらいでも気の利いた奴なら知っている。
これを使えば、割られる数の候補は1/3に減らせる。
同様の手法で(30n + \{1, 7, 11, 13, 17, 19, 23, 29\})法だとか、(210n + \{1, 11, 13, \ldots\}) (面倒なので略w) 法とかも考えられるのだが、たとえば(30n + \ldots)法では(6n + \{1, 5\})法に比べても4/5にしか計算量が減らない割に、残った剰余群の場合分けが8倍になる。
だから普通はやらない(笑)。

(6n \pm 1)法と並列処理

さて、これでプロセス番号を分けるとしたら

  • プロセス数nprocは偶数とする
  • width = (nproc // 2) * 6 と置く (『a // b』は「abで割った整数部」「\lfloor a / b\rfloor」の意)。
  • pn番のプロセスには
    • もし偶数なら6 * (pn // 2) + 1から始まる
    • もし奇数なら6 * (pn // 2) + 5から始まる
      widthおきの数を、被除数として検査してもらう

たとえば、nproc = 10個だとしたら

  • width = 10 // 2 * 6 = 30
  • プロセスごとに以下の被除数を検査してもらう
    • 0番のプロセスは、\{(1,) 31, 61, \ldots\}
    • 1番のプロセスは、\{5, 35, 65, \ldots\}
    • 2番のプロセスは、\{7, 37, 67, \ldots\}
    • ……
    • 9番のプロセスは、\{59, 89, 119, \ldots\}

とすれば、なかなか均等に割れるのではないか。

まあこれも試してみると、仕事しないで いやしてるのだが空振りで帰って来るプロセスがいる。
そりゃそうだ。たとえば上記の例では、1番と8番のプロセスは明らかに5の倍数ばかりの被除数を検査している。
一般化すると、分割数nproc // 2(6n \pm 1)になるときに、これが起きうる。

まあプロセス分割ではなくてスレッド分割の場合 (スレッドで有効に並列化するはなしは、この後すぐ!!)、多少CPUコア数より多くても無駄時間は生じないし、空振りスレッドが早めに仕事を終えて帰ってきたら、その資源は別のスレッドに割り当てられるため、そこまで無駄にはならない。だが無駄は無駄である。

試してみたMacbook Air/Proくらいであれば (コア数8とか10とか)、スレッド数nproc = 8, 12, 16, 18, 24,...あたりが、無駄仕事が割り振られないスレッド数ということになる。

ダブル(6n \pm 1)

わしは(6n \pm 1)法による素数リストのプログラムを、生涯100回くらい書いているが、ここに書くことは今回まで気づかなかった。わしが未熟なだけだが。

上記は被除数を(6n + \{1, 5\})に限定しているが、実は除数も(6n + \{1, 5\})に限定できる。まあこのことに気づいたのは、学生が『2から\sqrt{n}までの整数で割る』のではなく、『23から\sqrt{n}までの奇数で割る』方法を使っていたからだ。

なんで奇数でよいかというと、被除数が偶数だったら2で割り切れるから素数じゃないし (当たり前)、被除数が奇数かつ割り切れる数だったら奇数掛ける奇数だからだ。偶数に何を掛けても奇数にはならぬ。すべての(\sqrt{n}以下の) 奇数で割って割り切れなければ、それは素数である。

同様に、被除数が(6n + 1)である場合(以下『n \equiv 1\ ({\rm mod}\ 6)』のように書く。「nを6で割った余りが1」の意)、これが2数 (6l + a), (6m + b) の積 (ただし a, b \in \{0, 1, 2, 3, 4, 5\}) であるためには、a \equiv b \equiv 1\ ({\rm mod}\ 6) である必要がある。同じように、(6l + a)(6m + b) = (6n + 5)であるためにはa \equiv 1, b \equiv 5\ ({\rm mod}\ 6) または a \equiv 5, b \equiv 1\ ({\rm mod}\ 6)$だ。
このことは、6で割った余りの表を書いてみればわかる (『a→』『↓b』はそれぞれa, bを6で割った余りの略、表中の数はa\times bを6で割った余り)。

↓b\a→ 0 1 2 3 4 5
0 0 0 0 0 0 9
1 0 1 2 3 4 5
2 0 2 4 0 2 4
3 0 3 0 3 0 3
4 0 4 2 0 4 2
5 0 5 4 3 2 1

つまり、被除数(6n + 1)が素数でない=何らかの2数の積であるためには、その2数は共に6で割ってみると余りが1または5、という形をしている、ということだ。(6l + \{0, 2, 3, 4\})の形の数は、何を掛けても(6n + 1)な数にならない。だから、(6l + \{1, 5\})の形の数すべてで割って余りが出れば、(6n + 1)は素数だ、といえる ((6n + 5)も同様)。

除数と被除数が共に1/3になるので、計算回数は単純に1/9になる。
並列化についても、前記の方法で被除数を(6n + \{1, 5\})のグループごとに並列化した場合でも、除数を(6n + \{1, 5\}) (n = 0, 1, \ldots)に変更するだけである。

Python3.(<=12) のGILと、subinterpreterでスレッド有効活用

以前の円周率ネタでも、なかなかthreadingを使ったマルチスレッドが速くならない、というのは経験していたが、そのときは深く考えず放置していた。まあmultiprocessingでも、数分〜数時間帰ってこないお仕事の分割なので、別にオーバヘッドは問題ではなく、CPUのコア数をうまく勘案すればプロセス分割でよいか、と思っていたからである。

ところが今回、スレッドについていろいろ調べているうちに、Python3でスレッド分割しても、ひとつのプロセス内ではロック (GIL) が掛かってしまい、スレッドが同時にふたつ以上実行されない(爆)ということが分かった。いやそんな基本的なこと知らなかったわしがアホなだけであるが。

そしてまた最近のPython3.12では、subinterpreterというものが実装されて、インタプリタごとに異なるスレッドが実行できる!! らしい……なんのこっちゃ?

基本的な書き方は、次のようなものらしい。

subinterpreterのサンプル
sitest.py
import time
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels

scr = '''
import time
import _xxinterpchannels as channels
print('start:', arg, chid)
time.sleep(arg)
print('finish:')
channels.send(chid, arg ** 2)
'''

if __name__ == '__main__':
  ids = []
  chids = []

  id = interpreters.create()
  ids.append(id)
  chid = channels.create()
  chids.append(chid)
  arg = 5
  interpreters.run_string(
    id, scr, shared = {'arg': arg, 'chid': chid}
  )
  id = interpreters.create()
  ids.append(id)
  chid = channels.create()
  chids.append(chid)
  arg = 2
  interpreters.run_string(
    id, scr, shared = {'arg': arg, 'chid': chid}
  )

  for chid in chids:
    print(channels.recv(chid))
  for id in ids:
    interpreters.destroy(id)

まだ現在のPythonでは試用版らしく、subinterpreterのライブラリは_xxsubinterpreterなる、なんかエロい 怪しい名前になっている。

このsubinterpreterは、文字列として実行したいスクリプトを渡し (オレンジのhere document部分)、引数 (ってかシェアしたい値) をsharedで渡してやる。あと、戻り値はchannelというのを使えば (subinterpreterで実行しているスレッドからはchannels.send()、メイン側でchannels.recv()で) 戻って来る。

これとthreadingを組み合わせると、こうなる。

subinterpreter + threadsのサンプル
sithtest.py
import threading
import time
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels

scr = '''
import time
import _xxinterpchannels as channels
print('start:', arg, chid)
time.sleep(arg)
print('finish:')
channels.send(chid, arg ** 2)
'''

if __name__ == '__main__':
  ths = []
  ids = []
  chids = []

  id = interpreters.create()
  ids.append(id)
  chid = channels.create()
  chids.append(chid)
  arg = 2
  th = threading.Thread(target = interpreters.run_string,
       args = (id, scr),
       kwargs = {'shared': {'arg': arg, 'chid': chid}})
#  interpreters.run_string(
#    id, scr, shared = {'arg': arg, 'chid': chid}
#  )
  ths.append(th)
  th.start()

  id = interpreters.create()
  ids.append(id)
  chid = channels.create()
  chids.append(chid)
  arg = 5
  th = threading.Thread(target = interpreters.run_string,
       args = (id, scr),
       kwargs = {'shared': {'arg': arg, 'chid': chid}})
#  interpreters.run_string(
#    id, scr, shared = {'arg': arg, 'chid': chid}
#  )
  ths.append(th)
  th.start()

  for th in ths:
    th.join()
  for chid in chids:
    print(channels.recv(chid))
  for id in ids:
    interpreters.destroy(id)

threading.Threadにはinterpreters.run_string()を食わせてやる。interpreters.run_string()の引数はkwargs = で辞書で与える必要がある。

そんなこんなで、ダブル(6n + \{1, 5\})をスレッド並列化してみた。

ダブル$(6n \pm 1)$法のスレッド並列化
hspmod.py
import threading
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels
import time
import sys


script = '''
import _xxinterpchannels as channels
def isprime6(x: int) -> bool:
  d = 5
  mx = int(x ** .5)
  while d <= mx: 
    if x % d == 0:
      return False
    d = d + 2
    if x % d == 0:
      return False
    d = d + 4
  return True

cnt = 0
for i in range(start, end, stps):
#  print(i, end = '')
  if isprime6(i):
#    print('*', end = '')
    cnt = cnt + 1
#  print('|', end = '')
#print('\\n{}..{}: {}'.format(start, end, cnt))
channels.send(chid, cnt)
'''

# subinterpreter + threads

if __name__ == '__main__':
  if len(sys.argv) != 3:
    print('usage: hspmod.py MAX NPROCS', file = sys.stderr)
    print('  NPROCS must be even', file = sys.stderr)
    sys.exit(1)
  MAX = int(sys.argv[1])
  NPROCS = int(sys.argv[2])
  if NPROCS % 2 != 0:
    print('NPROCS must be even', file = sys.stderr)
    sys.exit(1)

  stt = time.time()
  stps = NPROCS * 6 // 2
  ids = []
  chids = []
  threads = []
  for pn in range(NPROCS):
    start = (pn // 2) * 6
    if pn % 2 == 0:
      start = start + 1
      if start == 1:
        start = start + stps
    else:
      start = start + 5
    end = MAX + 1
    id = interpreters.create()
    ids.append(id)
    chid = channels.create()
    chids.append(chid)
    print(pn, start, end, stps, chid)
    thread = threading.Thread(
      target = interpreters.run_string,
      args = (id, script),
      kwargs = {'shared':
                { 'start': start,
                  'end': end,
                  'stps': stps,
                  'chid': chid
                }
               }
    )
    thread.start()
    threads.append(thread)
  for thread in threads:
    thread.join()
  cnt = 2  # {2, 3}
  for chid in chids:
    retv = channels.recv(chid)
    cnt = cnt + retv
    print(retv, end = ' ')
  print('\n', cnt)
  for id in ids:
    interpreters.destroy(id)
  print(time.time() - stt)

例によって、100億 (小学生かよ)

ここまでで、およその実実行時間は、Macbook Airで1千万までで数秒程度になった。
となると、わしの大好きな100億 (※ 円周率参照) が視野に入ってくる。そのくらいになると、もう少し高速化してみたいところである。

除数についてだが、\sqrt{n}ってのは大したことない。被除数が100億までだとしても、除数は10万以下の素数の個数 (1万個弱) なので

  • まずはシングルスレッドで、ダブル(6n + \{1, 5\})法求めた\sqrt{n}のリストをつくる
  • 前記の剰余群にスレッド分割して、担当のスレッドにスタート数・終了数・ステップ (前記width)・素数リストを渡して計算してもらう
    • スレッド内ではスタート数から終了数までステップごとに被除数を
    • 素数リストの中の被除数の平方根より小さい除数で
      割り切れるかテスト

とすれば大幅に速くなるはずだ。

ただし、前者の除数リスト・後者のスレッド分割開始の数について、つなぎ目に気をつけないと素数の数をダブルカウントしてしまうなどのおそれがある。
これについては (ややこしいが)、\sqrt{n}を超えない(6i + 5)linmaxとして、2..linmax + 2 (linmax + 2(6i + 1)群)以下の素数をリストに、linmax + 6 (linmax + 6(6i + 5)群)以降をとにかく0番からのプロセスに分割してしまった。多少余計にリストの素数を求めてしまう・多少プロセスごとの被除数の個数が凸凹してしまうが、問題はない。

で、ここで問題が生じた。各プロセスに渡す\sqrt{n}以下の素数リストであるが、sharedにリストは指定できない。まあ解決策は強引かつ単純で、リストをpickle化してファイルに書き出し、ファイル名をsubinterpreter (スレッド) に渡せばよい。

以下長くなるが、最高速版のフルリスト。

hspmodl_print.py
import threading
import _xxsubinterpreters as interpreters
import _xxinterpchannels as channels
import pickle
import time
import sys

FFNAME = 'pf'

def isprime6(x: int) -> bool:
  d = 5
  mx = int(x ** .5)
  while d <= mx: 
    if x % d == 0:
      return False
    d = d + 2
    if x % d == 0:
      return False
    d = d + 4
  return True


def primelist(end: int) -> list:
#  plist = [2, 3]
  plist = []
  for i in range(5, end + 1, 6):
    if isprime6(i):
      plist = plist + [i]
    if isprime6(i + 2):
      plist = plist + [i + 2]
  return plist


script = '''
import pickle
import _xxinterpchannels as channels
FBASE = 'p'

def isprimediv(n: int) -> bool:
  mx = int(n ** .5)
  for div in divlist:
    if mx < div:
      return True
    if n % div == 0:
      return False
  return True


with open('plist.pkl', 'rb') as fp:
  divlist = pickle.load(fp)
#print(divlist)
with open(FBASE + str(chid) + '.txt', 'w') as fp:
  cnt = 0
  for i in range(start, end + 1, stps):
    if isprimediv(i):
      print(i, file = fp)
      cnt = cnt + 1
channels.send(chid, cnt)
'''

# subinterpreter + threads

def printlist(l: int, fn: str) -> None:
  with open(fn, 'w') as fp:
    for i in l:
      print(i, file = fp)
  return

if __name__ == '__main__':
  if len(sys.argv) != 3:
    print('usage: hspmod.py MAX NPROCS', file = sys.stderr)
    print('  NPROCS must be even', file = sys.stderr)
    sys.exit(1)
  MAX = int(sys.argv[1])
  NPROCS = int(sys.argv[2])
  if NPROCS % 2 != 0:
    print('NPROCS must be even', file = sys.stderr)
    sys.exit(1)
  r = (NPROCS // 2) % 6
  if r == 1 or r == 5:
    print('warning: some threads donot count any primes')

  stt = time.time()
  width = 6 * (NPROCS // 2)
  linmax = ((int(MAX ** .5) + 5) // 6) * 6 - 1
  # linmax <= sqrt(MAX), linmax % 6 == 5
  # 2..(linmax + 2) are tested (by linear method) number
  #   (within divlist if prime).
  plist = primelist(linmax)
  with open ('plist.pkl', 'wb') as fp:
    pickle.dump(plist, fp)
  print('2..{} tested by linear'.format(linmax + 2))
  print(time.time() - stt, 'sec')
  printlist([2, 3] + plist, FFNAME + '.txt')
  start0 = linmax + 6
  # start from (linmax + 6) by multiproc method (start0 % 6 == 5).
  print('{}..{} tested by multiproc'.format(start0, MAX))

  ids = []
  chids = []
  threads = []
  for pn in range(NPROCS):
    start = start0 + (pn // 2) * 6  # 6n + 5
    if pn % 2 == 1:
      start = start + 2 # 6n + 1
    id = interpreters.create()
    ids.append(id)
    chid = channels.create()
    chids.append(chid)
    print('proc #{}: {}..{} step {}, chid {}'.\
          format(pn, start, MAX, width, chid))
    thread = threading.Thread(
      target = interpreters.run_string,
      args = (id, script),
      kwargs = {'shared':
                { 'start': start,
                  'end': MAX,
                  'stps': width,
                  'chid': chid
                }
               }
    )
    thread.start()
    threads.append(thread)
  for thread in threads:
    thread.join()
  cnt = 2 + len(plist)  # {2, 3}
  for chid in chids:
    retv = channels.recv(chid)
    cnt = cnt + retv
    print(retv, end = ' ')
  print('\n', cnt)
  for id in ids:
    interpreters.destroy(id)
  print(time.time() - stt, 'sec')

こんな感じになりました。

なお上のプログラムは、求めた素数をリストにして書き出して (\sqrt{n}まではひとつのファイルpf.txt、並列部分はスレッドごとにひとつのファイルp{0,...,nproc}.txt) いるが、個数のカウントのみならず素数リストを作っているのは……次の節の理由による (このプログラムより前のベンチマークでは、求めた素数そのものは捨てて、個数のみカウント)。

いまどのへんまで求めているかは、上記のpてきとう.txttailしてみればわかるので、無駄に時間を喰うモニタ出力は一切、してません。あと出力テキストファイル (1個ずつ改行) の大きさは、素数の個数を見積もる近似式 (次節) があるので、全部11桁と仮定しても5GBちょっとになると暗算して、安心して寝ますた。

Macbook Airでnに対する実計算時間を測ってみたら、およそ桁数\log(n)のオーダで実行時間が増えていく。これならAirでも数時間で計算が終わるだろ……と思って、一晩放置しましたが。


(Macbook Air 8core。いろいろなnproc = 8・12・16で試してみてますが、グラフ上ではほとんど変わらん(笑))

素数の個数のカウントから素数のリストへ

そんなこんなで、時間の見積もりとかしているときに『nまでの素数の個数\pi(x) (割合x / \pi(x))の近似式』があるのを思い出した。

\pi(x) \simeq x / \log(x)

これを検索していたら、間違ってなんと、『n以下の素数の個数を求める (実際に素数を求めずに) Meissel–Lehmerアルゴリズム』というのが引っかかってきてしまった(笑)。

え? え? 近似じゃなくて、正確な個数を、素数そのものを求めずにカウントする方法があるの!?

ちょっと読んだだけでは理解できなかったのだが、どうやらそっちのほうが速そうだ。
ならば悔しいので『具体的な素数を求めてリストを作る』問題に、ゴールを変えちゃえ(笑)。

というわけで、100億までの素数を全部出力してみました(あほ)。これが家にあれば、ちょっと嫌なことがあっても「まあ家に帰れば素数100億まであるしな」ってなるし仕事でむかつく人に会っても「そんな口きいていいのか?私は自宅で100億までの素数の表とよろしくやってる身だぞ」ってなれる
個数は455052511個です。
Macbook Pro (Arm M1, 16core)で、CPU load 99%♡ でぶん回して、14619秒 (4時間3分39秒) で求まりました。

ログ
2..100003 tested by linear
0.14380383491516113 sec
proc #0: 100007..10000000000 step 48, chid 0
proc #1: 100009..10000000000 step 48, chid 1
proc #2: 100013..10000000000 step 48, chid 2
proc #3: 100015..10000000000 step 48, chid 3
proc #4: 100019..10000000000 step 48, chid 4
proc #5: 100021..10000000000 step 48, chid 5
proc #6: 100025..10000000000 step 48, chid 6
proc #7: 100027..10000000000 step 48, chid 7
proc #8: 100031..10000000000 step 48, chid 8
proc #9: 100033..10000000000 step 48, chid 9
proc #10: 100037..10000000000 step 48, chid 10
proc #11: 100039..10000000000 step 48, chid 11
proc #12: 100043..10000000000 step 48, chid 12
proc #13: 100045..10000000000 step 48, chid 13
proc #14: 100049..10000000000 step 48, chid 14
proc #15: 100051..10000000000 step 48, chid 15
28441480 28437829 28439317 28440138 28438970 28439733 28439461 28439126 28441462 28437590 28441167 28440729 28441228 28441900 28441495 28441293 
 455052511
14618.746592998505 sec

結果のマージソート

上記はファイルが、1, \ldots, \sqrt{n} までの素数のリストpf.txtと、プロセスごと (widthの剰余群ごと) のp{0,1,...}.txtに分かれちゃっている。これをソートしなければひとつの素数リストにならないが……まず前者は後者の手前 (\sqrt{n} またはそれよりちょっと) までが昇順でひとつのファイル、後者は剰余類ごとに分かれているがファイルの中では昇順になっている。

したがって、前者を最終出力にコピー、それに後者をマージソートしたものをくっつければおしまい。
引数として計算に使ったのと同じ最大数nとプロセス数を与える必要があるのは、EOFを知るため・ファイル名を知るため。
やっつけで書いたので、結構汚い。

psort.py
import sys
import numpy as np

PFNAME = 'pf'
FBASE = 'p'

nmax = int(sys.argv[1])
nproc = int(sys.argv[2])

ofp = open('pout.txt', 'w')
#ofp = sys.stdout

ff = open(PFNAME + '.txt', 'r')
ns = ff.readline()
while ns:
  print(ns, end = '', file = ofp)
  ns = ff.readline()
ff.close()

fp = [None] * nproc
pr = np.zeros(nproc, dtype = int)
for i in range(nproc):
  fp[i] = open(FBASE + str(i) + '.txt')
  pr[i] = int(fp[i].readline())
opencount = nproc
while 0 < opencount:
  am = np.argmin(pr)
  print(pr[am], file = ofp)
  ns = fp[am].readline()
  if not ns:
    fp[am].close()
    pr[am] = nmax + 1
    opencount = opencount - 1
  else:
    pr[am] = int(ns)

ofp.close()

Pythonでやったら遅いかな? と思ったけど、そうでもなかった。11分24秒で完了しました。
これも計算時間に含めないと、アンフェアかな?

100億までの素数 (長いのでw 頭と末尾のみ)
2
3
5
7
11
13
17
19
23
29
...
9999999769
9999999781
9999999787
9999999817
9999999833
9999999851
9999999881
9999999929
9999999943
9999999967

素数の個数のカウント (素数を計算せずに) ふたたび

実際に素数を求めずに、n以下の素数の個数を求める『Meissel–Lehmerアルゴリズム』が存在する? とな?
しかもこれ、競技プログラミングでは定番(?) ううむ?

と思ったが、英語版Wikipediaや英語の論文は読むのが面倒くさく、日本語の説明もよくわからん。
まあ、『エラトステネスの篩』とかdynamic programmingとか、キーワードを拾い読みしていたらなんとなく分かってしまった、ってか自分で再発明した(笑)ので、以下説明。オリジナルと違ってるかもしれないし、間違ってる……ことはないだろうが、効率悪いかもしれぬ。

まず、\{2, 3, 5\}という素数のリストがあるとしよう。これを篩に使うと、5 ^2 = 25以下 (つか48以下) の素数のリストが作れる。が、あえてリストを作らずに数えるだけにする。

  • 25以下の整数は25個w
  • 2の倍数は素数じゃないので除外する (2は素数だけどそのはなしは後ほど)。何個あるかというと\lfloor 25/2 \rfloor = 12
  • 同3の倍数は8個、5の倍数は5個
  • 25 - (12 + 8 + 5) = 0で何も残らん(笑)が、これはなんでかというと
    • 2と3の公倍数6の倍数は、2回除外されている。6の倍数は4個
    • 同2と5の公倍数10の倍数は2個、3と5の公倍数15の倍数は1個
    • したがってこれらを足し戻して0 + 4 + 2 + 1 = 7個残る
  • (さらに2と3と5の公倍数30の倍数は足し戻されすぎているので引き (戻し戻し) たいが、30の倍数は25以下に0個)
  • これに\{2, 3, 5\}を加えて10個、あと1は素数じゃないので除外して9個

ということになる。\{2, 3, 5, 7, 11, 13, 17, 19, 23\} の9個と一致する。

あるいは、『エラトステネスの篩』法で素数のリストを作り終わったとき、取り消し線の総数が25本 (実際は下図のように22本だけど。2・3・5は消してないので)、うち7つの数はひとつの数に2本線が引かれている。それに消されない\{2, 3, 5\}を加えて10個、あと1は素数じゃない(以下略) と考えると分かりやすい。

これを一般化すると

  • \sqrt{n}以下の素数のリストを作る
  • はじめにn以下の整数の個数nを『カウンタ』の初期値とする
  • このリストから、すべての1つ以上のコンビネーションを作る:
    pをコンビネーションの積としたとき、int(n / p)
    • コンビネーションの要素数が奇数ならカウンタから引く
    • 偶数ならカウンタに足す
  • 最後に、素数のリストの要素数を足して、1 (整数1の分) を引く

ということでいかがだろうか。

これが速いのか……というと。
たとえば100億までの素数の個数をカウントしようとしたら、すべてに\sqrt{} 100億 = 10万までの素数のリスト約9592個 (もう100億までの素数求めちゃってるもんね(笑)) の、すべてのコンビネーションを求めないといけない。コンビネーションのすべての和は、\displaystyle \sum_{r = 1, n} \left(\begin{array}{c}n\\r\end{array}\right) = 2 ^n - 1 より2^{9592}通りになり、これは10進2800桁もの莫大な数になる。ん? 10進11桁のループのほうがはるかに速くねーか? つか、終わらん。

実際には、すべてのコンビネーションを実行する必要はない。前記の25まで→\{2, 3, 5\}の例では、この3つの数のコンビネーションをほとんど全部試さなければならなかったが、例えば100億まで→9592個の素数リストのうち最後のほうの2個を考えてみれば、ほぼ10万(\sqrt{} 100億)なので2つ以上の数を掛け算する必要はない。

実際には、前からポインタを動かしつつ、次の数のポインタを動かしつつ、……をすべて掛け合わせてnを超えたらそれ以上探索を打ち切る、という具合にダイナミックプログラミング的(?)に、ループの段数を変化させつつコンビネーションを探すのは、再帰で簡単にできそうである (よくわからん説明だな。下の実行例を見てくれい)。

したがって、せっかくPythonなのにitertools.combinations()とかは使ってはいけない。DPによって打ち切ることなくすべての組み合わせを延々数え始めるので、ほんとに2^{9592}個の候補を計算しなければならなくなる。

ここで、ふとあることに気づく。再帰の段数は大丈夫?
Pythonの再帰可能な段数は、確か以前に (なんちゃって関数型プログラミングとかやってみたとき) 500段くらいだったか、結構浅いところでエラーが出た (※ sys.setrecursionlimit()である程度変えられる)。
100億までの素数は上記の通り最大9592段ループできなければダメ……ってわけじゃなくて、最初からk個の素数の積が100億を超えたところで再帰が戻ってくるのが再帰の最大 (最悪) の段数なので、試みに最初からk個の積をちょっと計算してみると

1 2
2 6
3 30
4 210
5 2310
6 30030
7 510510
8 9699690
9 223092870
10 6469693230
11 200560490130

なんだ、小さい方から11個で素数の積が100億超えるじゃん……11段くらいの再帰でいいのね。
まあ、いったん再帰な頭が完成したら、再帰を使わないループで簡単に書き直したりできるのだが。

あと、計算量も2^{9592}通りとかではなく、最大でも\displaystyle \sum_{i = 1, 11} \left(\begin{array}{c}9592\\i\end{array}\right) \simeq 1.58 \times 10^{36}通り……ってやっぱ、でかくね!? (11個の数を乗するのは前述の通り最悪のケース。DPでの打ち切りに期待しませう)。

(関数primelist(int: n) -> list は、さっきと同じなので略)

countprimes.py
import sys
import time

def countcombi(level: int, start: int, lprod: int) -> None:
  global count
#  global lplist
  for pp in range(start, len(plist)):
    div = lprod * plist[pp]
    if mx < div:
      return
#    print('prod(', lplist + [plist[pp]], ') =', div)
    if level % 2 == 0:
      count = count - mx // div
#      print('-', mx // div, '=>', count)
    else:
      count = count + mx // div
#      print('+', mx // div, '=>', count)
#    lplist = lplist + [plist[pp]]
    countcombi(level + 1, pp + 1, div)
#    lplist = lplist[:-1]


if __name__ == '__main__':
  if len(sys.argv) != 2:
    print('usage: countprimes MAX', file = sys.stderr)
    sys.exit(1)
  mx = int(sys.argv[1])

  stt = time.time()
  plist = primelist(int(mx ** .5))
#  print('plist:', plist)

  count = mx
#  lplist = []
  countcombi(0, 0, 1)
#  print('count:', count, '+', len(plist), '- 1 =>', end = ' ')
  count = count + len(plist) - 1
  print(count)
  print(time.time() - stt)

関数countcombi(level: int, start: int, lprod: int) -> Noneは、
再帰的にコンビネーションを生成し、積を計算し、倍数の個数を足し引きする。
levelは再帰レベルだが、すなわちこれまでいくつの素数を掛け算したかを示している。startplist[]中のスタート要素の番号で、この番号〜最後までを走査することになる。lprodはこれまで (つまりstartの左側のlevel個の素数) の積 (キャッシュみたいなもの) で、いちいちlevel + 1個の要素を掛け算するのを防ぐ。なおデバッグプリント (プログラム中のコメントアウトしているprint()) のためにlplist[]にこれまでかけ合わせたの素数のリストを作ってみているが、ただ個数をカウントするだけならlprodがあるので必要ない。

デバッグプリントを有効にして、100までの素数を求めるとこんな感じ。

100までの素数の個数を求める詳細
python3 countprimes.py 100 
plist: [2, 3, 5, 7]
prod( [2] ) = 2
- 50 => 50
prod( [2, 3] ) = 6
+ 16 => 66
prod( [2, 3, 5] ) = 30
- 3 => 63
prod( [2, 3, 7] ) = 42
- 2 => 61
prod( [2, 5] ) = 10
+ 10 => 71
prod( [2, 5, 7] ) = 70
- 1 => 70
prod( [2, 7] ) = 14
+ 7 => 77
prod( [3] ) = 3
- 33 => 44
prod( [3, 5] ) = 15
+ 6 => 50
prod( [3, 7] ) = 21
+ 4 => 54
prod( [5] ) = 5
- 20 => 34
prod( [5, 7] ) = 35
+ 2 => 36
prod( [7] ) = 7
- 14 => 22
count: 22 + 4 - 1 => 25

さて実行時間だが、次のような感じになる。なおオリジナルのMeissel-Lehmerアルゴリズムの計算量はこんなところで考察してくれているひとがいる (やっぱり競技プログラミング筋のようだ)。

んで100億はというと……。

% python3 countprimes.py 10000000000
455052511
511.4769809246063

Macbook Proで、8分32秒ほどで求まりました。
やっぱ数えるだけのほうが、断然速い……。


(紫はPython・緑は下の『おまけ』のC言語版、Macbook Pro (ARM M1))

ここまで来たら、再帰の最初のほうのレベルで範囲を10人くらいの子供プロセスに分割して並列……とか思ったのだが、先に下のC言語への書き直しをして爆速体験をしてしまったので、もう萎えた(笑)。

おまけ: C言語で書き直してみた

カウントするだけのほう、ただC言語に直してみました。

C言語版
countprimes.c
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <stdbool.h>


long long mx;
long *plist;
long long count;


bool
isprime6(long x)
{
	long d = 5;
	long mx = sqrt(x);
	while (d <= mx) {
		if (x % d == 0) {
			return false;
		}
		d = d + 2;
		if (x % d == 0) {
			return false;
		}
		d = d + 4;
	}
	return true;
}


void
primelist(long end, long *plist, int plsize)
{
	plist[0] = 2L;
	plist[1] = 3L;
	int ptr = 2;
	for (long i = 5; i < end + 1 && ptr < plsize; i = i + 6) {
		if (isprime6(i)) {
			plist[ptr++] = i;
		}
		if (isprime6(i + 2L)) {
			plist[ptr++] = i + 2L;
		}
	}
	plist[ptr] = -1L;
	return;
}


void
countcombi(int level, int start, long long lprod)
{
	long long div;
	
	for (int pp = start; plist[pp] != -1L; pp++) {
		div = lprod * (long long)plist[pp];
		if (mx < div) {
			return;
		}
		if (level % 2 == 0) {
			count = count - mx / div;
		} else {
			count = count + mx / div;
		}
		countcombi(level + 1, pp + 1, div);
	}
	return;
}

			
int
main(int argc, char **argv)
{
	int lenplist;
	
	if (argc != 2) {
		fprintf(stderr, "usage: countprimes MAX\n");
		exit(1);
	}
	mx = atoll(argv[1]);
	int plsize = (int)((float)mx / log((float)mx) * 1.01);
	if ((plist = (long *)malloc(sizeof(long) * (size_t)plsize)) == NULL) {
		fprintf(stderr, "cannot malloc()\n");
		exit(11);
	}
	long mxsq = (long)sqrt((float)mx);
	primelist(mxsq, plist, plsize);
	for (lenplist = 0; plist[lenplist] != -1L; lenplist++) {
//		printf("%ld ", plist[lenplist]);
	}
	
	count = mx;
	countcombi(0, 0, 1LL);
	count = count + (long long)lenplist - 1LL;
	printf("%lld\n", count);
	
	exit(0);
}
% time ./countprimes 10000000000
455052511
./countprimes 10000000000  8.20s user 0.05s system 99% cpu 8.253 total

% time ./countprimes 100000000000
4118054813
./countprimes 100000000000  80.24s user 0.62s system 98% cpu 1:22.34 total

ぶっ。
やっぱ、C言語速いわ……。

Discussion