pythonarraysnumpyinterleave

Interleaving NumPy arrays with mismatching shapes


I would like to interleave multiple numpy arrays with differing dimensions along a particular axis. In particular, I have a list of arrays of shape (_, *dims), varying along the first axis, which I would like to interleave to obtain another array of shape (_, *dims). For instance, given the input

a1 = np.array([[11,12], [41,42]])
a2 = np.array([[21,22], [51,52], [71,72], [91,92], [101,102]])
a3 = np.array([[31,32], [61,62], [81,82]])

interweave(a1,a2,a3)

the desired output would be

np.array([[11,12], [21,22], [31,32], [41,42], [51,52], [61,62], [71,72], [81,82], [91,92], [101,102]]

With the help of previous posts (such as Numpy concatenate arrays with interleaving), I've gotten this working when the arrays match along the first dimension:

import numpy as np

def interweave(*arrays, stack_axis=0, weave_axis=1):
    final_shape = list(arrays[0].shape)
    final_shape[stack_axis] = -1

    # stack up arrays along the "weave axis", then reshape back to desired shape
    return np.concatenate(arrays, axis=weave_axis).reshape(final_shape)

Unfortunately, if the input shapes mismatch along the first dimension, the above throws an exception since we must concatenate along a different axis than the mismatching one. Indeed, I don't see any way to use concatenation effectively here, since concatenating along the mismatched axis will destroy information we need to produce the desired output.

One other idea I had was to pad the input arrays with null entries until their shapes match along the first dimension, and then remove the null entries at the end of the day. While this would work, I am not sure how best to implement it, and it seems like it should not be necessary in the first place.


Solution

  • Here's a mostly NumPy based approach using also zip_longest to interleave the arrays with a fill value:

    def interleave(*a):
        # zip_longest filling values with as many NaNs as
        # values in second axis
        l = *zip_longest(*a, fillvalue=[np.nan]*a[0].shape[1]),
        # build a 2d array from the list
        out = np.concatenate(l)
        # return non-NaN values
        return out[~np.isnan(out[:,0])]
    

    a1 = np.array([[11,12], [41,42]])
    a2 = np.array([[21,22], [51,52], [71,72], [91,92], [101,102]])
    a3 = np.array([[31,32], [61,62], [81,82]])
    
    interleave(a1,a2,a3)
    
    array([[ 11.,  12.],
           [ 21.,  22.],
           [ 31.,  32.],
           [ 41.,  42.],
           [ 51.,  52.],
           [ 61.,  62.],
           [ 71.,  72.],
           [ 81.,  82.],
           [ 91.,  92.],
           [101., 102.]])