c++numpyxtensor

Having trouble using xt::where in xtensor


I am trying to find an index value of certain array values in an xarray. I have an xarray called lattice filled that contains numbers 1 through n and what I would like is something like

auto x2 = xt::where(lattice == i)

to get the index values of element i in lattice that will be used for a distance function, but I get the message that == doesn't match the operands. The problem doesn't happen when I use > so I'm just wondering what the difference is.

I've used np.where(lattice==i) in python and I'm trying to translate it over.


Solution

  • You'll have to use xt::equal(a, b) instead of a == b. Indeed this is different from for example a > b which is completely the same as xt::greater(a, b).

    Note also that the list of indices can be converted to a matrix using xt::from_indices(...), see documentation. Consider the following example:

    #include <xtensor/xtensor.hpp>
    #include <xtensor/xio.hpp>
    
    int main()
    {
        xt::xtensor<size_t,2> a = xt::arange(5 * 5).reshape({5, 5});
        size_t i = 4;
        xt::xtensor<size_t,2> idx = xt::from_indices(xt::where(xt::equal(a, i)));
        std::cout << idx << std::endl;
        return 0;
    }