c++arrayfire

Scatter/Gather like Numpy in ArrayFire


I want to scatter and gather elements from an array X at specific indices along one axis. So given an array of indices idx, I want to select the idx(0)th element along the 0th column, the idx(1)th element along the 1st column, etc..

In Numpy, the following statement:

X = np.array([[1, 2, 3], [4, 5, 6]])
print(X[[0, 1, 1], range(3)])

prints [1, 5, 6].

Furthermore, I can do this process in reverse:

Y = np.zeros((2, 3))
Y[[0, 1, 1], range(3)] = [1, 5, 6]
print(Y)

This will print

[[1. 0. 0.]
 [0. 5. 6.]]

However, when I try to replicate this behavior in ArrayFire:

float elements[] = {1, 2, 3, 4, 5, 6};
af::array X = af::array(3, 2, elements);
int idx_elements[] = {0, 1, 1};
af::array idx = af::array(3, idx_elements);
af::print("", X(af::span, idx));

I get an array of shape [3, 3, 1, 1] with the elements

1.0000     4.0000     4.0000 
2.0000     5.0000     5.0000 
3.0000     6.0000     6.0000

So how can I achieve the desired numpy-like behavior for scattering and gathering elements in ArrayFire?

To perform the gather operation on a matrix, I can extract the diagonal of the resulting matrix but that may not work in the multidimensional case and it doesn't work in the other (scatter) direction.


Solution

  • X
    [3 2 1 1]
        1.0000     4.0000 
        2.0000     5.0000 
        3.0000     6.0000 
    
    idx
    [3 1 1 1]
             0 
             1 
             1 
    

    ArrayFire does Cartesian product when af::array are involved. Hence, the output. Please see the below indices because of that.

    Col\Row   0       1        1   from array
    
       0    (0, 0)  (0,1)  (0, 1)
       1    (1, 0)  (1,1)  (1, 1)
       2    (2, 0)  (2,1)  (2, 1)
       ^
       ^ from sequence
    

    Thus, the output of X(af::span, idx)) is a 3x3 matrix.

    To gather elements based on coordinates, you would need different function approx2. Note that this function takes it's indices as floating point arrays only.

    float idx_elements[] = {0, 1, 1}; // changed the idx to floats
    af::array colIdx = af::array(3, idx_elements);
    af::array rowIdx = af::iota(3); // same effect as span
    
    af::array out = approx2(X, rowIdx, colIdx);
    af_print(out); 
    // out
    // [3 1 1 1]
    //     1.0000 
    //     5.0000 
    //     6.0000 
    

    To set the values for given indices, you would have to flatten the array because of very reason that array::operator() considers cartesian product when af::array is involved.

    af::array A = af::constant(0, 3, 2); // same size as X
    af::array B = af::flat(A); // flatten the array, this involves meta data modification only
    
    B(rowIdx + 3 * colIdx) = out; // use row & col indices to fetch linear indices
    // rowIdx + 3 * colIdx
    // [3 1 1 1]
    //     0.0000 
    //     4.0000 
    //     5.0000 
    
    B = moddims(B, A.dims()); // reset the dimensions to original A dims
    
    af_print(B);
    // B
    // [3 2 1 1]
    //     1.0000     0.0000 
    //     0.0000     5.0000 
    //     0.0000     6.0000 
    

    You can look more details in our indexing tutorial.