pythonnumpymultidimensional-arrayeinops

Reshaping a 3D array of shape (K, M, N) to 2D array of shape (n_rows * M, n_cols * N) with Numpy


I was trying to reshape a 3D array/tensor arr of shape (K, M, N) in numpy (where each (M, N) subarray could be an image for instance) to a 2D of shape (n_rows * M, n_cols * N).

Obviously, I ensure K = n_rows * n_cols beforehand.

I tried all the possible permutations (after scrolling on similar topics on SO),

 for perm in itertools.permutations([0, 1, 2], 3):
        test = arr.transpose(perm).reshape((n_rows * M, n_cols * N))

but unsuccessfully so far.

However, using einops like this,

test = ein.rearrange(arr, '(r c) h w -> (r h) (c w)', r=n_rows, c=n_cols)

it yields the expected result.

Is there a straightforward way to achieve this with numpy?


Solution

  • Deducing from what I think the ein syntax means (new package to me, so unverified whether this is the produced output you expect):

    import numpy as np
    
    K, M, N = 6, 4, 5
    n_rows, n_cols = 3, 2
    
    arr = np.arange(K * M * N).reshape(K, M, N)
    
    out = (
        arr                                  # (r c) h w
        .reshape(n_rows, n_cols, M, N)       # r c h w
        .swapaxes(1, 2)                      # r h c w
        .reshape(n_rows * M, n_cols * N)     # (r h) (c w)
    )
    

    out:

    array([[  0,   1,   2,   3,   4,  20,  21,  22,  23,  24],
           [  5,   6,   7,   8,   9,  25,  26,  27,  28,  29],
           [ 10,  11,  12,  13,  14,  30,  31,  32,  33,  34],
           [ 15,  16,  17,  18,  19,  35,  36,  37,  38,  39],
           [ 40,  41,  42,  43,  44,  60,  61,  62,  63,  64],
           [ 45,  46,  47,  48,  49,  65,  66,  67,  68,  69],
           [ 50,  51,  52,  53,  54,  70,  71,  72,  73,  74],
           [ 55,  56,  57,  58,  59,  75,  76,  77,  78,  79],
           [ 80,  81,  82,  83,  84, 100, 101, 102, 103, 104],
           [ 85,  86,  87,  88,  89, 105, 106, 107, 108, 109],
           [ 90,  91,  92,  93,  94, 110, 111, 112, 113, 114],
           [ 95,  96,  97,  98,  99, 115, 116, 117, 118, 119]])