I want to scatter and gather elements from an array X
at specific indices along one axis.
So given an array of indices idx
, I want to select the idx(0)
th element along the 0th column, the idx(1)
th element along the 1st column, etc..
In Numpy, the following statement:
X = np.array([[1, 2, 3], [4, 5, 6]])
print(X[[0, 1, 1], range(3)])
prints [1, 5, 6]
.
Furthermore, I can do this process in reverse:
Y = np.zeros((2, 3))
Y[[0, 1, 1], range(3)] = [1, 5, 6]
print(Y)
This will print
[[1. 0. 0.]
[0. 5. 6.]]
However, when I try to replicate this behavior in ArrayFire:
float elements[] = {1, 2, 3, 4, 5, 6};
af::array X = af::array(3, 2, elements);
int idx_elements[] = {0, 1, 1};
af::array idx = af::array(3, idx_elements);
af::print("", X(af::span, idx));
I get an array of shape [3, 3, 1, 1] with the elements
1.0000 4.0000 4.0000
2.0000 5.0000 5.0000
3.0000 6.0000 6.0000
So how can I achieve the desired numpy-like behavior for scattering and gathering elements in ArrayFire?
To perform the gather operation on a matrix, I can extract the diagonal of the resulting matrix but that may not work in the multidimensional case and it doesn't work in the other (scatter) direction.
X
[3 2 1 1]
1.0000 4.0000
2.0000 5.0000
3.0000 6.0000
idx
[3 1 1 1]
0
1
1
ArrayFire does Cartesian product when af::array are involved. Hence, the output. Please see the below indices because of that.
Col\Row 0 1 1 from array
0 (0, 0) (0,1) (0, 1)
1 (1, 0) (1,1) (1, 1)
2 (2, 0) (2,1) (2, 1)
^
^ from sequence
Thus, the output of X(af::span, idx))
is a 3x3 matrix.
To gather elements based on coordinates, you would need different function approx2. Note that this function takes it's indices as floating point arrays only.
float idx_elements[] = {0, 1, 1}; // changed the idx to floats
af::array colIdx = af::array(3, idx_elements);
af::array rowIdx = af::iota(3); // same effect as span
af::array out = approx2(X, rowIdx, colIdx);
af_print(out);
// out
// [3 1 1 1]
// 1.0000
// 5.0000
// 6.0000
To set the values for given indices, you would have to flatten the array because of very reason
that array::operator()
considers cartesian product when af::array is involved.
af::array A = af::constant(0, 3, 2); // same size as X
af::array B = af::flat(A); // flatten the array, this involves meta data modification only
B(rowIdx + 3 * colIdx) = out; // use row & col indices to fetch linear indices
// rowIdx + 3 * colIdx
// [3 1 1 1]
// 0.0000
// 4.0000
// 5.0000
B = moddims(B, A.dims()); // reset the dimensions to original A dims
af_print(B);
// B
// [3 2 1 1]
// 1.0000 0.0000
// 0.0000 5.0000
// 0.0000 6.0000
You can look more details in our indexing tutorial.