c++std-rangesc++23c++26

Why does std::views::take_while() do so many function invocations? (even with `cache_latest`)


I decided to use std::views stuff because of their bound-safety and readability. It looks like std::views::take_while invokes too many function calls and recomputes the same thing over and over again. This looks like a performance nightmare to me.

I thought the new std::views::cache_latest would be a solution, but testing my code with GCC (trunk) and Clang (trunk) showed otherwise.

In the following, std::views::take_while() invokes its lambda 68 times! 🤔☹️

#include <iostream>
#include <iterator>
#include <print>
#include <ranges>
#include <string>
#include <vector>

auto main() -> int
{
    const std::vector<std::string> input = {"World 0", "World 1", "World 2", "World 3", "World 4",
                                            "World 5", "World 6", "World 7", "World 8", "Earth"};

    auto output = input
        | std::views::take_while([](const auto& n) {std::println("{:10} called: {}","take_while", n); return n.back() != '5'; })
        | std::views::filter    ([](const auto& n) {std::println("{:10} called: {}","filter", n); return n.back() != '2'; })
        | std::views::transform ([](const auto& n) {std::println("{:10} called: {}","transform", n); return "Hello "+n+" "; })
        | std::views::join;

    std::ranges::copy(output,std::ostream_iterator<char>(std::cout, "\n"));
}

Let's make the code a bit simpler:

#include <iostream>
#include <iterator>
#include <print>
#include <ranges>
#include <string>
#include <vector>

auto main() -> int
{
    const std::vector<std::string> input = {"World 0", "World 1", "World 2", "World 3", "World 4",
                                            "World 5", "World 6", "World 7", "World 8", "Earth"};

    auto output = input
        | std::views::take_while([](const auto& n) {std::println("{:10} called: {}","take_while", n); return n.back() != '5'; })
        | std::views::filter    ([](const auto& n) {std::println("{:10} called: {}","filter", n); return n.back() != '2'; })
        | std::views::transform ([](const auto& n) {std::println("{:10} called: {}","transform", n); return "Hello "+n+" "; });

    std::ranges::copy(output,std::ostream_iterator<std::string>(std::cout, "\n"));
}

OK. take_while calls its lambda two times per iteration. This one clearly looks like something that can be fixed with std::views::cache_latest, but it didn’t do anything. 🤨 godbolt

I also tested drop_while, and there weren't any issues.

Why is that and is there a way to avoid multiple invocations that just recompute the same thing?


Solution

  • Let's take a look at r | take_while(f) | filter(g). I'm going to drop the transform because it doesn't really play into this example. cache_latest is also immaterial since its job is to optimize repeated uses of *it and that doesn't happen in this example either (the issue is repeated invocations of the predicates, which aren't invoked on *it), so let's just focus on the first two adapters.

    Consider we want to do this loop:

    for (auto elem : r | take_while(f) | filter(g)) {
      use(elem);
    }
    

    Similar to the example in OP, we're just iterating linearly and performing some operation on every element.

    We want this to desugar into this:

    for (; it != end; ++it) {
      // take_while
      if (not f(*it)) break;
    
      // filter
      if (not g(*it)) continue;
    
      use(*it);
    }
    

    That would be optimal.

    However, we cannot get that out of the iterator model. Not directly. Instead, if we just directly desugar the range operations we get, it turns into something like this:

    // acquire begin
    for (; it != end and not f(*it); ++it) {
      if (g(*it)) break;
    }
    
    while (true) {
      // this is the bounds check on the loop
      if (it == end or not f(*it)) { // (A)
        break;
      }
    
      use(*it);
    
      // this is incrementing the iterator
      ++it;
      while (true) {
        if (it == end or not f(*it)) { // (B)
          break;
        }
        if (g(*it)) { // (C)
          break;
        }
        ++it;
      }
    }
    

    So you see the check in (A) (that's our top-level check in the range) and (B) (that's what filter has to do to find the next element) are the same exactly check. But here we have a problem:

    1. When (B) is true, we've run out of elements. And we know (A) will be true at that point, but we still have to check it again.
    2. When (B) is false and (C) is true, we found our next element after the filter. We know at that point that (A) will be false, but we still have to check it again.

    Those checks simply do not optimize well. Of course in your particular case, it would be wrong to optimize, because you're adding observable behavior, but in general even if they could optimize, they don't. It's just too complicated a loop structure. I've seen compilers fail to do this optimization even with just a single filter.


    In general, I haven't the slightest idea how to improve this. A specific solution that could help is to create a view that eagerly computes the iterator end(), so that you don't have to compute it lazily.

    That's not a very complicated adapter:

    template <class V>
    class eager_common_view : public std::ranges::view_interface<eager_common_view<V>> {
        V base_;
        std::optional<std::ranges::iterator_t<V>> end_;
    
    public:
        eager_common_view(V v) : base_(std::move(v)) { }
    
        auto begin() -> std::ranges::iterator_t<V> { return std::ranges::begin(base_); }
        auto end() -> std::ranges::iterator_t<V> {
            if (not end_) {
                end_.emplace(std::ranges::next(begin(), std::ranges::end(base_)));
            }
            return *end_;
        }
    };
    
    struct EagerCommon : std::ranges::range_adaptor_closure<EagerCommon> {
        template <std::ranges::viewable_range R>
        constexpr auto operator()(R&& r) const {
            if constexpr (std::ranges::common_range<R>) {
                return std::views::all((R&&)r);
            } else {
                return eager_common_view(std::views::all((R&&)r));
            }
        }
    };
    
    inline constexpr EagerCommon eager_common;
    

    With that:

    auto output = input
        | std::views::take_while([](const auto& n) {std::println("{:10} called: {}","take_while", n); return n.at(n.size()-1) != '5'; })
        | eager_common // <== this guy
        | std::views::filter   ([](const auto& n) {std::println("{:10} called: {}","filter", n); return n.at(n.size()-1) != '2'; })
        | std::views::transform([](const auto& n) {std::println("{:10} called: {}","transform", n);return "Hello "+n+" "; })
    

    now prints:

    take_while called: World 0
    take_while called: World 1
    take_while called: World 2
    take_while called: World 3
    take_while called: World 4
    take_while called: World 5
    filter     called: World 0
    transform  called: World 0
    Hello World 0
    filter     called: World 1
    transform  called: World 1
    Hello World 1
    filter     called: World 2
    filter     called: World 3
    transform  called: World 3
    Hello World 3
    filter     called: World 4
    transform  called: World 4
    Hello World 4
    

    Which is now only calling take_while and filter exactly once per element, and then transform only on the filter-ed once.

    That happens to work out in this case, but it's not really a general solution, and I'm not sure what other situations it might apply to.