pythonlistnumpyruntime-error

How does one compare the contents of two lists of non-hashable objects in python without caring about order?


I have two lists of numpy arrays, and I want to check if the two lists have the same set of numpy arrays. If they have the same arrays in a different order, I still want it to return true. The numpy arrays in each list are not all the same shape, so I can't use anything that would rely on that. As an example, let's just set

list_1 = [numpy.array([1, 2]), numpy.array([3])]
list_2 = [numpy.array([3]), numpy.array([1, 2])]

as an example of something that should return true when compared. The order of elements should matter within the numpy array, so

list_3 = [numpy.array([2, 1]), numpy.array([3])]

should return false when compared with both of list_1 and list_2.

I can't use set(list_1) == set(list_2), which is what I would normally try, because I get TypeError: unhashable type: numpy.ndarray. My backup solution was to try the following function:

def contents_equal(list_1, list_2):
    for array in list_1:
        if array not in list_2:
            return False
    for array in list_2:
        if array not in list_1:
            return False
    return True

Unfortunately, this did not work either. I got the following error on the line if array not in list_2:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

I tried another solution with a.all(), as it suggested:

def contents_equal(list_1, list_2):
    if [array not in list_2 for array in list_1].any():
        return False
    if [array not in list_1 for array in list_2].any():
        return False
    return True

This one felt less readable to me, but I figured if it worked, it worked. Unfortunately, it didn't work. I got the same ValueError on the if statement. I'm not sure what to try from here, or even why my solution isn't working. I'm not evaluating any list as a boolean in either solution as far as I can tell, which seems to be what the error is complaining about. Can anyone enlighten me?


Solution

  • NumPy arrays aren’t directly hashable or comparable with == or in because they return arrays of booleans instead of a single truth value. That’s why doing something like array in list raises an error.

    Instead, you can convert the arrays to tuples, which are hashable, and compare sets like this:

    import numpy as np
    
    def contents_equal(list_1, list_2):
        set_1 = {tuple(array.flatten()) for array in list_1}
        set_2 = {tuple(array.flatten()) for array in list_2}
        return set_1 == set_2
    
    
    list_1 = [np.array([1, 2]), np.array([3])]
    list_2 = [np.array([3]), np.array([1, 2])]
    
    out = contents_equal(list_1, list_2)
    
    print(out)
    True