pythonnumpy

In numpy, how to compare all values in an axis


For a numpy array, how can I change the value only if all elements along an axis are equal to another array? For example...

array = np.array([[1, 0, 1],
                  [0, 0, 1],
                  [1, 1, 0],
                  [0, 0, 0],
                  [1, 0, 1]])

I want to replace all [1, 0, 1] with [1, 1, 1]... so that array becomes

array([[1, 1, 1],
       [0, 0, 1],
       [1, 1, 0],
       [0, 0, 0],
       [1, 1, 1]])

When I use a boolean array, it checks each individual number. How can I compare the entire row at once instead?


Solution

  • Try with:

    array[(array == [1, 0, 1]).all(axis=1)] = [1, 1, 1]