arraysnumpybroadcaststride

Unravel strided indices


I am trying to write a routine which returns the raveled strided indices given the shape of two operands. It should take into account broadcasting and from these "raveled strided indices" it should be possible to get which values are being accessed for the two arrays (operands) if there are flattened.

Some examples which I hope explains this better:

shapeA = (4,3)
shapeB = (3,)

should output

strideA = [0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]
strideB = [0,  1,  2,  0,  1,  2,  0,  1,  2,  0,  1,  2]

Another example:

shapeA = (4,1)
shapeB = (3,)

should output

strideA = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
strideB = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]

My attempt (which only works on the first example):

def ravel_strided_indices(shapeA, shapeB):
    out_shape = np.broadcast_shapes(shapeA, shapeB)
    flatten = np.prod(out_shape)
    strideA = np.arange(flatten) % np.prod(shapeA)
    strideB = np.arange(flatten) % np.prod(shapeB)
    return strideA, strideB

Does there exist any convenient way of doing this in numpy (maybe something similar to np.ravel_multi_index or np.indices) ?

EDIT

I am looking for something that would make this operation

A = np.random.randn(3,1,5,2,1)
B = np.random.randn(1,4,5,1,5)

A + B

Equivalent with this operation:

indicesA, indicesB = ravel_strided_indices(A.shape, B.shape)
out_shape = np.broadcast_shapes(A.shape, B.shape)

(A.flatten()[indicesA] + B.flatten()[indicesB]).reshape(out_shape)

Solution

  • Maybe the following can get you started.

    import numpy as np
    AA = np.random.randn(3,1,5,2,1)
    BB = np.random.randn(1,4,5,1,5)
    CC = AA + BB
    out_shape = np.broadcast_shapes(AA.shape,BB.shape)
    aidx   = np.arange(np.prod(AA.shape)).reshape(AA.shape)
    bidx   = np.arange(np.prod(BB.shape)).reshape(BB.shape)
    aidx_b = np.broadcast_to(aidx,CC.shape)
    bidx_b = np.broadcast_to(bidx,CC.shape)
    cnew   = (AA.flatten()[aidx_b.flatten()] + BB.flatten()[bidx_b.flatten()]).reshape(out_shape)
    print('norm diff=',np.linalg.norm(CC-cnew))
    

    Edit: the question, of course, is why you'd want to do such a thing.