pythonstringbytestream

Checking for equality if either input can be `str` or `bytes`


I am trying to write a function that checks if two strings (with ASCII-only content) or bytes are equal.

Right now I have:

import typing as typ


def is_equal_str_bytes(
    a: typ.Union[str, bytes],
    b: typ.Union[str, bytes],
) -> bool:
    if isinstance(a, str):
        a = a.encode()
    if isinstance(b, str):
        b = b.encode()
    return a == b

This works with the any combination of str or bytes types, while the == operator will return False (rightfully) if the two types differ.

import itertools


ss = "ciao", b"ciao"
for a, b in itertools.product(ss, repeat=2):
    print(f"{a!r:<8} {b!r:<8} {is_equal_str_bytes(a, b)} {a == b}")
# 'ciao'   'ciao'   True True
# 'ciao'   b'ciao'  True False
# b'ciao'  'ciao'   True False
# b'ciao'  b'ciao'  True True

Is there a simpler / faster way?


Solution

  • Some benchmarks with random equal strings/bytes of a million characters (on TIO with Python 3.8 pre-release, but I got similar times with 3.10.2):

      186.88 us  s.encode()
      187.39 us  s.encode("utf-8")
      183.85 us  s.encode("ascii")
       94.62 us  b.decode()
       94.27 us  b.decode("utf-8")
      137.91 us  b.decode("ascii")
       79.93 us  s == s2
       82.69 us  b == b2
      182.72 us  s + "a"
      177.06 us  b + b"a"
        0.08 us  len(s)
        0.07 us  len(b)
        1.14 us  s[:1000].encode()
        0.97 us  b[:1000].decode()
        2.06 us  s[::1000].encode()
        1.45 us  b[::1000].decode()
        1.91 us  hash(s)
        1.56 us  hash(b)
      508.62 us  hash(s2)
      546.00 us  hash(b2)
        2.85 us  str(s)
     9142.59 us  str(b)
    13541.64 us  repr(s)
     9100.34 us  repr(b)
    

    Thoughts based on that:

    So here's some potentially faster one using the above optimizations (not tested/benchmarked, partly because it depends on your data):

    import typing as typ
    
    def is_equal_str_bytes(
        a: typ.Union[str, bytes],
        b: typ.Union[str, bytes],
    ) -> bool:
        if len(a) != len(b):
            return False
        if hash(a) != hash(b):
            return False
        if type(a) is type(b):
            return a == b
        if isinstance(a, bytes):  # make a=str, b=bytes
            a, b = b, a
        if a[:1000] != b[:1000].decode():
            return False
        if a[::1000] != b[::1000].decode():
            return False
        return a == b.decode()
    

    My benchmark code:

    import os
    from timeit import repeat
    
    n = 10**6
    b = bytes(x & 127 for x in os.urandom(n))
    s = b.decode()
    assert hash(s) == hash(b)
    
    setup = '''
    from __main__ import s, b
    s2 = b.decode()  # Always fresh so it doesn't have a hash stored already 
    b2 = s.encode()
    assert s2 is not s and b2 is not b
    '''
    
    exprs = [
        's.encode()',
        's.encode("utf-8")',
        's.encode("ascii")',
        'b.decode()',
        'b.decode("utf-8")',
        'b.decode("ascii")',
        's == s2',
        'b == b2',
        's + "a"',
        'b + b"a"',
        'len(s)',
        'len(b)',
        's[:1000].encode()',
        'b[:1000].decode()',
        's[::1000].encode()',
        'b[::1000].decode()',
        'hash(s)',
        'hash(b)',
        'hash(s2)',
        'hash(b2)',
        'str(s)',
        'str(b)',
        'repr(s)',
        'repr(b)',
    ]
    
    for _ in range(3):
        for e in exprs:
            number = 100 if exprs.index(e) < exprs.index('hash(s)') else 1
            t = min(repeat(e, setup, number=number)) / number
            print('%8.2f us ' % (t * 1e6), e)
        print()