c++xtensor

How can I use equality in xt::filter?


the following code,

#include <iostream>

#include "xtensor/xadapt.hpp"
#include "xtensor/xarray.hpp"
#include "xtensor/xindex_view.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xmasked_view.hpp"
#include "xtensor/xview.hpp"

using namespace std;

int main() {
    xt::xarray<float> a = {{1, 2, 3}, {4, 2, 6}, {9, 0, 2}};
    cout << a << endl;
    xt::filter(a, a == 2) = 10;
    cout << a << endl;
}

Fails to compile with the following: error: no match for ‘operator==’ (operand types are ‘xt::xarray<float>’ ... and ‘int’)

However, other comparison operators work as expected (>,<,>=,<=). I'm not sure if operator== was intentionally not implemented, but until it is (if it ever is), is there a work around, and what is it?


Solution

  • You can use xt::equal(a, b) instead of a == b. I.e.

    xt::filter(a, xt::equal(a, 2)) = 10;
    

    does what you want.