I am using Windows 11 and I plan to implement the code in C++. As you might know, building C++ libraries on Windows is very complicated, so I want to make sure it uses the least amount of dependencies possible.
For more context about the larger project this will be a part of, see this.
I have decided to implement a load-balancing ShadowSocks5 proxy in C++ from scratch. This is a programming challenge, a learning project, and a practical project all in one.
I decided to start from the easiest problem, the encryption I need to use is AES-256-GCM. GCM stands for Galois/Counter_Mode, I haven't figured out how to implement it from reading the Wikipedia article. But the Wikipedia article on Advanced Encryption Standard is very helpful and is one of the primary references I used in implementing this. Another Wikipedia article I referenced is AES key schedule. I got the values for SBOX and RSBOX from this article
Now, here is the implementation, I wrote it all by myself, an effort that took two days:
import json
with open("D:/AES_256.json", "r") as f:
AES_256 = json.load(f)
MAX_256 = (1 << 256) - 1
SBOX = AES_256["SBOX"]
RCON = AES_256["RCON"]
OCTUPLE = (
(0, 4),
(4, 8),
(8, 12),
(12, 16),
(16, 20),
(20, 24),
(24, 28),
(28, 32),
)
SEXAGESY = (
(0, 1, 2, 3),
(4, 5, 6, 7),
(8, 9, 10, 11),
(12, 13, 14, 15),
(16, 17, 18, 19),
(20, 21, 22, 23),
(24, 25, 26, 27),
(28, 29, 30, 31),
(32, 33, 34, 35),
(36, 37, 38, 39),
(40, 41, 42, 43),
(44, 45, 46, 47),
(48, 49, 50, 51),
(52, 53, 54, 55),
(56, 57, 58, 59),
)
HEXMAT = (0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)
SHIFT = (0, 1, 2, 3, 5, 6, 7, 4, 10, 11, 8, 9, 15, 12, 13, 14)
COLMAT = (
((2, 0), (1, 4), (1, 8), (3, 12)),
((2, 1), (1, 5), (1, 9), (3, 13)),
((2, 2), (1, 6), (1, 10), (3, 14)),
((2, 3), (1, 7), (1, 11), (3, 15)),
((3, 0), (2, 4), (1, 8), (1, 12)),
((3, 1), (2, 5), (1, 9), (1, 13)),
((3, 2), (2, 6), (1, 10), (1, 14)),
((3, 3), (2, 7), (1, 11), (1, 15)),
((1, 0), (3, 4), (2, 8), (1, 12)),
((1, 1), (3, 5), (2, 9), (1, 13)),
((1, 2), (3, 6), (2, 10), (1, 14)),
((1, 3), (3, 7), (2, 11), (1, 15)),
((1, 0), (1, 4), (3, 8), (2, 12)),
((1, 1), (1, 5), (3, 9), (2, 13)),
((1, 2), (1, 6), (3, 10), (2, 14)),
((1, 3), (1, 7), (3, 11), (2, 15)),
)
def add_round_key(state: list, key: list) -> list:
return [a ^ b for a, b in zip(state, key)]
def state_matrix(data: list) -> list:
return [data[i] for i in HEXMAT]
def sub_bytes(state: list) -> list:
return [SBOX[i] for i in state]
def shift_rows(state: list) -> list:
return [state[i] for i in SHIFT]
def rot8(byte: int, x: int) -> int:
x &= 7
return (byte << x | byte >> (8 - x)) & 0xFF
def quadword(quad: bytes) -> int:
a, b, c, d = quad
return (a << 24) | (b << 16) | (c << 8) | d
def rot_word(word: int) -> int:
return (word << 8 | word >> 24) & 0xFFFFFFFF
def sub_word(word: int) -> int:
return (
(SBOX[(word >> 24) & 0xFF] << 24)
| (SBOX[(word >> 16) & 0xFF] << 16)
| (SBOX[(word >> 8) & 0xFF] << 8)
| SBOX[word & 0xFF]
)
def galois_mult(x: int, y: int) -> int:
p = 0
while x and y:
if y & 1:
p ^= x
if x & 0x80:
x = (x << 1) ^ 0x11B
else:
x <<= 1
y >>= 1
return p
def mix_columns(state: list) -> list:
result = [0] * 16
for e, row in zip(state, COLMAT):
for mult, i in row:
result[i] ^= galois_mult(e, mult)
return state_matrix(result)
def key_matrix(key: list) -> list:
mat = []
for row in SEXAGESY:
line = []
for col in row:
n = key[col]
line.extend([n >> 24 & 0xFF, n >> 16 & 0xFF, n >> 8 & 0xFF, n & 0xFF])
mat.append(state_matrix(line))
return mat
def derive_key(password: bytes) -> list:
keys = [quadword(password[a:b]) for a, b in OCTUPLE]
result = keys.copy()
last = result[7]
for i in range(8, 60):
if not i & 7:
last = sub_word(rot_word(last)) ^ (RCON[i // 8] << 24)
elif i & 7 == 4:
last = sub_word(last)
key = result[i - 8] ^ last
result.append(key)
last = key
return key_matrix(result)
def aes_256_cipher(data: bytes, password: bytes) -> list:
state = add_round_key(state_matrix(data), password[0])
for i in range(1, 14):
state = add_round_key(mix_columns(shift_rows(sub_bytes(state))), password[i])
return state_matrix(add_round_key(shift_rows(sub_bytes(state)), password[14]))
def get_padded_data(data: bytes | str) -> bytes:
if isinstance(data, str):
data = data.encode("utf8")
if not isinstance(data, bytes):
raise ValueError("argument data must be bytes or str")
return data + b"\x00" * (16 - len(data) % 16)
def get_key(password: bytes | int | str) -> list:
if isinstance(password, int):
if password < 0 or password > MAX_256:
raise ValueError("argument password must be between 0 and 2^256-1")
password = password.to_bytes(32, "big")
if isinstance(password, str):
password = "".join(i for i in password if i.isalnum()).encode("utf8")
if len(password) > 32:
raise ValueError("argument password must be 32 bytes or less")
if not isinstance(password, bytes):
raise ValueError("argument password must be bytes | int | str")
return derive_key(password.rjust(32, b"\x00"))
def ecb_encrypt(data: bytes | str, password: bytes | str) -> str:
data = get_padded_data(data)
key = get_key(password)
blocks = [aes_256_cipher(data[i : i + 16], key) for i in range(0, len(data), 16)]
return "".join(f"{e:02x}" for block in blocks for e in block)
AES_256.json
{
"SBOX": [
99 , 124, 119, 123, 242, 107, 111, 197, 48 , 1 , 103, 43 , 254, 215, 171, 118,
202, 130, 201, 125, 250, 89 , 71 , 240, 173, 212, 162, 175, 156, 164, 114, 192,
183, 253, 147, 38 , 54 , 63 , 247, 204, 52 , 165, 229, 241, 113, 216, 49 , 21 ,
4 , 199, 35 , 195, 24 , 150, 5 , 154, 7 , 18 , 128, 226, 235, 39 , 178, 117,
9 , 131, 44 , 26 , 27 , 110, 90 , 160, 82 , 59 , 214, 179, 41 , 227, 47 , 132,
83 , 209, 0 , 237, 32 , 252, 177, 91 , 106, 203, 190, 57 , 74 , 76 , 88 , 207,
208, 239, 170, 251, 67 , 77 , 51 , 133, 69 , 249, 2 , 127, 80 , 60 , 159, 168,
81 , 163, 64 , 143, 146, 157, 56 , 245, 188, 182, 218, 33 , 16 , 255, 243, 210,
205, 12 , 19 , 236, 95 , 151, 68 , 23 , 196, 167, 126, 61 , 100, 93 , 25 , 115,
96 , 129, 79 , 220, 34 , 42 , 144, 136, 70 , 238, 184, 20 , 222, 94 , 11 , 219,
224, 50 , 58 , 10 , 73 , 6 , 36 , 92 , 194, 211, 172, 98 , 145, 149, 228, 121,
231, 200, 55 , 109, 141, 213, 78 , 169, 108, 86 , 244, 234, 101, 122, 174, 8 ,
186, 120, 37 , 46 , 28 , 166, 180, 198, 232, 221, 116, 31 , 75 , 189, 139, 138,
112, 62 , 181, 102, 72 , 3 , 246, 14 , 97 , 53 , 87 , 185, 134, 193, 29 , 158,
225, 248, 152, 17 , 105, 217, 142, 148, 155, 30 , 135, 233, 206, 85 , 40 , 223,
140, 161, 137, 13 , 191, 230, 66 , 104, 65 , 153, 45 , 15 , 176, 84 , 187, 22
],
"RCON": [0, 1, 2, 4, 8, 16, 32, 64, 128, 27, 54]
}
It doesn't raise exceptions, but I have absolutely no idea what I am doing.
This is an example output:
In [428]: ecb_encrypt('6a84867cd77e12ad07ea1be895c53fa3', '0'*32)
Out[428]: 'b981b1853c16fbb6adc7cf4a01c9c57b94a3e5ce608239660c324b01400ebdd5d45a5452d22fed94b7ca9d916ac47736'
I have used this website to check the correctness of the output, with the same key and plaintext, AES-256-ECB mode, the cipher text in hex is:
e07beff38697f04e7adbc971adc2a9135f60746178fcd0f1b3040e4d15c920ad0318e084e1666e699891c78f8aa98960
It isn't what my code outputs as you can see.
Why isn't my code working properly?
I have fixed my code, thanks to this comment, I have used this website to check all the intermediate results, and I walked through all of them step by step, and I got the right results in the end.
I am not good at explaining things, so I will just describe the fixes I implemented.
There were multiple issues in the code, first of all the padding is incorrect, I don't know how to implement padding yet, but it should be extremely easy at this point, but in the original code the data will be padded if its length is a multiple of 16, I fixed that.
Then the order of operations is all wrong, the key is derived correctly, but state_matrix
is wrongly applied and key_matrix
is bugged, as a result the subsequent steps are all wrong. Both of the aforementioned functions aren't needed.
The keys should be in the original order, in round zero first the data block should be mixed with the key using bitwise XOR, then in rounds 1 to 13 the data is first substituted with corresponding values in SBOX, sub_bytes
was implemented correctly, and then the data needs to be permuted, this wasn't implemented correctly.
I don't know how to derive the indices needed, but the correct indices I need are these: (0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12, 1, 6, 11)
Next we need to apply Galois matrix multiplication, it wasn't implemented correctly, I have fixed it.
Fixed code
import json
with open("D:/AES_256.json", "r") as f:
AES_256 = json.load(f)
MAX_256 = (1 << 256) - 1
SBOX = AES_256["SBOX"]
RSBOX = AES_256["RSBOX"]
RCON = AES_256["RCON"]
PERMI = (0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12, 1, 6, 11)
I240 = (
(0, 16),
(16, 32),
(32, 48),
(48, 64),
(64, 80),
(80, 96),
(96, 112),
(112, 128),
(128, 144),
(144, 160),
(160, 176),
(176, 192),
(192, 208),
(208, 224),
(224, 240),
)
def sub_bytes(state: list) -> list:
return [SBOX[i] for i in state]
def permute_state(state: list) -> list:
return [state[i] for i in PERMI]
def add_round_key(state: list, key: list) -> list:
return [a ^ b for a, b in zip(state, key)]
def rot8(byte: int, x: int) -> int:
x &= 7
return (byte << x | byte >> (8 - x)) & 0xFF
def rot_word(word: list) -> int:
return word[1:] + [word[0]]
def sub_word(word: list) -> int:
a, b, c, d = word
return [SBOX[a], SBOX[b], SBOX[c], SBOX[d]]
def galois_mult(x: int, y: int) -> int:
p = 0
while x and y:
if y & 1:
p ^= x
if x & 0x80:
x = (x << 1) ^ 0x11B
else:
x <<= 1
y >>= 1
return p
def mix_column(state: list) -> list:
a, b, c, d = state
a2 = galois_mult(a, 2)
b2 = galois_mult(b, 2)
c2 = galois_mult(c, 2)
d2 = galois_mult(d, 2)
a3 = galois_mult(a, 3)
b3 = galois_mult(b, 3)
c3 = galois_mult(c, 3)
d3 = galois_mult(d, 3)
return a2 ^ b3 ^ c ^ d, a ^ b2 ^ c3 ^ d, a ^ b ^ c2 ^ d3, a3 ^ b ^ c ^ d2
def mix_columns(state: list) -> list:
result = []
for i in range(0, 16, 4):
result.extend(mix_column(state[i : i + 4]))
return result
def derive_key(password: bytes) -> list:
result = list(password).copy()
last = result[28:]
for i in range(32, 240, 4):
j = i // 4
if not j & 7:
last = sub_word(rot_word(last))
last[0] ^= RCON[j // 8]
elif j & 7 == 4:
last = sub_word(last)
key = [a ^ b for a, b in zip(result[i - 32 :], last)]
result.extend(key)
last = key
return result
def aes_256_cipher(data: bytes, password: bytes) -> list:
state = add_round_key(data, password[0])
for i in range(1, 14):
state = add_round_key(mix_columns(permute_state(sub_bytes(state))), password[i])
return add_round_key(permute_state(sub_bytes(state)), password[14])
def get_padded_data(data: bytes | str) -> bytes:
if isinstance(data, str):
data = data.encode("utf8")
if not isinstance(data, bytes):
raise ValueError("argument data must be bytes or str")
l = len(data)
return data + b"\x00" * ((l + 15 & -16) - l)
def get_key(password: bytes | int | str) -> list:
if isinstance(password, int):
if password < 0 or password > MAX_256:
raise ValueError("argument password must be between 0 and 2^256-1")
password = password.to_bytes(32, "big")
if isinstance(password, str):
password = "".join(i for i in password if i.isalnum()).encode("utf8")
if len(password) > 32:
raise ValueError("argument password must be 32 bytes or less")
if not isinstance(password, bytes):
raise ValueError("argument password must be bytes | int | str")
return derive_key(password.rjust(32, b"\x00"))
def ecb_encrypt(data: bytes | str, password: bytes | str) -> str:
data = get_padded_data(data)
key = get_key(password)
key = [key[a:b] for a, b in I240]
blocks = [aes_256_cipher(data[i : i + 16], key) for i in range(0, len(data), 16)]
return b"".join(bytes(block) for block in blocks).hex()
Testing:
hexdigits = "0123456789abcdefABCDEF"
def hex2bytes(s: str) -> list:
result = []
skip = 0
for i, c in enumerate(s):
if skip:
skip = 0
continue
if c in hexdigits:
result.append(int(s[i : i + 2], 16))
skip = 1
return result
password = hex2bytes('603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4')
data = get_padded_data(bytes(hex2bytes('6bc1bee22e409f96e93d7e117393172a')))
In [177]: ecb_encrypt(bytes(data), bytes(password))
Out[177]: 'f3eed1bdb5d2a03c064b5a7e3db181f8'
In [178]: %timeit ecb_encrypt(bytes(data), bytes(password))
296 μs ± 3.58 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
The output is the same as the output for the AES-256-ECB test vector 1 example given by the linked website.
In [179]: ecb_encrypt('6a84867cd77e12ad07ea1be895c53fa3', '00000000000000000000000000000000')
Out[179]: 'e07beff38697f04e7adbc971adc2a9135f60746178fcd0f1b3040e4d15c920ad'
The above example output is the same as the output from this website if AES-256-ECB without padding is selected and the output is formatted as hex.
I am working on decryption and adding more modes and then porting it to C++.
I have figured out how to implement PKCS5 padding:
def pkcs5_pad(data: bytes) -> bytes:
if not isinstance(data, bytes):
raise ValueError("argument data must be bytes")
pad = 16 - (len(data) & 15)
return data + bytes([pad] * pad)
def pkcs5_unpad(data: bytes) -> bytes:
if not isinstance(data, bytes):
raise ValueError("argument data must be bytes")
pad = data[-1]
if pad < 1 or pad > 16 or data[-pad:] != bytes([pad] * pad):
raise ValueError("invalid padding")
return data[:-pad]