👋

SECCON CTF WitchQuiz stage3 速報writeup

2023/02/16に公開

概要

stage1は1bitずつ山登りを中心とした戦略を、stage2は半数程度のチームが全問正解をすることができました。stage4は国内で解いている人がいるので、それらは動画を作りつつwriteupを待つとして、ここではstage3のwriteupを紹介します。

問題内容

ここでは、問題に取り組んだ人向けに紹介するので、雑に紹介します。
こちらの画像は(見にくいですが、1ドットの緑の点があります)楕円曲線を表していて、十字が今のtickが指している座標を表しています。
この問題では、「点」と「十字」取得することが目標です


楕円曲線のパラメータ、aとbを知る

この問題では、

y^2 = x^3 + ax + b \mod 1009

な楕円曲線が与えられており、画像上の点は上記を満たす x,y 座標を表しています。生成コードは以下のようになっています。

self.p = 1009
for x in range(self.p):
    for y in range(1, self.p):
        if (y*y) % self.p == (x*x*x + self.a * x + self.b) % self.p:
            self.points.append((x,y))
            self.points.append((x,self.p-y))

ab が未知数なので、2点の (x,y) を求めることができれば、こちらを求めることができそうです。1009個もブルートフォースで求めていたら時間がかかりすぎてしまうので、ここは二部探索で求めることにしましょう。

yを固定して探索する

二分探索をしたいのですが、そこのyに点がないと二分探索する意味がありません。まずは点のいちを調べましょう。

解答のクエリとして、以下のようなものを投げることにより、 y 行目にある1の数を数えることができます。点と十字が運悪く重なっていない限り、点が存在していれば2以上の結果が返ってくるはずです。1行に点が一個あると探索が楽そうですので、点が一個のケースを探します。(つまり、2を返すケースを探します。)

[9] * 1009 * y + [1] * 1009

このような点の数はいくつ程度あるのでしょうか?
以下のような実験コードを回すことで大まかに500個前後存在しているので、50%程度の確率で見つけることができそうです。
(ここ雑ですみません、実際にはもっと多くの数で検証していますが、未証明です。90度回転して縦を固定するケースだと直感的には50%程度になりそう(平方剰余が存在する確率と同じなので)ですが、こちらもある程度確認はしていますが未証明です🙇)

import random

for _ in range(100):
    p = 1009
    a = random.randrange(1,p)
    b = random.randrange(1,p)

    memo = {}
    for x in range(p):
        for y in range(1, p):
            if (y*y) % p == (x*x*x + a * x + b) % p:
                if y not in memo:
                    memo[y] = set()
                    memo[p-y] = set()
                memo[y].add(x)
                memo[p-y].add(x)
    
    cnt = 0
    for k,v in memo.items():
        if len(v) == 1:
            cnt += 1

    print(cnt, end=", ")
# output:
# 494, 496, 514, 488, 520, 524, 522, 524, 498, 508, 484, 506, 514, 496, 490, 528, 508, 528, 486, 506, 518, 492, 500, 492, 532, 482, 496, 508, 502, 518, 480, 514, 506, 510, 514, 502, 478, 480, 516, 528, 480, 520, 480, 518, 494, 520, 480, 508, 532, 524, 520, 492, 504, 490, 490, 522, 496, 520, 512, 504, 504, 522, 480, 516, 490, 488, 534, 502, 508, 506, 524, 498, 520, 526, 476, 484, 476, 492, 524, 490, 508, 532, 508, 516, 490, 496, 530, 488, 530, 496, 494, 504, 516, 484, 506, 500, 510, 486, 504, 506,

さて、ここで発見した行に対して二分探索をかけていきたいのですが、縦横無尽に駆け巡る十字が結果にノイズを加えてきます。丁寧に場合分けをし、気合を出せば答えは求まりそうですがそんなことをしている時間はありません。

ここで、生成アルゴリズムに注目すると、 y=0 では点が生成されていないことがわかります。これは零点になるため、点が動かなくなってしまうことを避けるための処置ですが、逆に言うとこの行には点が存在していないことが保証されています。(もしこれが思いつかなくても、0となる行を同様に探せばOKです)

ここで、点を探索するときに、 y=0 の行と、狙いたい行を同時に1を入れて探索します。すると、以下のような結果が帰ってきます

十字がある 十字がない
点がある 3 1
点がない 2 0

今、知りたいのは「点があるかどうかなので」、表の行方向に注目すると、偶数と奇数で場合分けができることがわかりました。
あとは同様に二分探索をすると1点求まります。これをもう一度繰り返すことで2点が求まり、連立方程式を解くことでa,bを求めることができます。

十字の位置を特定する

さて、点の位置が特定できたので次は十字の位置を特定します。ここで1/1000の確率で運良く当てると十字を特定することができちゃう(10チームが1000回やったら当たりかねない…)ので、3回パラメータを変更し、なおかつ再度特定しないと順位がリセットされる仕様にしました。これでもまだ確率的にはありえますが、妥協しました。ごめんなさい。(これ以上pを増やすと計算負荷がやばいことになるので…)

ここでは、十字の x を特定しにいきます。xを特定すると2択に抑え込めるため、かなり高い確率で当てることができます。

x は先程使った y=0 を利用して特定します。 y=0 の左半分だけを1で埋め尽くし、10回程度繰り返すと、xの値が500以下か、それ以上かがわかるようになります。同じ移動方法をする点の位置を全列挙すると候補がそこそこ絞れるので、そこから先は全探索で求めます。

以上がstage3の倒し方になります。

ソースコードの全体像はこちらになります。回答ファイルは自動生成されるので、このまま実行すると1018081 = 1009*1009 の値が最後の方に出ることが確認できると思います。

from typing import Optional
import random
import pickle
import os
import numpy as np

class EC:
    def __init__(self, a=None, b=None):
        self.p = 1009
        self.field = [[0 for i in range(self.p)] for j in range(self.p)]

        self.a = random.randrange(1,self.p) if a is None else a
        self.b = random.randrange(1,self.p) if b is None else b

        # y^2 = x^3 + a*x + b
        self.points = []

        for x in range(self.p):
            for y in range(1, self.p):
                if (y*y) % self.p == (x*x*x + self.a * x + self.b) % self.p:
                    self.points.append((x,y))
                    self.points.append((x,self.p-y))
                    self.field[y][x] = 1
                    self.field[self.p - y][x] = 1

        self.x, self.y = self.points[random.randrange(0,len(self.points))]

        for p in self.points:
            x, y = p
            assert y*y % self.p == (x*x*x + self.a * x + self.b) % self.p
        
        assert self.y*self.y % self.p == (self.x*self.x*self.x + self.a * self.x + self.b) % self.p

    def get_answer(self, x, y):
        ans = [0] * (1009*1009)
        for p in self.points:
            ans[p[1] * self.p + p[0]] = 1

        for i in range(self.p):
            ans[i * self.p + x] = 1
            ans[y * self.p + i] = 1
        return ans 
    
    def next(self):
        phi = (3*self.x*self.x + self.a) * pow(2*self.y, -1, self.p)
        psi = (-3*self.x*self.x*self.x - self.a*self.x + 2 * self.y * self.y) * pow(2*self.y, -1, self.p)
        self.x = (phi * phi - 2 * self.x) % self.p
        self.y = (-phi * self.x - psi) % self.p
        assert self.y*self.y % self.p == (self.x*self.x*self.x + self.a * self.x + self.b) % self.p
    
class ECProblem:
    def __init__(self):
        # __init__ is called only once at server startup. The answer is saved in a file so that the answer does not change.
        server_ans_filename = "ec_problem.ans"

        if not os.path.isfile(server_ans_filename):
            with open(server_ans_filename, "wb") as f:
                curvepoints = []
###########################################################################################
#                                  ATTENSION                                              #
# This problem changes parameters and points every 480 ticks. Please be careful           #
###########################################################################################
                for i in range(3):
                    xys = []
                    ec = EC()
                    for tick in range(480):
                        xys.append((ec.x, ec.y))
                        ec.next()
                    curvepoints.append({"a":ec.a, "b":ec.b, "xys":xys})
                pickle.dump(curvepoints, f)

        with open(server_ans_filename, "rb") as f:
            self.curves = pickle.load(f)

    def score(self, tick: int, your_answer) -> int:
        """
        Calculate scores for submitted answers
        
        Parameters
        ----------
        tick: int
            `tick` you given as request (1-indexed)
        given_answer : int
            `answer` you given as request
        """
        curve_index = ((tick-1)//480)
        curve = self.curves[curve_index]

        ec = EC(curve["a"], curve["b"])
        xys = curve["xys"]

        if your_answer == None:
            return 0

        cnt = 0
        expected = ec.get_answer(xys[tick-480*curve_index][0], xys[tick-480*curve_index][1])
        for l, r in zip(your_answer, expected):
            cnt += (1 if l == r else 0)
        
        ###########################################################################################
        #                                  ATTENSION                                              #
        # This problem changes parameters and points every 480 ticks. Please be careful           #
        ###########################################################################################
        return cnt * (((tick-1)//480) + 1)

import secrets
from Crypto.Util.number import *
import gmpy
import time
import random

def square_root(n, p):
    n %= p
    if pow(n, (p-1)>>1, p) != 1:
        return -1
    q = p-1; m = 0
    while q & 1 == 0:
        q >>= 1
        m += 1
    z = random.randint(1, p-1)
    while pow(z, (p-1)>>1, p) == 1:
        z = random.randint(1, p-1)
    c = pow(z, q, p)
    t = pow(n, q, p)
    r = pow(n, (q+1)>>1, p)
    if t == 0:
        return 0
    m -= 2
    while t != 1:
        while pow(t, 2**m, p) == 1:
            c = c * c % p
            m -= 1
        r = r * c % p
        c = c * c % p
        t = t * c % p
        m -= 1
    return r

tickcnt = 0

server = ECProblem()
tick = 1

def score(answer):
    global server
    global tick
    res = server.score(tick, answer)
    tick += 1
    return res

def get_y(x,a,b):
    y = square_root(x*x*x + a*x + b, p) % p
    return y

def next(x,y,a,b):
    phi = (3*x*x + a) * pow(2*y, -1, p)
    psi = (-3*x*x*x - a*x + 2 * y * y) * pow(2*y, -1, p)
    x = (phi * phi - 2 * x) % p
    y = (-phi * x - psi) % p
    return x,y

p = 1009
offset = 0
prev_score = 0
posses = []
tickcnt = 0

print("phase1")
zeroy = 0
# for i in range(p):
#     if score([2]*p*i + [1]*p) == 1:
#         zeroy = i
#         break

print("phase2")
yoffset = zeroy + 1
for _ in range(2):
    for i in range(yoffset, p):
        base = [2] * (p * p)
        for j in range(p):
            base[zeroy * p + j] = 1

        res = score(base)

        for j in range(p):
            base[i * p + j] = 1
        res = score(base)

        if res == 3:
            yoffset = i
            break

    high = 1009
    low = 0 
    while high - low > 1:
        print(f"high={high}, low={low}")
        mid = (high + low) // 2
        a = [1] * mid

        base = [2] * (p * p)

        for j in range(mid):
            base[zeroy * p + j] = 1
        for j in range(mid):
            base[yoffset*p + j] = 1

        res = score(base) % 2
        if res == 0:
            low = mid
        else:
            high = mid

    A = [0] * (high-1) + [1]
    y, x = yoffset, high-1
    offset += high+2
    prev_score += 1
    posses.append((y,x))
    yoffset += 1

Xs = []
for pos in posses:
    Xs.append(pos[0]*pos[0] - pos[1] * pos[1] * pos[1] % p)

a = (Xs[1]-Xs[0]) * pow(posses[1][1] - posses[0][1], -1, p) % p
y, x = posses[0]
b = (y*y - x*x*x - a * x) % p
print(f"a={a}, b={b}")

print("phase3")

width = p//2
zeroy = 0
results = []
for i in range(10):
    results.append(score([2]*p*zeroy + [1] * width))

result_points = []
for x in range(p):
    xx = x
    y = square_root(x*x*x + a * x + b, p) % p
    if y*y % p == 1:
        continue

    print("==================================================")
    print(x,y)
    print("--------------------------------------------------")
    print(y*y % p)
    print((x*x*x + a *x + b) % p)
    print("--------------------------------------------------")

    assert y*y % p == (x*x*x + a *x + b) % p
    if y == 0:
        continue
    f = True
    for j in range(10):
        if (True if x < width else False) != results[j]:
            print("error")
            f = False
        print(x,y)

        x, y = next(x,y,a,b)
        print(x,y)
        if y == 0 or x == 0:
           f = False 
           break
        assert y*y % p == (x*x*x + a*x + b) % p
    if f:
        result_points.append((x,y))
        print("true:", x, y)

print("point:")
cands = []
for x, _ in result_points:
    y = get_y(x, a, b)
    print(x,y)
    print(x,p-y)
    cands.append((x,y))
    cands.append((x,p-y))

for i in range(len(cands)):
    x,y = cands[i]
    print("memomemo:", i,x,y)

    for j in range(i):
        phi = (3*x*x + a) * pow(2*y, -1, p)
        psi = (-3*x*x*x - a*x + 2 * y * y) * pow(2*y, -1, p)
        x = (phi * phi - 2 * x) % p
        y = (-phi * x - psi) % p

    ans = [0] * (1009*1009)
    for xx in range(p):
        for yy in range(1, p):
            if (yy*yy) % p == (xx*xx*xx + a * xx + b) % p:
                ans[yy * p + xx] = 1

    for i in range(p):
        ans[i * p + x] = 1
        ans[y * p + i] = 1

    print("--------------------------------------------------")
    print(score(ans))
    print("--------------------------------------------------")

Discussion