2023/09/25に公開

# problem

import random
import Crypto.Cipher.AES as AES
from Crypto.Util.number import long_to_bytes
import hashlib
from flag import FLAG

R.<x, y> = PolynomialRing(GF(2))
size = 2^8
K.<alpha> = GF(size, modulus=x^8+x^4+x^3+x^2+1)

def make_human_world(human_world_size):
H = 0
for i in range(human_world_size):
for j in range(human_world_size):
H += (x^i) * (y^j) * (alpha^(random.randint(0,size-2)))
return H

def make_spirit_world(H, spirit_world_size_param):
Gx = prod(x-alpha^i for i in range(1,spirit_world_size_param+1))
Gy = prod(y-alpha^i for i in range(1,spirit_world_size_param+1))
return H % (Gx + Gy)

def make_disharmony(C, count):
x_set = set()

D = 0
for i in range(count):
r = random.randint(0,size-2)
p = random.choice(list(C.dict().keys()))
while p[0] in x_set:
p = random.choice(list(C.dict().keys()))
D += (x^p[0]) * (y^p[1]) * alpha^(r)
return D

def make_key(D):
key_seed = b""
for pos, value in sorted(list(D.dict().items())):
x = pos[0]
y = pos[1]
power = discrete_log(value, alpha, size-1)
key_seed += long_to_bytes(x) + long_to_bytes(y) + long_to_bytes(power)
m = hashlib.sha256()
m.update(key_seed)
return m.digest()

def get_polynomial_dict(C):
res = C.dict()
for key in res:
res[key] = discrete_log(res[key], alpha, size-1)
return res

human_world_size = 64
spirit_world_size_param = 32
disharmony_count = 16

H = make_human_world(human_world_size)
S = make_spirit_world(H, spirit_world_size_param)
World = H+S
D = make_disharmony(World, disharmony_count)
C = H+S+D

key = make_key(D)
cipher = AES.new( key, AES.MODE_ECB )

print("# Witch making the map! please wait.", flush=True)
witch_map = []
for i in range(spirit_world_size_param):
row = []
C_y = C(y=alpha^(i+1))
for j in range(spirit_world_size_param):
temp = C_y(x=alpha^(j+1))
if temp == 0:
row.append(None)
else:
row.append(discrete_log(temp, alpha, size-1))
witch_map.append(row)
print("witch_map=", witch_map)
print("treasure_box=", cipher.encrypt(FLAG))


Considering that H = S \mod (G_x + G_y), we can get the following equation;

C(\alpha^i, \alpha^j) = D(\alpha^i, \alpha^j) (1 \leq i,j \leq 32)

Here, we write D as follows.

\begin{aligned} D(x,y) &= \sum_{i=0}^{n-1} d_i x^{e_{x_i}} y^{e_{y_i}} \\ &\mathrm{where\ } n = 16\ (\mathrm{the\ number\ of\ errors}) \end{aligned}

If we pay attention to the nature of this problem, we will notice that it is similar to the
Reed–Solomon error correction (ref: https://en.wikipedia.org/wiki/Reed–Solomon_error_correction).

However, this problem appears to be bivariate. How do we solve it?
The answer is to enter an value within the [\alpha^1,\alpha^{32}] range!
Here, Let's assign y=\alpha.
Then, D(x,y) becomes as follows.

\begin{aligned} D(x) &= \sum_{i=1}^{n} d_i \alpha^{e_{y_i}} x^{e_{x_i}} \\ &= \sum_{i=1}^{n} e_i x^{e_{x_i}} \\ & \mathrm{where}\ e_i = d_i \alpha^{e_{y_i}} \end{aligned}

We were able to transform it into a one-dimensional Reed-Solomon code :)
e_{x_i} and e_{y_i} can each be obtained in this way.

However, we now know the candidates for e_{x_i} and e_{y_i}, but the problem arises that we do not know the combination.
For example, if we can get these candidates;

\begin{aligned} e_x &= \{1,2,3,4,5\} \\ e_y &= \{5,6,7,8\} \end{aligned}

But we cannot know which of (x,y)=(1,5) and (x,y)=(1,6) is correct.
There are 16 candidates, so the number of combinations is still high.

Let's consider the x=y

\begin{aligned} D(x,x) &= \sum_{i=1}^{n} d_i x^{\left(e_{x_i} + e_{y_i}\right)} \\ \end{aligned}

Therefore, by selecting which_map so that x=y (for example (1,1), (2,2), (3,3), ...), we can obtain the values of x+y.

Considering the following conditions of make_disharmony,

        while p[0] in x_set:
p = random.choice(list(C.dict().keys()))


Considering the above conditions of make_disharmony and that all y values are used, we can reduce the number of candidates, and by exploring, we can obtain e_{x_i} and e_{y_i} every few times.

The calculation of the coefficients is the same as in Reed-Solomon, so I'll leave it out. (sorry!)

It can be solved with the following code :)
(I'm sorry for super dirty code)

import random
import output
import copy

human_world_size = 64
spirit_world_size_param = 32
disharmony_count = 16
t = spirit_world_size_param // 2
n = human_world_size + spirit_world_size_param

# ref: https://www.jstage.jst.go.jp/article/essfr/4/3/4_3_183/_pdf
R.<x, y> = PolynomialRing(GF(2))
# K.<a> = GF(2^8, name='x', modulus=x^8+x^4+x^3+x^2+1)
size = 2^8
K.<alpha> = GF(size, modulus=x^8+x^4+x^3+x^2+1)

def solve_sigma(S):
def make_mat(S, k):
mat = [[0 for i in range(k)] for i in range(k)]
for i in range(k):
for j in range(k):
mat[i][j] = S[k-j-1+i]
return matrix(mat)

mat = None
error_cnt = None
for k in range(t):
print("progress:", k)
temp = make_mat(S, k+1)
if temp.det() == 0:
error_cnt = k
break
mat = temp
error_cnt = k+1

vec = []
for i in range(error_cnt):
vec.append(S[error_cnt + i])

temp = mat.solve_right(vec)
sigma = 1
for i in range(len(temp)):
sigma += temp[i] * x^(i+1)

return sigma, error_cnt

def find_x(alpha_map):
S = [alpha_map[0][i] for i in range(2*t)]
sigma, error_cnt = solve_sigma(S)

error_pos = []
for i in range(n):
if sigma(alpha^(size-i-1), 0) == 0:
error_pos.append(i)

mat = [[0 for i in range(error_cnt)] for i in range(error_cnt)]
vec = [0 for i in range(error_cnt)]
for i in range(error_cnt):
for j in range(error_cnt):
mat[i][j] = alpha^(error_pos[j] * (i+1))
vec[i] = S[i]
mat = matrix(mat)
vec = vector(vec)

temp = mat.solve_right(vec)
print(temp)

return error_pos, temp

def find_y(alpha_map):
S = [alpha_map[i][0] for i in range(2*t)]
sigma, _ = solve_sigma(S)
error_pos = []
for i in range(n):
if sigma(alpha^(size-i-1), 0) == 0:
error_pos.append(i)

return error_pos

def find_xy(alpha_map):
S = [alpha_map[i][i] for i in range(2*t)]

sigma, _ = solve_sigma(S)
error_pos = []
for i in range(2*n):
if sigma(alpha^(size-i-1), 0) == 0:
error_pos.append(i)

return error_pos

def witchmap_to_alphamap(map):
res = []
for i in range(len(map)):
row = []
for j in range(len(map[i])):
if map[i][j] == None:
row.append(0)
else:
row.append(alpha^map[i][j])
res.append(row)
return res

alpha_map = witchmap_to_alphamap(output.witch_map)
print(alpha_map)
x_pos, error_value = find_x(alpha_map)
y_pos = find_y(alpha_map)
xy_pos = find_xy(alpha_map)
print("x_pos=", x_pos)
print("y_pos=", y_pos)
print("xy_pos=", xy_pos)
x_pos_index = {}
for i in range(len(x_pos)):
x_pos_index[x_pos[i]] = i

cand = []
xmemo = {}
for xp in x_pos:
for yp in y_pos:
if xp+yp in xy_pos:
if xp not in xmemo:
xmemo[xp] = []
xmemo[xp].append(yp)

yprev = {}
xyprev = {}
for yp in y_pos:
yprev[yp] = 0
for xyp in xy_pos:
xyprev[xyp] = 0
ycnt = [copy.deepcopy(yprev)]
xycnt = [copy.deepcopy(xyprev)]

for xp in reversed(x_pos):
ytemp = copy.deepcopy(yprev)
xytemp =copy.deepcopy(xyprev)
for yp in xmemo[xp]:
ytemp[yp] += 1
xytemp[xp+yp] += 1
ycnt.append(ytemp)
xycnt.append(xytemp)
yprev = copy.deepcopy(ytemp)
xyprev = copy.deepcopy(xytemp)

ycnt = list(reversed(ycnt))
xycnt = list(reversed(xycnt))
for xy in xycnt:
print(xy)

def dfs(x_pos, xmemo, index, error_value, expected_E, res, ycnt, xycnt, yused, xyused):
if index == len(x_pos):
res.append(expected_E)
return
xp = x_pos[index]

cand = set()
for yp in xmemo[xp]:
# もう次以降でなくて、まだ一回も使われていないなら、それを使わなければならない
if (ycnt[index+1][yp] == 0 and yused[yp] == 0) or (xycnt[index+1][xp+yp] == 0 and xyused[xp+yp] == 0):
print("--------------------------------------------------")
print(xp)
print(yused)
print(cand)
if len(cand) >= 2:
return
if len(cand) == 0:
cand = xmemo[xp]

for yp in cand:
temp = x^xp * y^yp * error_value[index] / alpha^yp
expected_E += temp
yused[yp] += 1
xyused[xp+yp] += 1
dfs(x_pos, xmemo, index+1, error_value, expected_E, res, ycnt, xycnt, yused, xyused)
xyused[xp+yp] -= 1
yused[yp] -= 1
expected_E -= temp

res_cnt = 1
for _, value in xmemo.items():
res_cnt *= len(value)
print("res_cnt:", res_cnt)
if res_cnt > 10000000:
exit(1)

expected_E = 0
res = []
yused = {}
xyused = {}
for yp in y_pos:
yused[yp] = 0
for xyp in xy_pos:
xyused[xyp] = 0

dfs(x_pos, xmemo, 0, error_value, expected_E, res, ycnt, xycnt, yused, xyused)
print("res cnt", len(res))

import Crypto.Cipher.AES as AES
from Crypto.Util.number import long_to_bytes
import hashlib

def make_key(D):
key_seed = b""
for pos, value in sorted(list(D.dict().items())):
print(pos)
x = pos[0]
y = pos[1]
power = discrete_log(value, alpha, size-1)
key_seed += long_to_bytes(x) + long_to_bytes(y) + long_to_bytes(power)
m = hashlib.sha256()
m.update(key_seed)
return m.digest()

cand = []
for i in range(len(res)):
print(i, "/", len(res))
r = res[i]

key = make_key(r.numerator())
cipher = AES.new( key, AES.MODE_ECB )
flag = cipher.decrypt(output.treasure_box)
cand.append(flag)
if b"SECCON{" in flag:
print(cand)
print(flag)
break
print(cand)