Here's the example:
import numpy as np
A = np.random.randint(100)
B = np.random.randint(100)
C = np.random.randint(100)
print(f"{A=}, {B=}, {C=}")
x = np.random.random((A, B, C))
print(f"{x.shape=}")
a = np.random.randint(0, A)
b = np.random.randint(0, A)
print(f"{a=}, {b=}")
c = np.random.randint(0, B)
d = np.random.randint(0, B)
print(f"{c=}, {d=}")
print(f"{x[[a, b], [c, d], :].shape=}")
print(f"{x[[a, b]][:, [c, d]].shape=}")
A=7, B=40, C=57
x.shape=(7, 40, 57)
a=4, b=1
c=10, d=5
x[[a, b], [c, d], :].shape=(2, 57)
x[[a, b]][:, [c, d]].shape=(2, 2, 57)
I would have expected indexing with [[a, b], [c, d], :]
to produce a shape (2, 2, C)
?
Use np.ix_
to get an open mesh. Unfortunately it doesn't play particularly nice with :
, but you can get around it:
print(f"{x[np.ix_([a, b], [c, d], range(C))].shape = }")
x[np.ix_([a, b], [c, d], range(C))].shape = (2, 2, 77)
print(f"{x[np.ix_([a, b], range(B), [c, d])].shape = }")
x[np.ix_([a, b], range(B), [c, d])].shape = (2, 92, 2)