c++xtensor

xtensor: Select rows with specific column values


I am playing around with xtensor and I just wanted to perform a simple operation to select rows with specific column values. Imagine I've the following array.

[ 
  [0, 1, 1, 3, 4 ]
  [0, 2, 1, 5, 6 ]
  [0, 3, 1, 3, 2 ]
  [0, 4, 1, 5, 7 ]
]

Now I want to select the rows where col2 and col4 has value 3. Which in this case is row 3.

  [0, 3, 1, 3, 2 ]

I want to achieve similar to what this answer has achieved.

How can I achieve this in xtensor?


Solution

  • The way to go is to slice with the columns you need, and then look where the condition is true for all columns.

    For the latter an overload for xt::all(...) is seemingly not implemented (yet!), but we can use xt::sum(..., axis) to achieve the same:

    #include <xtensor/xtensor.hpp>
    #include <xtensor/xview.hpp>
    #include <xtensor/xio.hpp>
    
    int main()
    {
      xt::xtensor<int,2> a =
        {{0, 1, 1, 3, 4},
         {0, 2, 1, 5, 6},
         {0, 3, 1, 3, 2},
         {0, 4, 1, 5, 7}};
    
      auto test = xt::equal(xt::view(a, xt::all(), xt::keep(1, 3)), 3);
      auto n = xt::sum(test, 1);
      auto idx = xt::flatten_indices(xt::argwhere(xt::equal(n, 2)));
    
      auto b = xt::view(a, xt::keep(idx), xt::all());
    
      std::cout << b << std::endl;
    
      return 0;
    }