📝

Universal Cup 3-27 A: (A + B) mod P

2025/01/26に公開

解法も面白かったし、背景も面白かった

問題

奇素数 p \leq 100 が与えられる。 以下の条件を満たす D 次元 (\mathbb{R}) ベクトルたち A_0,..,A_{p-1},B_0,..,B_{p-1} を構築せよ。

  • 任意の 0 \leq a,b < p に対し、v[d] := \max(A_a[d]+A_b[d],0) として D 次元ベクトル v を計算する。この時、vB_k の内積が最も大きくなる k は ちょうど一つであり、それは k = a+b \mod p である
  • D \leq 25 でなければいけない

ちなみに元の問題文は機械学習の文脈で書いてあって、
ReLU(a_{one-hot}A + b_{one-hot}A) \cdot B^T というモデルを使って a+b \mod p を計算させましょうと書いてあった。

アイデア

  • 普通の D = 2 次元平面に、 A_a = ( \cos(\frac{2\pi}{p} \cdot a), \sin(\frac{2\pi}{p} \cdot a) ) と単位円上にプロットしてみることを考える。すると A_a + A_b の方向は、円周上を 2p 個に分割した時の c := (a+b)\%p 番目の点 P_c := ( \cos(\frac{\pi}{p} \cdot c), \sin(\frac{\pi}{p} \cdot c) ) 、あるいは真逆の点 P_{c+p} と同じ向きを向いている
    • たとえば p = 5 として、A_1 + A_4P_0 = (1,0) と同じ向きを向いている。 A_2+A_3 は真逆である P_5 = (-1,0) の方を向いている。
  • なので、A_a + A_bP_k の内積の最大値は、 k = c, あるいは k = c + p のどちらかで取る、ということは言える
    • 内積の最大値で |P_k|=1 なので A_a+A_b の arg しか関係ないことに注意
    • p 奇数より A_a + A_b = 0 と潰れることもない
  • この二択をうまくひとつにまとめるために \max(\cdot,0) の部分を使う

解法

実は D = 4 次元しか使わずに出来る。あと p が素数である必要もない。(奇数である必要はある)

まず 0 < \varphi \leq \pi/2 をうまく(後述)取り、

  • A_a[0] := \cos(\frac{2\pi}{p} \cdot a)
  • A_a[1] := \cos(\frac{2\pi}{p} \cdot a - \varphi)
  • A_a[2] := -A_a[0]
  • A_a[3] := -A_a[1]
    と定義する。それぞれ A_a(1,0), (\cos \varphi, \sin \varphi), -(1,0), -(\cos \varphi, \sin \varphi) 方向に射影した時の(符号付き)長さであることに注意。

B も同様に、P_k \ (0 \leq k < p) をこれら4方向に射影した長さと定義するが、こちらは符号の絶対値を取り、さらにノルムを正規化する。すなわち、

  • B'_k[0] := |\cos(\frac{\pi}{p} \cdot k)|
  • B'_k[1] := |\cos(\frac{\pi}{p} \cdot k - \varphi)|
  • B'_k[2] := B'_k[0]
  • B'_k[3] := B'_k[1]
  • B_k := (B'_k の長さを 1 にしたやつ)

A_a + A_b(x,y,-x,-y) のような形になっており、各値 0 と max を取ってから B_k = (z,w,z,w) と内積を取ることを考えると、 (x,y,-x,-y)(-x,-y,x,y) で結果は変わらず |x|z+|y|w となる。
つまり、A_a+A_bP_c を向いていようが P_{c+p} を向いていようがどちらにせよ、|x|,|y| = B_c[0], B_c[1] を使って |x|z + |y|w を最大にする B_kk = c であることを示せば十分。
これは長さ1のベクトルたち B_0,..,B_{p-1} の中で B_c との内積が最大になるのは唯一 k = c である、というほぼ当たり前のことを言っている。ひとつ確かめないといけないのは B_k が unique だということで、これは \varphi を適切に取れば問題ない。
A,B の定義的に \varphi = \pi/2 と取るのが自然そうに見えるが、そうするとめちゃくちゃ点がかぶるのでダメだったんですね。 \varphi = \pi/3 とかでよさそう。

(|cos t|,|cos(t-pi/3)|) のプロット。赤い点は (B'_k[0], B'_k[1]) たち。これを正規化してるので赤い点たちのargが一致しなければ良い

おまけ

元の文脈は、機械学習でこういうきれいな数理的関数を一部training dataを使って学習していると、途中までは training data に適合していくだけだが、途中から急に数理的性質を"発見" するという現象 (Grokking) があって面白いね、という話らしい。
https://pair.withgoogle.com/explorables/grokking/ ここにめちゃくちゃ綺麗にまとまっている。インタラクティブに遊べて楽しい。

Discussion