pythoncryptographyrsaoaep

RSA Implementation with OAEP occasionally produces an error with lhash and lhash'


I attempted to implement a RSA algorithm in Python 3.12. I first implemented a textbook RSA algorithm, which was done successfully. (I verified this with many repeated attempts using various keys and messages, all of which encrypted and decrypted successfully.)

However, since I implemented Optimal Asymmetric Encryption Padding (OAEP) encoding and decoding, lHash is occasionally not equal to lHash' during the OAEP decoding process. (I'm using terminology from Wikipedia's page on OAEP: https://en.wikipedia.org/wiki/Optimal_asymmetric_encryption_padding). One quick glace at my code will reveal that many of the OAEP related functions were "inspired by" or taken directly from the following github project: https://gist.github.com/ppoffice/e10e0a418d5dafdd5efe9495e962d3d2.

Approximately half of the time, the line assert(lhash == lhash_prime) will cause an assertion error. When running the program 100 times, 46% of the time, the assertion error was raised. I have a few examples of key values which worked as well as ones that did not.

I've attempted analysing patterns in the n, e, and d values that the RSA key generation produces which raise the assertion error. I believe that the n values in particular would be helpful, given OAEP uses the length of n as a significant part of the process. However with such large values, it's difficult for a beginner programmer like me to make sense of them.

When I didn't actually encrypt and decrypt messages with both textbook RSA and OAEP, and instead I simply encoded and decoded messages with only OAEP, the process worked fine. Additionally, any tests I did only using textbook RSA worked as well.

Below is the code for a minimal, reproducible example. I'm sorry, but even though I've tried my best to reduce how long it is (removed checks to ensure the primes p and q are actually secure primes, etc.), it's still quite long if the bug is to be reproduced.

import random
from math import ceil
import hashlib
import os
from typing import Callable

def byte_len(n: int) -> int:
    return ceil(n.bit_length() / 8)

def get_n_bit_rand_num(n: int) -> int:
    return random.randrange(2**(n-1)+1,2**n-1)

def rabin_miller_composite_test(a: int, m: int, k: int, n: int) -> bool:
    if (pow(a,m,n) == 1): 
        return False 
    for i in range(k):
        if (pow(a,2**i*m,n) == n-1):
            return False
    return True

def probablistic_is_prime_test(n: int) -> bool:
    first_primes_list = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
                     31, 37, 41, 43, 47, 53, 59, 61, 67,
                     71, 73, 79, 83, 89, 97, 101, 103,
                     107, 109, 113, 127, 131, 137, 139,
                     149, 151, 157, 163, 167, 173, 179,
                     181, 191, 193, 197, 199, 211, 223,
                     227, 229, 233, 239, 241, 251, 257,
                     263, 269, 271, 277, 281, 283, 293,
                     307, 311, 313, 317, 331, 337, 347, 349]
    for divisor in first_primes_list:
        if n % divisor == 0:
            return False
    k = 0
    m = n-1
    while (m % 2 == 0):
        m >>= 1 
        k += 1

    iterations = 20 
    for _ in range(iterations):
        a = random.randrange(2,n-1) 
        if rabin_miller_composite_test(a,m,k,n):
            return False
    return True

def get_random_large_prime() -> int:
    num_is_prime = False
    num = 0
    while(num_is_prime == False):
        num = get_n_bit_rand_num(1024)
        if probablistic_is_prime_test(num):
            num_is_prime = True
    return num

def euclidean_algorithm_GCD(larger_num: int, smaller_num: int) -> int:
    if (smaller_num == 0):
        return larger_num
    else:
        return euclidean_algorithm_GCD(smaller_num,larger_num % smaller_num)

def extended_euclidean_algorithm_second_num_of_linear_combination(larger_num: int, smaller_num: int) -> int:
    s = 0
    r = smaller_num
    old_r = larger_num
    old_s = 1
    quotient = 0
    temp = 0
    while (r != 0):
        quotient = old_r // r
        temp = old_r
        old_r = r
        r = temp - quotient * r
        temp = old_s
        old_s = s
        s = temp - quotient * s
    second_num = (old_r - old_s * larger_num) // smaller_num
    return second_num

def generate_keys() -> tuple[int, int, int]:
    p = get_random_large_prime()
    q = get_random_large_prime()
    n = p*q
    phi_of_n = (p-1) * (q-1)
    e = 65537
    while euclidean_algorithm_GCD(phi_of_n,e) != 1:
        e += 1
    d = extended_euclidean_algorithm_second_num_of_linear_combination(phi_of_n,e) % phi_of_n
    return n, e, d

def textbook_encrypt_message(message: bytes, e: int, n: int) -> int:
    int_message = int.from_bytes(message, 'little')
    return pow(int_message,e,n)

def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes:
    int_message = pow(encrypted_message, d, n)
    return int_message.to_bytes(byte_len(int_message), 'little')

def encrypt_message_oaep(message: bytes, e: int, n: int) -> int:
    n_byte_length = byte_len(n)
    padded_message = oaep_encode(message,n_byte_length)
    return textbook_encrypt_message(padded_message,e,n)

def decrypt_message_oaep(encrypted_message: int, d: int, n: int) -> str:
    encoded_message = textbook_decrypt_message(encrypted_message, d, n)
    encoded_message_as_bytes = encoded_message
    n_byte_length = byte_len(n)
    message = oaep_decode(encoded_message_as_bytes,n_byte_length)
    return message.decode()

def bytewise_xor(data: bytes, mask: bytes) -> bytes: 
    masked = b""
    for i in range(max(len(data),len(mask))):
        if i < len(data) and i < len(mask):
            masked += (data[i] ^ mask[i]).to_bytes(1, byteorder = 'big')
        elif i < len(data):
            masked += data[i].to_bytes(1, byteorder="big")
        else:
            break
    return masked

def sha1(m: bytes) -> bytes:
    '''SHA-1 hash function'''
    hasher = hashlib.sha1()
    hasher.update(m)
    return hasher.digest()

def mgf1(seed: bytes, mlen: int, f_hash: Callable = sha1) -> bytes: 
    '''MGF1 mask generation function with SHA-1'''
    t = b''
    hlen = len(f_hash(b''))
    for c in range(0, ceil(mlen / hlen)):
        _c = c.to_bytes(4, byteorder="big")
        t += f_hash(seed + _c)
    return t[:mlen]

def oaep_encode(message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes: 
    lhash = hash_func(label)
    padding_string = (k - len(message)-2*len(lhash)-2) * b"\x00"
    data_block = lhash + padding_string + b"\x01" + message
    seed = os.urandom(len(lhash))
    data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
    masked_data_block = bytewise_xor(data_block,data_block_mask)
    seed_mask = mgf(masked_data_block,len(lhash),hash_func)
    masked_seed = bytewise_xor(seed,seed_mask)
    return b"\x00" + masked_seed + masked_data_block

def oaep_decode(encoded_message: bytes, k: int, label: bytes = b"", hash_func: Callable = sha1, mgf: Callable = mgf1) -> bytes:
    lhash = hash_func(label)
    masked_seed = encoded_message[1:1 + len(lhash)]
    masked_data_block = encoded_message[1+len(lhash):]
    seed_mask = mgf(masked_data_block,len(lhash),hash_func)
    seed = bytewise_xor(masked_seed,seed_mask)
    data_block_mask = mgf(seed,k-len(lhash)-1,hash_func)
    data_block = bytewise_xor(masked_data_block, data_block_mask)
    lhash_prime = data_block[:len(lhash)]
    assert(lhash == lhash_prime)
    i = len(lhash)
    while i < len(data_block):
        if data_block[i] == 0:
            i += 1
            continue
        elif data_block[i] == 1:
            i += 1
            break
        else:
            raise Exception('This should never happen.')
    return data_block[i:]

[n, e, d] = generate_keys()
print("n: ", n)
print("e: ", e)
print("d: ", d)
message = "Imagine that this is some secure test message"
oaep_encrypted_message = encrypt_message_oaep(message.encode(), e, n)
print(oaep_encrypted_message)
print(decrypt_message_oaep(oaep_encrypted_message, d, n))

Solution

  • RFC8017 (PKCS #1: RSA Cryptography Specifications) ensures that the encoded message is smaller than the modulus, so this problem cannot occur if RFC8017 is implemented correctly!
    Rather, the problem is caused by two bugs in your original implementation:

    Overall:

    def textbook_encrypt_message(message: bytes, e: int, n: int) -> int:
        int_message = int.from_bytes(message, 'big')
        return pow(int_message,e,n)
    
    def textbook_decrypt_message(encrypted_message: int, d: int, n: int) -> bytes:
        int_message = pow(encrypted_message, d, n)
        return int_message.to_bytes(byte_len(n), 'big')
    

    With this, decryption works and the (inperformant) loop from your answer is not required.


    Note that I have not checked the consistency of your implementation with RFC8017 per se. A possible test would be to encrypt with your code and decrypt with a reliable library (and analogously for decryption).
    Also keep in mind that in addition to pure functionality, security must also be kept in mind (see side-channel attacks in particular).