Consider this:
a = array(1:60, c(3,4,5))
# Extract (*,2,2)^th elements
a[cbind(1:3,2,2)]
# ^^ returns a vector of length 3, c(16, 17, 18)
# Note this is asymptotically inefficient -- cbind(1:k,2,2) is a kx3 matrix
# So it's not making use of slicing
# Extract (1, y_i, z_i)^th elements
N = 100000
y = sample(1:4, N, replace = TRUE)
z = sample(1:5, N, replace = TRUE)
a[cbind(1, y, z)]
# ^^ returns a vector of length N
How do I efficiently extract the (*, y_i, z_i)
elements, with the result as a Nx3 matrix (i.e. 2d array)?
Note this question is similar but the only answer proceeds elementwise + so will be slow.
A way might be to calculate the linear index of the array using its dim.
outer((y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2]), seq_len(nrow(a)), \(x, y) a[x+y])
do.call(cbind, lapply(seq_len(nrow(a)), \(x, y) a[x+y], (y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2])))
Or reducing the dim and selecting whole columns or whole rows.
t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1])))[,y + (z - 1)*dim(a)[2]])
t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1]))))[y + (z - 1)*dim(a)[2],]
Benchmark:
m <- alist(
matrix_indexing = `dim<-`(a[cbind(rep(seq_len(nrow(a)), each = N), y, z)], c(N, nrow(a))),
sapply_indexing = sapply(seq_len(nrow(a)), \(i) a[cbind(i, y, z)]),
sapply_indexing2 = {yz <- cbind(y, z); sapply(1:dim(a)[1], \(i) a[i,,][yz])},
linearIndex = outer((y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2]), seq_len(nrow(a)), \(x, y) a[x+y]),
linearIndex2 = do.call(cbind, lapply(seq_len(nrow(a)), \(x, y) a[x+y], (y-1)*dim(a)[1] + (z-1)*prod(dim(a)[1:2]))),
asplit = do.call(rbind, asplit(a, -1)[cbind(y, z)]),
aperm = `dim<-`(aperm(a, c(2, 3, 1)), c(prod(dim(a)[-1]), dim(a)[1]))[y + (z - 1)*dim(a)[2],],
dimT = t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1])))[,y + (z - 1)*dim(a)[2]]),
tDim = t(`dim<-`(a, c(dim(a)[1], prod(dim(a)[-1]))))[y + (z - 1)*dim(a)[2],]
)
a = array(1:60, c(3,4,5))
N = 100000
y = sample(1:4, N, replace = TRUE)
z = sample(1:5, N, replace = TRUE)
bench::mark(exprs = m)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# <bch:expr> <bch:t> <bch:t> <dbl> <bch:byt> <dbl> <int> <dbl> <bch:tm>
#1 matrix_in… 9.46ms 9.59ms 102. 6.87MB 26.2 39 10 382.4ms
#2 sapply_in… 3.76ms 3.81ms 261. 8.02MB 92.2 85 30 325.4ms
#3 sapply_in… 2.89ms 2.96ms 334. 5.38MB 72.5 115 25 344.7ms
#4 linearInd… 4.7ms 4.76ms 208. 9.57MB 132. 55 35 264.3ms
#5 linearInd… 3.1ms 3.14ms 316. 7.25MB 115. 96 35 303.3ms
#6 asplit 40.8ms 40.8ms 24.5 3.08MB 221. 1 9 40.8ms
#7 aperm 1.3ms 1.31ms 748. 2.29MB 48.5 324 21 433.3ms
#8 dimT 2.03ms 2.06ms 478. 3.44MB 55.8 197 23 412.5ms
#9 tDim 1.3ms 1.31ms 748. 2.29MB 48.2 326 21 435.9ms
a = array(1:60000, c(30, 40, 50))
N = 1000
y = sample(1:40, N, replace = TRUE)
z = sample(1:50, N, replace = TRUE)
bench::mark(exprs = m)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc
# <bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl> <int> <dbl>
#1 matrix_indexing 952.7µs 960.17µs 1030. 703KB 14.7 492 7
#2 sapply_indexing 530.8µs 540.62µs 1813. 825KB 28.3 832 13
#3 sapply_indexing2 896.2µs 919.12µs 1077. 729KB 14.9 506 7
#4 linearIndex 426.7µs 435.92µs 2263. 836KB 39.5 1032 18
#5 linearIndex2 326.9µs 334.29µs 2945. 606KB 33.1 1244 14
#6 asplit 6.6ms 6.79ms 145. 637KB 8.53 68 4
#7 aperm 318.1µs 323.11µs 2900. 363KB 19.3 1355 9
#8 dimT 155.9µs 163.63µs 6061. 481KB 55.8 2714 25
#9 tDim 198.1µs 208.65µs 4721. 598KB 53.7 2108 24