SECCON CTF WitchQuiz stage3 速報writeup
概要
stage1は1bitずつ山登りを中心とした戦略を、stage2は半数程度のチームが全問正解をすることができました。stage4は国内で解いている人がいるので、それらは動画を作りつつwriteupを待つとして、ここではstage3のwriteupを紹介します。
問題内容
ここでは、問題に取り組んだ人向けに紹介するので、雑に紹介します。
こちらの画像は(見にくいですが、1ドットの緑の点があります)楕円曲線を表していて、十字が今のtickが指している座標を表しています。
この問題では、「点」と「十字」取得することが目標です
楕円曲線のパラメータ、aとbを知る
この問題では、
な楕円曲線が与えられており、画像上の点は上記を満たす
self.p = 1009
for x in range(self.p):
for y in range(1, self.p):
if (y*y) % self.p == (x*x*x + self.a * x + self.b) % self.p:
self.points.append((x,y))
self.points.append((x,self.p-y))
yを固定して探索する
二分探索をしたいのですが、そこのyに点がないと二分探索する意味がありません。まずは点のいちを調べましょう。
解答のクエリとして、以下のようなものを投げることにより、
[9] * 1009 * y + [1] * 1009
このような点の数はいくつ程度あるのでしょうか?
以下のような実験コードを回すことで大まかに500個前後存在しているので、50%程度の確率で見つけることができそうです。
(ここ雑ですみません、実際にはもっと多くの数で検証していますが、未証明です。90度回転して縦を固定するケースだと直感的には50%程度になりそう(平方剰余が存在する確率と同じなので)ですが、こちらもある程度確認はしていますが未証明です🙇)
import random
for _ in range(100):
p = 1009
a = random.randrange(1,p)
b = random.randrange(1,p)
memo = {}
for x in range(p):
for y in range(1, p):
if (y*y) % p == (x*x*x + a * x + b) % p:
if y not in memo:
memo[y] = set()
memo[p-y] = set()
memo[y].add(x)
memo[p-y].add(x)
cnt = 0
for k,v in memo.items():
if len(v) == 1:
cnt += 1
print(cnt, end=", ")
# output:
# 494, 496, 514, 488, 520, 524, 522, 524, 498, 508, 484, 506, 514, 496, 490, 528, 508, 528, 486, 506, 518, 492, 500, 492, 532, 482, 496, 508, 502, 518, 480, 514, 506, 510, 514, 502, 478, 480, 516, 528, 480, 520, 480, 518, 494, 520, 480, 508, 532, 524, 520, 492, 504, 490, 490, 522, 496, 520, 512, 504, 504, 522, 480, 516, 490, 488, 534, 502, 508, 506, 524, 498, 520, 526, 476, 484, 476, 492, 524, 490, 508, 532, 508, 516, 490, 496, 530, 488, 530, 496, 494, 504, 516, 484, 506, 500, 510, 486, 504, 506,
さて、ここで発見した行に対して二分探索をかけていきたいのですが、縦横無尽に駆け巡る十字が結果にノイズを加えてきます。丁寧に場合分けをし、気合を出せば答えは求まりそうですがそんなことをしている時間はありません。
ここで、生成アルゴリズムに注目すると、
ここで、点を探索するときに、
十字がある | 十字がない | |
---|---|---|
点がある | 3 | 1 |
点がない | 2 | 0 |
今、知りたいのは「点があるかどうかなので」、表の行方向に注目すると、偶数と奇数で場合分けができることがわかりました。
あとは同様に二分探索をすると1点求まります。これをもう一度繰り返すことで2点が求まり、連立方程式を解くことで
十字の位置を特定する
さて、点の位置が特定できたので次は十字の位置を特定します。ここで1/1000の確率で運良く当てると十字を特定することができちゃう(10チームが1000回やったら当たりかねない…)ので、3回パラメータを変更し、なおかつ再度特定しないと順位がリセットされる仕様にしました。これでもまだ確率的にはありえますが、妥協しました。ごめんなさい。(これ以上pを増やすと計算負荷がやばいことになるので…)
ここでは、十字の
以上がstage3の倒し方になります。
ソースコードの全体像はこちらになります。回答ファイルは自動生成されるので、このまま実行すると1018081 = 1009*1009 の値が最後の方に出ることが確認できると思います。
from typing import Optional
import random
import pickle
import os
import numpy as np
class EC:
def __init__(self, a=None, b=None):
self.p = 1009
self.field = [[0 for i in range(self.p)] for j in range(self.p)]
self.a = random.randrange(1,self.p) if a is None else a
self.b = random.randrange(1,self.p) if b is None else b
# y^2 = x^3 + a*x + b
self.points = []
for x in range(self.p):
for y in range(1, self.p):
if (y*y) % self.p == (x*x*x + self.a * x + self.b) % self.p:
self.points.append((x,y))
self.points.append((x,self.p-y))
self.field[y][x] = 1
self.field[self.p - y][x] = 1
self.x, self.y = self.points[random.randrange(0,len(self.points))]
for p in self.points:
x, y = p
assert y*y % self.p == (x*x*x + self.a * x + self.b) % self.p
assert self.y*self.y % self.p == (self.x*self.x*self.x + self.a * self.x + self.b) % self.p
def get_answer(self, x, y):
ans = [0] * (1009*1009)
for p in self.points:
ans[p[1] * self.p + p[0]] = 1
for i in range(self.p):
ans[i * self.p + x] = 1
ans[y * self.p + i] = 1
return ans
def next(self):
phi = (3*self.x*self.x + self.a) * pow(2*self.y, -1, self.p)
psi = (-3*self.x*self.x*self.x - self.a*self.x + 2 * self.y * self.y) * pow(2*self.y, -1, self.p)
self.x = (phi * phi - 2 * self.x) % self.p
self.y = (-phi * self.x - psi) % self.p
assert self.y*self.y % self.p == (self.x*self.x*self.x + self.a * self.x + self.b) % self.p
class ECProblem:
def __init__(self):
# __init__ is called only once at server startup. The answer is saved in a file so that the answer does not change.
server_ans_filename = "ec_problem.ans"
if not os.path.isfile(server_ans_filename):
with open(server_ans_filename, "wb") as f:
curvepoints = []
###########################################################################################
# ATTENSION #
# This problem changes parameters and points every 480 ticks. Please be careful #
###########################################################################################
for i in range(3):
xys = []
ec = EC()
for tick in range(480):
xys.append((ec.x, ec.y))
ec.next()
curvepoints.append({"a":ec.a, "b":ec.b, "xys":xys})
pickle.dump(curvepoints, f)
with open(server_ans_filename, "rb") as f:
self.curves = pickle.load(f)
def score(self, tick: int, your_answer) -> int:
"""
Calculate scores for submitted answers
Parameters
----------
tick: int
`tick` you given as request (1-indexed)
given_answer : int
`answer` you given as request
"""
curve_index = ((tick-1)//480)
curve = self.curves[curve_index]
ec = EC(curve["a"], curve["b"])
xys = curve["xys"]
if your_answer == None:
return 0
cnt = 0
expected = ec.get_answer(xys[tick-480*curve_index][0], xys[tick-480*curve_index][1])
for l, r in zip(your_answer, expected):
cnt += (1 if l == r else 0)
###########################################################################################
# ATTENSION #
# This problem changes parameters and points every 480 ticks. Please be careful #
###########################################################################################
return cnt * (((tick-1)//480) + 1)
import secrets
from Crypto.Util.number import *
import gmpy
import time
import random
def square_root(n, p):
n %= p
if pow(n, (p-1)>>1, p) != 1:
return -1
q = p-1; m = 0
while q & 1 == 0:
q >>= 1
m += 1
z = random.randint(1, p-1)
while pow(z, (p-1)>>1, p) == 1:
z = random.randint(1, p-1)
c = pow(z, q, p)
t = pow(n, q, p)
r = pow(n, (q+1)>>1, p)
if t == 0:
return 0
m -= 2
while t != 1:
while pow(t, 2**m, p) == 1:
c = c * c % p
m -= 1
r = r * c % p
c = c * c % p
t = t * c % p
m -= 1
return r
tickcnt = 0
server = ECProblem()
tick = 1
def score(answer):
global server
global tick
res = server.score(tick, answer)
tick += 1
return res
def get_y(x,a,b):
y = square_root(x*x*x + a*x + b, p) % p
return y
def next(x,y,a,b):
phi = (3*x*x + a) * pow(2*y, -1, p)
psi = (-3*x*x*x - a*x + 2 * y * y) * pow(2*y, -1, p)
x = (phi * phi - 2 * x) % p
y = (-phi * x - psi) % p
return x,y
p = 1009
offset = 0
prev_score = 0
posses = []
tickcnt = 0
print("phase1")
zeroy = 0
# for i in range(p):
# if score([2]*p*i + [1]*p) == 1:
# zeroy = i
# break
print("phase2")
yoffset = zeroy + 1
for _ in range(2):
for i in range(yoffset, p):
base = [2] * (p * p)
for j in range(p):
base[zeroy * p + j] = 1
res = score(base)
for j in range(p):
base[i * p + j] = 1
res = score(base)
if res == 3:
yoffset = i
break
high = 1009
low = 0
while high - low > 1:
print(f"high={high}, low={low}")
mid = (high + low) // 2
a = [1] * mid
base = [2] * (p * p)
for j in range(mid):
base[zeroy * p + j] = 1
for j in range(mid):
base[yoffset*p + j] = 1
res = score(base) % 2
if res == 0:
low = mid
else:
high = mid
A = [0] * (high-1) + [1]
y, x = yoffset, high-1
offset += high+2
prev_score += 1
posses.append((y,x))
yoffset += 1
Xs = []
for pos in posses:
Xs.append(pos[0]*pos[0] - pos[1] * pos[1] * pos[1] % p)
a = (Xs[1]-Xs[0]) * pow(posses[1][1] - posses[0][1], -1, p) % p
y, x = posses[0]
b = (y*y - x*x*x - a * x) % p
print(f"a={a}, b={b}")
print("phase3")
width = p//2
zeroy = 0
results = []
for i in range(10):
results.append(score([2]*p*zeroy + [1] * width))
result_points = []
for x in range(p):
xx = x
y = square_root(x*x*x + a * x + b, p) % p
if y*y % p == 1:
continue
print("==================================================")
print(x,y)
print("--------------------------------------------------")
print(y*y % p)
print((x*x*x + a *x + b) % p)
print("--------------------------------------------------")
assert y*y % p == (x*x*x + a *x + b) % p
if y == 0:
continue
f = True
for j in range(10):
if (True if x < width else False) != results[j]:
print("error")
f = False
print(x,y)
x, y = next(x,y,a,b)
print(x,y)
if y == 0 or x == 0:
f = False
break
assert y*y % p == (x*x*x + a*x + b) % p
if f:
result_points.append((x,y))
print("true:", x, y)
print("point:")
cands = []
for x, _ in result_points:
y = get_y(x, a, b)
print(x,y)
print(x,p-y)
cands.append((x,y))
cands.append((x,p-y))
for i in range(len(cands)):
x,y = cands[i]
print("memomemo:", i,x,y)
for j in range(i):
phi = (3*x*x + a) * pow(2*y, -1, p)
psi = (-3*x*x*x - a*x + 2 * y * y) * pow(2*y, -1, p)
x = (phi * phi - 2 * x) % p
y = (-phi * x - psi) % p
ans = [0] * (1009*1009)
for xx in range(p):
for yy in range(1, p):
if (yy*yy) % p == (xx*xx*xx + a * xx + b) % p:
ans[yy * p + xx] = 1
for i in range(p):
ans[i * p + x] = 1
ans[y * p + i] = 1
print("--------------------------------------------------")
print(score(ans))
print("--------------------------------------------------")
Discussion