pythonarraysnumpynumpy-indexing

Extract array from array without loop in python


I am trying to extract part of an array from an array.

Let's assume I have an array array1 with shape (M, N, P). For my specific case, M = 10, N = 5, P = 2000. I have another array, array2 of shape (M, N, 1), which contains the starting points of the interesting data in array1 along the last axis. I want to extract 50 points of this data starting with the indices given by array2, kind of like this:

array1[:, :, array2:array2 + 50] 

I would expect a result of shape (M, N, 50). Unfortunatly I get the Error:

TypeError: only integer scalar arrays can be converted to a scalar index

Sure I could also get the result by looping through the array, but I feel that there must be a smarter way, because I needed this quite often.


Solution

  • You can build a mask using a comparison of the values in array2 with an index range of the last dimension:

    For example:

    import numpy as np
        
    M,N,P,k = 4,2,15,3   # yours would be 10,5,2000,50
    
    A1 = np.arange(M*N*P).reshape((M,N,P))
    A2 = np.arange(M*N).reshape((M,N,1)) + 1
    
    rP = np.arange(P)[None,None,:]
    A3 = A1[(rP>=A2)&(rP<A2+k)].reshape((M,N,k))
    

    Input:

    print(A1)
    
    [[[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14]
      [ 15  16  17  18  19  20  21  22  23  24  25  26  27  28  29]]
    
     [[ 30  31  32  33  34  35  36  37  38  39  40  41  42  43  44]
      [ 45  46  47  48  49  50  51  52  53  54  55  56  57  58  59]]
    
     [[ 60  61  62  63  64  65  66  67  68  69  70  71  72  73  74]
      [ 75  76  77  78  79  80  81  82  83  84  85  86  87  88  89]]
    
     [[ 90  91  92  93  94  95  96  97  98  99 100 101 102 103 104]
      [105 106 107 108 109 110 111 112 113 114 115 116 117 118 119]]]
    
    print(A2)
    
    [[[1]
      [2]]
    
     [[3]
      [4]]
    
     [[5]
      [6]]
    
     [[7]
      [8]]]
    

    Output:

    print(A3)
    
    [[[  1   2   3]
      [ 17  18  19]]
    
     [[ 33  34  35]
      [ 49  50  51]]
    
     [[ 65  66  67]
      [ 81  82  83]]
    
     [[ 97  98  99]
      [113 114 115]]]