SECCON CTF 2023 mystic_harmony writeup

2023/09/25に公開

problem

import random
import Crypto.Cipher.AES as AES
from Crypto.Util.number import long_to_bytes
import hashlib
from flag import FLAG

R.<x, y> = PolynomialRing(GF(2))
size = 2^8
K.<alpha> = GF(size, modulus=x^8+x^4+x^3+x^2+1)

def make_human_world(human_world_size):
    H = 0
    for i in range(human_world_size):
        for j in range(human_world_size):
            H += (x^i) * (y^j) * (alpha^(random.randint(0,size-2)))
    return H

def make_spirit_world(H, spirit_world_size_param):
    Gx = prod(x-alpha^i for i in range(1,spirit_world_size_param+1))
    Gy = prod(y-alpha^i for i in range(1,spirit_world_size_param+1))
    return H % (Gx + Gy)

def make_disharmony(C, count):
    x_set = set()

    D = 0
    for i in range(count):
        r = random.randint(0,size-2)
        p = random.choice(list(C.dict().keys()))
        while p[0] in x_set:
            p = random.choice(list(C.dict().keys()))
        x_set.add(p[0])
        D += (x^p[0]) * (y^p[1]) * alpha^(r)
    return D

def make_key(D):
    key_seed = b""
    for pos, value in sorted(list(D.dict().items())):
        x = pos[0]
        y = pos[1]
        power = discrete_log(value, alpha, size-1)
        key_seed += long_to_bytes(x) + long_to_bytes(y) + long_to_bytes(power)
    m = hashlib.sha256()
    m.update(key_seed)
    return m.digest()

def get_polynomial_dict(C):
    res = C.dict()
    for key in res:
        res[key] = discrete_log(res[key], alpha, size-1)
    return res

human_world_size = 64
spirit_world_size_param = 32
disharmony_count = 16

H = make_human_world(human_world_size)
S = make_spirit_world(H, spirit_world_size_param)
World = H+S
D = make_disharmony(World, disharmony_count)
C = H+S+D

key = make_key(D)
cipher = AES.new( key, AES.MODE_ECB )

print("# Witch making the map! please wait.", flush=True)
witch_map = []
for i in range(spirit_world_size_param):
    row = []
    C_y = C(y=alpha^(i+1))
    for j in range(spirit_world_size_param):
        temp = C_y(x=alpha^(j+1))
        if temp == 0:
            row.append(None)
        else:
            row.append(discrete_log(temp, alpha, size-1))
    witch_map.append(row)
print("witch_map=", witch_map)
print("treasure_box=", cipher.encrypt(FLAG))

Considering that H = S \mod (G_x + G_y), we can get the following equation;

C(\alpha^i, \alpha^j) = D(\alpha^i, \alpha^j) (1 \leq i,j \leq 32)

Here, we write D as follows.

\begin{aligned} D(x,y) &= \sum_{i=0}^{n-1} d_i x^{e_{x_i}} y^{e_{y_i}} \\ &\mathrm{where\ } n = 16\ (\mathrm{the\ number\ of\ errors}) \end{aligned}

If we pay attention to the nature of this problem, we will notice that it is similar to the
Reed–Solomon error correction (ref: https://en.wikipedia.org/wiki/Reed–Solomon_error_correction).

However, this problem appears to be bivariate. How do we solve it?
The answer is to enter an value within the [\alpha^1,\alpha^{32}] range!
Here, Let's assign y=\alpha.
Then, D(x,y) becomes as follows.

\begin{aligned} D(x) &= \sum_{i=1}^{n} d_i \alpha^{e_{y_i}} x^{e_{x_i}} \\ &= \sum_{i=1}^{n} e_i x^{e_{x_i}} \\ & \mathrm{where}\ e_i = d_i \alpha^{e_{y_i}} \end{aligned}

We were able to transform it into a one-dimensional Reed-Solomon code :)
e_{x_i} and e_{y_i} can each be obtained in this way.

However, we now know the candidates for e_{x_i} and e_{y_i}, but the problem arises that we do not know the combination.
For example, if we can get these candidates;

\begin{aligned} e_x &= \{1,2,3,4,5\} \\ e_y &= \{5,6,7,8\} \end{aligned}

But we cannot know which of (x,y)=(1,5) and (x,y)=(1,6) is correct.
There are 16 candidates, so the number of combinations is still high.

Let's consider the x=y

\begin{aligned} D(x,x) &= \sum_{i=1}^{n} d_i x^{\left(e_{x_i} + e_{y_i}\right)} \\ \end{aligned}

Therefore, by selecting which_map so that x=y (for example (1,1), (2,2), (3,3), ...), we can obtain the values of x+y.

Considering the following conditions of make_disharmony,

        while p[0] in x_set:
            p = random.choice(list(C.dict().keys()))

Considering the above conditions of make_disharmony and that all y values are used, we can reduce the number of candidates, and by exploring, we can obtain e_{x_i} and e_{y_i} every few times.

The calculation of the coefficients is the same as in Reed-Solomon, so I'll leave it out. (sorry!)

It can be solved with the following code :)
(I'm sorry for super dirty code)

import random
import output
import copy

human_world_size = 64
spirit_world_size_param = 32
disharmony_count = 16
t = spirit_world_size_param // 2
n = human_world_size + spirit_world_size_param

# ref: https://www.jstage.jst.go.jp/article/essfr/4/3/4_3_183/_pdf
R.<x, y> = PolynomialRing(GF(2))
# K.<a> = GF(2^8, name='x', modulus=x^8+x^4+x^3+x^2+1)
size = 2^8
K.<alpha> = GF(size, modulus=x^8+x^4+x^3+x^2+1)

def solve_sigma(S):
    def make_mat(S, k):
        mat = [[0 for i in range(k)] for i in range(k)]
        for i in range(k):
            for j in range(k):
                mat[i][j] = S[k-j-1+i]
        return matrix(mat)

    mat = None
    error_cnt = None
    for k in range(t):
        print("progress:", k)
        temp = make_mat(S, k+1)
        if temp.det() == 0:
            error_cnt = k
            break
        mat = temp
        error_cnt = k+1

    vec = []
    for i in range(error_cnt):
        vec.append(S[error_cnt + i])

    temp = mat.solve_right(vec)
    sigma = 1
    for i in range(len(temp)):
        sigma += temp[i] * x^(i+1)

    return sigma, error_cnt

def find_x(alpha_map):
    S = [alpha_map[0][i] for i in range(2*t)]
    sigma, error_cnt = solve_sigma(S)

    error_pos = []
    for i in range(n):
        if sigma(alpha^(size-i-1), 0) == 0:
            error_pos.append(i)

    mat = [[0 for i in range(error_cnt)] for i in range(error_cnt)]
    vec = [0 for i in range(error_cnt)]
    for i in range(error_cnt):
        for j in range(error_cnt):
            mat[i][j] = alpha^(error_pos[j] * (i+1))
        vec[i] = S[i]
    mat = matrix(mat)
    vec = vector(vec)

    temp = mat.solve_right(vec)
    print(temp)

    return error_pos, temp

def find_y(alpha_map):
    S = [alpha_map[i][0] for i in range(2*t)]
    sigma, _ = solve_sigma(S)
    error_pos = []
    for i in range(n):
        if sigma(alpha^(size-i-1), 0) == 0:
            error_pos.append(i)

    return error_pos

def find_xy(alpha_map):
    S = [alpha_map[i][i] for i in range(2*t)]

    sigma, _ = solve_sigma(S)
    error_pos = []
    for i in range(2*n):
        if sigma(alpha^(size-i-1), 0) == 0:
            error_pos.append(i)

    return error_pos

def witchmap_to_alphamap(map):
    res = []
    for i in range(len(map)):
        row = []
        for j in range(len(map[i])):
            if map[i][j] == None:
                row.append(0)
            else:
                row.append(alpha^map[i][j])
        res.append(row)
    return res

alpha_map = witchmap_to_alphamap(output.witch_map)
print(alpha_map)
x_pos, error_value = find_x(alpha_map)
y_pos = find_y(alpha_map)
xy_pos = find_xy(alpha_map)
print("x_pos=", x_pos)
print("y_pos=", y_pos)
print("xy_pos=", xy_pos)
x_pos_index = {}
for i in range(len(x_pos)):
    x_pos_index[x_pos[i]] = i

cand = []
xmemo = {}
for xp in x_pos:
    for yp in y_pos:
        if xp+yp in xy_pos:
            if xp not in xmemo:
                xmemo[xp] = []
            xmemo[xp].append(yp)

yprev = {}
xyprev = {}
for yp in y_pos:
    yprev[yp] = 0
for xyp in xy_pos:
    xyprev[xyp] = 0
ycnt = [copy.deepcopy(yprev)]
xycnt = [copy.deepcopy(xyprev)]

for xp in reversed(x_pos):
    ytemp = copy.deepcopy(yprev)
    xytemp =copy.deepcopy(xyprev) 
    for yp in xmemo[xp]:
        ytemp[yp] += 1
        xytemp[xp+yp] += 1
    ycnt.append(ytemp)
    xycnt.append(xytemp)
    yprev = copy.deepcopy(ytemp)
    xyprev = copy.deepcopy(xytemp)

ycnt = list(reversed(ycnt))
xycnt = list(reversed(xycnt))
for xy in xycnt:
    print(xy)

def dfs(x_pos, xmemo, index, error_value, expected_E, res, ycnt, xycnt, yused, xyused):
    if index == len(x_pos):
        res.append(expected_E)
        return
    xp = x_pos[index]

    cand = set()
    for yp in xmemo[xp]:
        # もう次以降でなくて、まだ一回も使われていないなら、それを使わなければならない
        if (ycnt[index+1][yp] == 0 and yused[yp] == 0) or (xycnt[index+1][xp+yp] == 0 and xyused[xp+yp] == 0):
            cand.add(yp)
    print("--------------------------------------------------")
    print(xp)
    print(yused)
    print(cand)
    if len(cand) >= 2:
        return
    if len(cand) == 0:
        cand = xmemo[xp]

    for yp in cand:
        temp = x^xp * y^yp * error_value[index] / alpha^yp
        expected_E += temp
        yused[yp] += 1
        xyused[xp+yp] += 1
        dfs(x_pos, xmemo, index+1, error_value, expected_E, res, ycnt, xycnt, yused, xyused)
        xyused[xp+yp] -= 1
        yused[yp] -= 1
        expected_E -= temp

res_cnt = 1
for _, value in xmemo.items():
    res_cnt *= len(value)
print("res_cnt:", res_cnt)
if res_cnt > 10000000:
    exit(1)

expected_E = 0
res = []
yused = {}
xyused = {}
for yp in y_pos:
    yused[yp] = 0
for xyp in xy_pos:
    xyused[xyp] = 0

dfs(x_pos, xmemo, 0, error_value, expected_E, res, ycnt, xycnt, yused, xyused)
print("res cnt", len(res))

import Crypto.Cipher.AES as AES
from Crypto.Util.number import long_to_bytes
import hashlib

def make_key(D):
    key_seed = b""
    for pos, value in sorted(list(D.dict().items())):
        print(pos)
        x = pos[0]
        y = pos[1]
        power = discrete_log(value, alpha, size-1)
        key_seed += long_to_bytes(x) + long_to_bytes(y) + long_to_bytes(power)
    m = hashlib.sha256()
    m.update(key_seed)
    return m.digest()

cand = []
for i in range(len(res)):
    print(i, "/", len(res))
    r = res[i]

    key = make_key(r.numerator())
    cipher = AES.new( key, AES.MODE_ECB )
    flag = cipher.decrypt(output.treasure_box)
    cand.append(flag)
    if b"SECCON{" in flag:
        print(cand)
        print(flag)
        break
print(cand)

Discussion