c++unordered-mapstd-functionstdset

Why can't I use std::function as a std::set or std::unordered_set value type?


Why can't I have a std::set or std::unordered_set of std::functions?

Is there any way to get it to work anyway?


Solution

  • Why can't I have a std::set or std::unordered_set of std::functions?

    std::set relies on a comparator, which is used to determine if one element is less than the other.

    It uses std::less by default, and std::less doesn't work with std::functions.
    (Because there is no way to properly compare std::functions.)

    Similarly, std::unordered_set relies on std::hash and std::equal_to (or custom replacements for them), which also don't operate on std::functions.


    Is there any way to get it to work anyway?

    You can write a wrapper around (or a replacement for) std::function that works with std::less, std::equal_to and/or std::hash.

    Via power of type erasure, you can forward std::less/std::equal_to/std::hash to objects stored in your wrapper.

    Here's a proof-of-concept for such a wrapper.

    Notes:

    Example usage:

    // With `std::set`:
    
    #include <iostream>
    #include <set>
    
    struct AddN
    {
        int n;
        int operator()(int x) const {return n + x;}
        friend bool operator<(AddN a, AddN b) {return a.n < b.n;}
    };
    
    int main()
    {   
        using func_t = FancyFunction<int(int), FunctionFlags::comparable_less>;
    
        // Note that `std::less` can operate on stateless lambdas by converting them to function pointers first. Otherwise this wouldn't work.
        auto square = [](int x){return x*x;};
        auto cube = [](int x){return x*x*x;};
    
        std::set<func_t> set;
        set.insert(square);
        set.insert(square); // Dupe.
        set.insert(cube);
        set.insert(AddN{100});
        set.insert(AddN{200});
        set.insert(AddN{200}); // Dupe.
    
        for (const auto &it : set)
            std::cout << "2 -> " << it(2) << '\n';
        std::cout << '\n';
        /* Prints:
         * 2 -> 4   // `square`, note that it appears only once.
         * 2 -> 8   // `cube`
         * 2 -> 102 // `AddN{100}`
         * 2 -> 202 // `AddN{200}`, also appears once.
         */
    
        set.erase(set.find(cube));
        set.erase(set.find(AddN{100}));
    
        for (const auto &it : set)
            std::cout << "2 -> " << it(2) << '\n';
        std::cout << '\n';
        /* Prints:
         * 2 -> 4   // `square`
         * 2 -> 202 // `AddN{200}`
         * `cube` and `AddN{100}` were removed.
         */
    }
    
    
    // With `std::unordered_set`:
    
    #include <iostream>
    #include <unordered_set>
    
    struct AddN
    {
        int n;
        int operator()(int x) const {return n + x;}
        friend bool operator==(AddN a, AddN b) {return a.n == b.n;}
    };
    
    struct MulByN
    {
        int n;
        int operator()(int x) const {return n * x;}
        friend bool operator==(MulByN a, MulByN b) {return a.n == b.n;}
    };
    
    namespace std
    {
        template <> struct hash<AddN>
        {
            using argument_type = AddN;
            using result_type = std::size_t;
            size_t operator()(AddN f) const {return f.n;}
        };
    
        template <> struct hash<MulByN>
        {
            using argument_type = MulByN;
            using result_type = std::size_t;
            size_t operator()(MulByN f) const {return f.n;}
        };
    }
    
    int main()
    {   
        using hashable_func_t = FancyFunction<int(int), FunctionFlags::hashable | FunctionFlags::comparable_eq>;
        std::unordered_set<hashable_func_t> set;
        set.insert(AddN{100});
        set.insert(AddN{100}); // Dupe.
        set.insert(AddN{200});
        set.insert(MulByN{10});
        set.insert(MulByN{20});
        set.insert(MulByN{20}); // Dupe.
    
        for (const auto &it : set)
            std::cout << "2 -> " << it(2) << '\n';
        std::cout << '\n';
        /* Prints:
         * 2 -> 40  // `MulByN{20}`
         * 2 -> 20  // `MulByN{10}`
         * 2 -> 102 // `AddN{100}`
         * 2 -> 202 // `AddN{200}`
         */
    
        set.erase(set.find(AddN{100}));
        set.erase(set.find(MulByN{20}));
    
        for (const auto &it : set)
            std::cout << "2 -> " << it(2) << '\n';
        std::cout << '\n';
        /* Prints:
         * 2 -> 20  // `MulByN{10}`
         * 2 -> 202 // `AddN{200}`
         * `MulByN{20}` and `AddN{100}` were removed.
         */
    }
    

    Implementation:

    #include <cstddef>
    #include <functional>
    #include <experimental/type_traits>
    #include <utility>
    
    enum class FunctionFlags
    {
        none            = 0,
        comparable_less = 0b1,
        comparable_eq   = 0b10,
        hashable        = 0b100,
    };
    constexpr FunctionFlags operator|(FunctionFlags a, FunctionFlags b) {return FunctionFlags(int(a) | int(b));}
    constexpr FunctionFlags operator&(FunctionFlags a, FunctionFlags b) {return FunctionFlags(int(a) & int(b));}
    
    
    template <typename T> using detect_hashable = decltype(std::hash<T>{}(std::declval<const T &>()));
    
    
    template <typename T, FunctionFlags Flags = FunctionFlags::none>
    class FancyFunction;
    
    template <typename ReturnType, typename ...ParamTypes, FunctionFlags Flags>
    class FancyFunction<ReturnType(ParamTypes...), Flags>
    {
        struct TypeDetails
        {
            int index = 0;
            bool (*less)(const void *, const void *) = 0;
            bool (*eq)(const void *, const void *) = 0;
            std::size_t (*hash)(const void *) = 0;
    
            inline static int index_counter = 0;
        };
    
        template <typename T> const TypeDetails *GetDetails()
        {
            static TypeDetails ret = []()
            {
                using type = std::remove_cv_t<std::remove_reference_t<T>>;
    
                TypeDetails d;
    
                d.index = TypeDetails::index_counter++;
    
                if constexpr (comparable_less)
                {
                    // We can't SFINAE on `std::less`.
                    d.less = [](const void *a_ptr, const void *b_ptr) -> bool
                    {
                        const type &a = *static_cast<const FancyFunction *>(a_ptr)->func.template target<type>();
                        const type &b = *static_cast<const FancyFunction *>(b_ptr)->func.template target<type>();
                        return std::less<type>{}(a, b);
                    };
                }
    
                if constexpr (comparable_eq)
                {
                    // We can't SFINAE on `std::equal_to`.
                    d.eq = [](const void *a_ptr, const void *b_ptr) -> bool
                    {
                        const type &a = *static_cast<const FancyFunction *>(a_ptr)->func.template target<type>();
                        const type &b = *static_cast<const FancyFunction *>(b_ptr)->func.template target<type>();
                        return std::equal_to<type>{}(a, b);
                    };
                }
    
                if constexpr (hashable)
                {
                    static_assert(std::experimental::is_detected_v<detect_hashable, type>, "This type is not hashable.");
                    d.hash = [](const void *a_ptr) -> std::size_t
                    {
                        const type &a = *static_cast<const FancyFunction *>(a_ptr)->func.template target<type>();
                        return std::hash<type>(a);
                    };
                }
    
                return d;
            }();
            return &ret;
        }
    
        std::function<ReturnType(ParamTypes...)> func;
        const TypeDetails *details = 0;
    
      public:
        inline static constexpr bool
            comparable_less = bool(Flags & FunctionFlags::comparable_less),
            comparable_eq   = bool(Flags & FunctionFlags::comparable_eq),
            hashable        = bool(Flags & FunctionFlags::hashable);
    
        FancyFunction(decltype(nullptr) = nullptr) {}
    
        template <typename T>
        FancyFunction(T &&obj)
        {
            func = std::forward<T>(obj);    
            details = GetDetails<T>();
        }
    
        explicit operator bool() const
        {
            return bool(func);
        }
    
        ReturnType operator()(ParamTypes ... params) const
        {
            return ReturnType(func(std::forward<ParamTypes>(params)...));
        }
    
        bool less(const FancyFunction &other) const
        {
            static_assert(comparable_less, "This function is disabled.");
            if (int delta = bool(details) - bool(other.details)) return delta < 0;
            if (!details) return 0;
            if (int delta = details->index - other.details->index) return delta < 0;
            return details->less(this, &other);
        }
    
        bool equal_to(const FancyFunction &other) const
        {
            static_assert(comparable_eq, "This function is disabled.");
            if (bool(details) != bool(other.details)) return 0;
            if (!details) return 1;
            if (details->index != other.details->index) return 0;
            return details->eq(this, &other);
        }
    
        std::size_t hash() const
        {
            static_assert(hashable, "This function is disabled.");
            if (!details) return 0;
            return details->hash(this);
        }
    
        friend bool operator<(const FancyFunction &a, const FancyFunction &b) {return a.less(b);}
        friend bool operator>(const FancyFunction &a, const FancyFunction &b) {return b.less(a);}
        friend bool operator<=(const FancyFunction &a, const FancyFunction &b) {return !b.less(a);}
        friend bool operator>=(const FancyFunction &a, const FancyFunction &b) {return !a.less(b);}
        friend bool operator==(const FancyFunction &a, const FancyFunction &b) {return a.equal_to(b);}
        friend bool operator!=(const FancyFunction &a, const FancyFunction &b) {return !a.equal_to(b);}
    };
    
    namespace std
    {
        template <typename T, FunctionFlags Flags> struct hash<FancyFunction<T, Flags>>
        {
            using argument_type = FancyFunction<T, Flags>;
            using result_type = std::size_t;
            size_t operator()(const FancyFunction<T, Flags> &f) const
            {
                return f.hash();
            }
        };
    }