pythonnumpynumpy-slicing

Indexing numpy array of shape `(A, B, C)` with `[[a, b], [c, d], :]` (`0 <= a, b < A`, `0 <= c, d < B`) produces shape `(2, C)` instead of `(2, 2, C)`


Here's the example:

import numpy as np

A = np.random.randint(100)
B = np.random.randint(100)
C = np.random.randint(100)
print(f"{A=}, {B=}, {C=}")

x = np.random.random((A, B, C))
print(f"{x.shape=}")

a = np.random.randint(0, A)
b = np.random.randint(0, A)
print(f"{a=}, {b=}")

c = np.random.randint(0, B)
d = np.random.randint(0, B)
print(f"{c=}, {d=}")

print(f"{x[[a, b], [c, d], :].shape=}")
print(f"{x[[a, b]][:, [c, d]].shape=}")
A=7, B=40, C=57
x.shape=(7, 40, 57)
a=4, b=1
c=10, d=5
x[[a, b], [c, d], :].shape=(2, 57)
x[[a, b]][:, [c, d]].shape=(2, 2, 57)

I would have expected indexing with [[a, b], [c, d], :] to produce a shape (2, 2, C)?


Solution

  • Use np.ix_ to get an open mesh. Unfortunately it doesn't play particularly nice with :, but you can get around it:

    print(f"{x[np.ix_([a, b], [c, d], range(C))].shape = }")
    x[np.ix_([a, b], [c, d], range(C))].shape = (2, 2, 77)
    
    print(f"{x[np.ix_([a, b], range(B), [c, d])].shape = }")
    x[np.ix_([a, b], range(B), [c, d])].shape = (2, 92, 2)