I am playing around with xtensor
and I just wanted to perform a simple operation to select rows with specific column values. Imagine I've the following array.
[
[0, 1, 1, 3, 4 ]
[0, 2, 1, 5, 6 ]
[0, 3, 1, 3, 2 ]
[0, 4, 1, 5, 7 ]
]
Now I want to select the rows where col2
and col4
has value 3. Which in this case is row 3
.
[0, 3, 1, 3, 2 ]
I want to achieve similar to what this answer has achieved.
How can I achieve this in xtensor
?
The way to go is to slice with the columns you need, and then look where the condition is true for all columns.
For the latter an overload for xt::all(...)
is seemingly not implemented (yet!), but we can use xt::sum(..., axis)
to achieve the same:
#include <xtensor/xtensor.hpp>
#include <xtensor/xview.hpp>
#include <xtensor/xio.hpp>
int main()
{
xt::xtensor<int,2> a =
{{0, 1, 1, 3, 4},
{0, 2, 1, 5, 6},
{0, 3, 1, 3, 2},
{0, 4, 1, 5, 7}};
auto test = xt::equal(xt::view(a, xt::all(), xt::keep(1, 3)), 3);
auto n = xt::sum(test, 1);
auto idx = xt::flatten_indices(xt::argwhere(xt::equal(n, 2)));
auto b = xt::view(a, xt::keep(idx), xt::all());
std::cout << b << std::endl;
return 0;
}