rarraysperformancematrixindexing

Is there an efficient way to extract a slice of a 3d array?


Consider this:

a = array(1:60, c(3,4,5))

# Extract (*,2,2)^th elements
a[cbind(1:3,2,2)]
                                
# ^^ returns a vector of length 3, c(16, 17, 18)
# Note this is asymptotically inefficient -- cbind(1:k,2,2) is a kx3 matrix
# So it's not making use of slicing
 
# Extract (1, y_i, z_i)^th elements
N = 100000
y = sample(1:4, N, replace = TRUE)
z = sample(1:5, N, replace = TRUE)
a[cbind(1, y, z)]

# ^^ returns a vector of length N

How do I efficiently extract the (*, y_i, z_i) elements, with the result as a Nx3 matrix (i.e. 2d array)?

Note this question is similar but the only answer proceeds elementwise + so will be slow.


Solution

  • A way might be to calculate the linear index of the array using its dim.

    outer((y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2]), seq_len(nrow(a)), \(x, y) a[x+y])
    
    do.call(cbind, lapply(seq_len(nrow(a)), \(x, y) a[x+y], (y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2])))
    

    Or reducing the dim and selecting whole columns or whole rows.

    t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1])))[,y + (z - 1)*dim(a)[2]])
    
    t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1]))))[y + (z - 1)*dim(a)[2],]
    


    Benchmark:

    m <- alist(
      matrix_indexing = `dim<-`(a[cbind(rep(seq_len(nrow(a)), each = N), y, z)], c(N, nrow(a))),
      sapply_indexing = sapply(seq_len(nrow(a)), \(i) a[cbind(i, y, z)]),
      sapply_indexing2 = {yz <- cbind(y, z); sapply(1:dim(a)[1], \(i) a[i,,][yz])},
      linearIndex = outer((y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2]), seq_len(nrow(a)), \(x, y) a[x+y]),
      linearIndex2 = do.call(cbind, lapply(seq_len(nrow(a)), \(x, y) a[x+y], (y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2]))),
      asplit = do.call(rbind, asplit(a, -1)[cbind(y, z)]),
      aperm = `dim<-`(aperm(a, c(2, 3, 1)), c(prod(dim(a)[-1]), dim(a)[1]))[y + (z - 1)*dim(a)[2],],
      dimT = t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1])))[,y + (z - 1)*dim(a)[2]]),
      tDim = t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1]))))[y + (z - 1)*dim(a)[2],]
    )
    
    a = array(1:60, c(3,4,5))
    N = 100000
    y = sample(1:4, N, replace = TRUE)
    z = sample(1:5, N, replace = TRUE)
    bench::mark(exprs = m)
    #  expression     min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    #  <bch:expr> <bch:t> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
    #1 matrix_in…  9.46ms  9.59ms     102.     6.87MB     26.2    39    10    382.4ms
    #2 sapply_in…  3.76ms  3.81ms     261.     8.02MB     92.2    85    30    325.4ms
    #3 sapply_in…  2.89ms  2.96ms     334.     5.38MB     72.5   115    25    344.7ms
    #4 linearInd…   4.7ms  4.76ms     208.     9.57MB    132.     55    35    264.3ms
    #5 linearInd…   3.1ms  3.14ms     316.     7.25MB    115.     96    35    303.3ms
    #6 asplit      40.8ms  40.8ms      24.5    3.08MB    221.      1     9     40.8ms
    #7 aperm        1.3ms  1.31ms     748.     2.29MB     48.5   324    21    433.3ms
    #8 dimT        2.03ms  2.06ms     478.     3.44MB     55.8   197    23    412.5ms
    #9 tDim         1.3ms  1.31ms     748.     2.29MB     48.2   326    21    435.9ms
    
    a = array(1:60000, c(30, 40, 50))
    N = 1000
    y = sample(1:40, N, replace = TRUE)
    z = sample(1:50, N, replace = TRUE)
    bench::mark(exprs = m)
    #  expression            min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc
    #  <bch:expr>       <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>
    #1 matrix_indexing   952.7µs 960.17µs     1030.     703KB    14.7    492     7
    #2 sapply_indexing   530.8µs 540.62µs     1813.     825KB    28.3    832    13
    #3 sapply_indexing2  896.2µs 919.12µs     1077.     729KB    14.9    506     7
    #4 linearIndex       426.7µs 435.92µs     2263.     836KB    39.5   1032    18
    #5 linearIndex2      326.9µs 334.29µs     2945.     606KB    33.1   1244    14
    #6 asplit              6.6ms   6.79ms      145.     637KB     8.53    68     4
    #7 aperm             318.1µs 323.11µs     2900.     363KB    19.3   1355     9
    #8 dimT              155.9µs 163.63µs     6061.     481KB    55.8   2714    25
    #9 tDim              198.1µs 208.65µs     4721.     598KB    53.7   2108    24