c++pybind11std-rangesc++-coroutinestd-generator

Binding std::generator<T> with pybind11?


I am trying to bind a std::generator<T> to a Python generator through pybind11, I am using the following currently:

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <generator>
#include <ranges>

namespace py = pybind11;

std::generator<int> f1(int n = 0) {
    while (true) { co_yield n++; }
}

std::generator<int> f2(int n = 0) {
    co_yield n;
    co_yield std::ranges::elements_of(f1(n + 1));
}

PYBIND11_MODULE(test_generator, m) {
    py::class_<std::generator<int>>(m, "_generator_int", pybind11::module_local())
        .def("__iter__",
             [](std::generator<int>& gen) -> std::generator<int>& {
                 return gen;
             })
        .def("__next__", [](std::generator<int>& gen) {
            auto it = gen.begin();
            if (it != gen.end()) { return *it; }
            else                 { throw py::stop_iteration(); }
        });

    m.def("f1", &f1, py::arg("n") = 0);
    m.def("f2", &f2, py::arg("n") = 0);
}

The above code works with f1 but not f2 - If f2(0) is called in Python, it will only generate value 0 and 1, so I am assuming my implementation does not work with std::ranges::elements_of.

How can I make it work with std::ranges::elements_of without modifying f2?


Solution

  • As pointed out in a comment, calling g.begin() more than once is undefined. I was able to fix the code by storing the iterator alongside the generator in an intermediate struct:

    template <class T>
    struct state {
        std::generator<T> g;
        decltype(g.begin()) it;
    
        state(std::generator<T> g) : g(std::move(g)), it(this->g.begin()) {}
    };
    
    PYBIND11_MODULE(test_generator, m)
    {
        py::class_<state<int>>(m, "_generator_int", pybind11::module_local())
            .def("__iter__",
                 [](state<int>& gen) -> state<int>& {
                     return gen;
                 })
            .def("__next__", [](state<int>& s) {
                if (s.it != s.g.end()) {
                    const auto v = *s.it;
                    s.it++;
                    return v;
                }
                else {
                    throw py::stop_iteration();
                }
            });
    
        m.def("f1", [](int n) -> state<int> { return f1(n); }, py::arg("n") = 0);
        m.def("f2", [](int n) -> state<int> { return f2(n); }, py::arg("n") = 0);
    }