I attempted to implement a RSA algorithm in Python 3.12. I first implemented a textbook RSA algorithm, which was done successfully. (I verified this with many repeated attempts using various keys and messages, all of which encrypted and decrypted successfully.)
However, since I implemented Optimal Asymmetric Encryption Padding (OAEP) encoding and decoding, lHash is occasionally not equal to lHash' during the OAEP decoding process. (I'm using terminology from Wikipedia's page on OAEP: https://en.wikipedia.org/wiki/Optimal_asymmetric_encryption_padding). One quick glace at my code will reveal that many of the OAEP related functions were "inspired by" or taken directly from the following github project: https://gist.github.com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2.
Approximately half of the time, the line
assert(lhash == lhash_prime)
will cause an assertion error. When running the program 100 times, 46% of the time, the assertion error was raised. I have a few examples of key values which worked as well as ones that did not.
I've attempted analysing patterns in the n, e, and d values that the RSA key generation produces which raise the assertion error. I believe that the n values in particular would be helpful, given OAEP uses the length of n as a significant part of the process. However with such large values, it's difficult for a beginner programmer like me to make sense of them.
When I didn't actually encrypt and decrypt messages with both textbook RSA and OAEP, and instead I simply encoded and decoded messages with only OAEP, the process worked fine. Additionally, any tests I did only using textbook RSA worked as well.
Below is the code for a minimal, reproducible example. I'm sorry, but even though I've tried my best to reduce how long it is (removed checks to ensure the primes p and q are actually secure primes, etc.), it's still quite long if the bug is to be reproduced.
import random
from math import ceil
import hashlib
import os
from typing import Callable
def byte_len(n: int) -> int:
return ceil(n.bit_length() / 8)
def get_n_bit_rand_num(n: int) -> int:
return random.randrange(2**(n-1)+1,2**n-1)
def rabin_miller_composite_test(a: int, m: int, k: int, n: int) -> bool:
if (pow(a,m,n) == 1):
return False
for i in range(k):
if (pow(a,2**i*m,n) == n-1):
return False
return True
def probablistic_is_prime_test(n: int) -> bool:
first_primes_list = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
31, 37, 41, 43, 47, 53, 59, 61, 67,
71, 73, 79, 83, 89, 97, 101, 103,
107, 109, 113, 127, 131, 137, 139,
149, 151, 157, 163, 167, 173, 179,
181, 191, 193, 197, 199, 211, 223,
227, 229, 233, 239, 241, 251, 257,
263, 269, 271, 277, 281, 283, 293,
307, 311, 313, 317, 331, 337, 347, 349]
for divisor in first_primes_list:
if n % divisor == 0:
return False
k = 0
m = n-1
while (m % 2 == 0):
m >>= 1
k += 1
iterations = 20
for _ in range(iterations):
a = random.randrange(2,n-1)
if rabin_miller_composite_test(a,m,k,n):
return False
return True
def get_random_large_prime() -> int:
num_is_prime = False
num = 0
while(num_is_prime == False):
num = get_n_bit_rand_num(1024)
if probablistic_is_prime_test(num):
num_is_prime = True
return num
def euclidean_algorithm_GCD(larger_num: int, smaller_num: int) -> int:
if (smaller_num == 0):
return larger_num
else:
return euclidean_algorithm_GCD(smaller_num,larger_num % smaller_num)
def extended_euclidean_algorithm_second_num_of_linear_combination(larger_num: int, smaller_num: int) -> int:
s = 0
r = smaller_num
old_r = larger_num
old_s = 1
quotient = 0
temp = 0
while (r != 0):
quotient = old_r // r
temp = old_r
old_r = r
r = temp - quotient * r
temp = old_s
old_s = s
s = temp - quotient * s
second_num = (old_r - old_s * larger_num) // smaller_num
return second_num
def generate_keys() -> tuple[int, int, int]:
p = get_random_large_prime()
q = get_random_large_prime()
n = p*q
phi_of_n = (p-1) * (q-1)
e = 65537
while euclidean_algorithm_GCD(phi_of_n,e) != 1:
e += 1
d = extended_euclidean_algorithm_second_num_of_linear_combination(phi_of_n,e) % phi_of_n
return n, e, d
def textbook_encrypt_message(message: bytes, e: int, n: int) -> int:
int_message = int.from_bytes(message, 'little')
return pow(int_message,e,n)
def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes:
int_message = pow(encrypted_message, d, n)
return int_message.to_bytes(byte_len(int_message), 'little')
def encrypt_message_oaep(message: bytes, e: int, n: int) -> int:
n_byte_length = byte_len(n)
padded_message = oaep_encode(message,n_byte_length)
return textbook_encrypt_message(padded_message,e,n)
def decrypt_message_oaep(encrypted_message: int, d: int, n: int) -> str:
encoded_message = textbook_decrypt_message(encrypted_message, d, n)
encoded_message_as_bytes = encoded_message
n_byte_length = byte_len(n)
message = oaep_decode(encoded_message_as_bytes,n_byte_length)
return message.decode()
def bytewise_xor(data: bytes, mask: bytes) -> bytes:
masked = b""
for i in range(max(len(data),len(mask))):
if i < len(data) and i < len(mask):
masked += (data[i] ^ mask[i]).to_bytes(1, byteorder = 'big')
elif i < len(data):
masked += data[i].to_bytes(1, byteorder="big")
else:
break
return masked
def sha1(m: bytes) -> bytes:
'''SHA-1 hash function'''
hasher = hashlib.sha1()
hasher.update(m)
return hasher.digest()
def mgf1(seed: bytes, mlen: int, f_hash: Callable = sha1) -> bytes:
'''MGF1 mask generation function with SHA-1'''
t = b''
hlen = len(f_hash(b''))
for c in range(0, ceil(mlen / hlen)):
_c = c.to_bytes(4, byteorder="big")
t += f_hash(seed + _c)
return t[:mlen]
def oaep_encode(message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:
lhash = hash_func(label)
padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00"
data_block = lhash + padding_string + b"\x01" + message
seed = os.urandom(len(lhash))
data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
masked_data_block = bytewise_xor(data_block,data_block_mask)
seed_mask = mgf(masked_data_block,len(lhash),hash_func)
masked_seed = bytewise_xor(seed,seed_mask)
return b"\x00" + masked_seed + masked_data_block
def oaep_decode(encoded_message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:
lhash = hash_func(label)
masked_seed = encoded_message[1:1 + len(lhash)]
masked_data_block = encoded_message[1+len(lhash):]
seed_mask = mgf(masked_data_block,len(lhash),hash_func)
seed = bytewise_xor(masked_seed,seed_mask)
data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
data_block = bytewise_xor(masked_data_block, data_block_mask)
lhash_prime = data_block[:len(lhash)]
assert(lhash == lhash_prime)
i = len(lhash)
while i < len(data_block):
if data_block[i] == 0:
i += 1
continue
elif data_block[i] == 1:
i += 1
break
else:
raise Exception('This should never happen.')
return data_block[i:]
[n, e, d] = generate_keys()
print("n: ", n)
print("e: ", e)
print("d: ", d)
message = "Imagine that this is some secure test message"
oaep_encrypted_message = encrypt_message_oaep(message.encode(), e, n)
print(oaep_encrypted_message)
print(decrypt_message_oaep(oaep_encrypted_message, d, n))
RFC8017 (PKCS #1: RSA Cryptography Specifications) ensures that the encoded message is smaller than the modulus, so this problem cannot occur if RFC8017 is implemented correctly!
Rather, the problem is caused by two bugs in your original implementation:
First, the posted code regarding OS2IP and I2OSP uses the little endian order, while RFC8017 defines the big endian order.
The big endian order makes 0x00 the most significant byte in the encoded message EM = 0x00 || maskedSeed || maskedDB
(see 7.1.1., Step 2i). Since the length of EM is equal to the length of the modulus n, EM is therefore smaller than n. This does not apply when using the little endian order.
As fix, the order in textbook_encrypt_message()
and textbook_decrypt_message()
must be changed from little to big.
Second, the posted code regarding I2OSP does not take into account that the length of the octet string must correspond to the length of the modulus in bytes.
As fix, byte_len(n)
must be used in textbook_decrypt_message()
instead of byte_len(int_message)
.
Overall:
def textbook_encrypt_message(message: bytes, e: int, n: int) -> int:
int_message = int.from_bytes(message, 'big')
return pow(int_message,e,n)
def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes:
int_message = pow(encrypted_message, d, n)
return int_message.to_bytes(byte_len(n), 'big')
With this, decryption works and the (inperformant) loop from your answer is not required.
Note that I have not checked the consistency of your implementation with RFC8017 per se. A possible test would be to encrypt with your code and decrypt with a reliable library (and analogously for decryption).
Also keep in mind that in addition to pure functionality, security must also be kept in mind (see side-channel attacks in particular).