pythonarraysnumpynumpy-ndarrayarray-broadcasting

How to use a mask to limit broadcasted operations between two numpy arrays?


I have an array like so:

data = np.array([
    [[10, 10, 10],
     [10, 10, 10],
     [10, 10, 10]],

    [[20, 20, 20],
     [20, 20, 20],
     [20, 20, 20]],

    [[30, 30, 30],
     [30, 30, 30],
     [30, 30, 30]],
], dtype=np.float64)

and one to divide values by, like so:

divide_by = np.array([
    [[10, 10, 1]],
    [[1, 10, 10]],
    [[1, 1, 1]],
], dtype=np.float64)

I would like to divide each row (axis 0) of the data array by values in the divide_by array (sort of like a stamp), but only in positions where a given mask (which as the shape of data) has been set to True.

So the first part I can achieve by:

divide_by = divide_by.reshape(divide_by.shape[0], divide_by.shape[2])

data /= divide_by

print(data)

Which yields:

[[[ 1.  1. 10.]
  [10.  1.  1.]
  [10. 10. 10.]]

 [[ 2.  2. 20.]
  [20.  2.  2.]
  [20. 20. 20.]]

 [[ 3.  3. 30.]
  [30.  3.  3.]
  [30. 30. 30.]]]

Note that each row of the data array has been divided by what's in divide_by as if that had been applied like a stamp on top of it. Great.

I would like to do the same now, but only apply the division in places where this mask is set to true:

mask = np.array([
    [[False, True, False],
     [False, False, False],
     [True, False, False]],

    [[True, True, True],
     [False, False, True],
     [False, False, False]],

    [[True, False, False],
     [False, False, False],
     [False, False, False]],
])

So that the expected output is:

[[[10.  1. 10.]
  [10. 10.  1.]
  [10. 10. 10.]]

 [[ 2.  2. 20.]
  [20. 20.  2.]
  [20. 20. 20.]]

 [[ 3. 30. 30.]
  [30. 30. 30.]
  [30. 30. 30.]]]

The mask is defining a subset of places to divide by,

But if I do:

data[mask] /= divide_by

instead of

data /= divide_by

I get:

ValueError: operands could not be broadcast together with shapes (7,) (3,3) (7,) 

How can I use this mask in this particular case?


Solution

  • You can use np.where(mask, data / divide_by[None, :, 0], data).