Why can't I have a std::set
or std::unordered_set
of std::function
s?
Is there any way to get it to work anyway?
Why can't I have a
std::set
orstd::unordered_set
ofstd::function
s?
std::set
relies on a comparator, which is used to determine if one element is less than the other.
It uses std::less
by default, and std::less
doesn't work with std::function
s.
(Because there is no way to properly compare std::function
s.)
Similarly, std::unordered_set
relies on std::hash
and std::equal_to
(or custom replacements for them), which also don't operate on std::function
s.
Is there any way to get it to work anyway?
You can write a wrapper around (or a replacement for) std::function
that works with std::less
, std::equal_to
and/or std::hash
.
Via power of type erasure, you can forward std::less
/std::equal_to
/std::hash
to objects stored in your wrapper.
Here's a proof-of-concept for such a wrapper.
Notes:
You can specify whether or not you want the class FancyFunction
to work with std::less
, std::equal_to
and std::hash
separetely, by adjusting a template argument.
If some of those is enabled, you'll be able to apply them to FancyFunction
.
Naturally, you'll be able to construct FancyFunction
from a type only if they can be applied to that type.
There is a static assertion that fires when a type fails to provide std::hash
if it's needed.
It seems to be impossible to SFINAE on availability of std::less
and std::equal_to
, so I couldn't make similar assertions for those.
In theory, you could support types that don't work with std::less
, std::equal_to
and/or std::hash
by always considering all instances of one type equivalent, and using typeid(T).hash_code()
as a hash.
I'm not sure if that behavior is desirable or not, implementing it is left as an exercise to the reader.
(Lack of SFINAE for std::less
and std::equal_to
will make it harder to implement properly.)
Specifying custom replacements for std::less
, std::equal_to
and std::hash
is not supported, implementing that is also left as an exercise to the reader.
(This means that this implementation can only be used to put lambdas into a regular std::set
, not std::unordered_set
.)
When applied to FancyFunction
, std::less
and std::equal_to
will first compare types of stored functors.
If types are identical, they'll resort to calling std::less
/std::equal_to
on underlying instances.
(Thus, for two arbitrary different functor types, std::less
will always consider instances of one of them less than instances of the other one. The order is not stable between program invocations.)
Example usage:
// With `std::set`:
#include <iostream>
#include <set>
struct AddN
{
int n;
int operator()(int x) const {return n + x;}
friend bool operator<(AddN a, AddN b) {return a.n < b.n;}
};
int main()
{
using func_t = FancyFunction<int(int), FunctionFlags::comparable_less>;
// Note that `std::less` can operate on stateless lambdas by converting them to function pointers first. Otherwise this wouldn't work.
auto square = [](int x){return x*x;};
auto cube = [](int x){return x*x*x;};
std::set<func_t> set;
set.insert(square);
set.insert(square); // Dupe.
set.insert(cube);
set.insert(AddN{100});
set.insert(AddN{200});
set.insert(AddN{200}); // Dupe.
for (const auto &it : set)
std::cout << "2 -> " << it(2) << '\n';
std::cout << '\n';
/* Prints:
* 2 -> 4 // `square`, note that it appears only once.
* 2 -> 8 // `cube`
* 2 -> 102 // `AddN{100}`
* 2 -> 202 // `AddN{200}`, also appears once.
*/
set.erase(set.find(cube));
set.erase(set.find(AddN{100}));
for (const auto &it : set)
std::cout << "2 -> " << it(2) << '\n';
std::cout << '\n';
/* Prints:
* 2 -> 4 // `square`
* 2 -> 202 // `AddN{200}`
* `cube` and `AddN{100}` were removed.
*/
}
// With `std::unordered_set`:
#include <iostream>
#include <unordered_set>
struct AddN
{
int n;
int operator()(int x) const {return n + x;}
friend bool operator==(AddN a, AddN b) {return a.n == b.n;}
};
struct MulByN
{
int n;
int operator()(int x) const {return n * x;}
friend bool operator==(MulByN a, MulByN b) {return a.n == b.n;}
};
namespace std
{
template <> struct hash<AddN>
{
using argument_type = AddN;
using result_type = std::size_t;
size_t operator()(AddN f) const {return f.n;}
};
template <> struct hash<MulByN>
{
using argument_type = MulByN;
using result_type = std::size_t;
size_t operator()(MulByN f) const {return f.n;}
};
}
int main()
{
using hashable_func_t = FancyFunction<int(int), FunctionFlags::hashable | FunctionFlags::comparable_eq>;
std::unordered_set<hashable_func_t> set;
set.insert(AddN{100});
set.insert(AddN{100}); // Dupe.
set.insert(AddN{200});
set.insert(MulByN{10});
set.insert(MulByN{20});
set.insert(MulByN{20}); // Dupe.
for (const auto &it : set)
std::cout << "2 -> " << it(2) << '\n';
std::cout << '\n';
/* Prints:
* 2 -> 40 // `MulByN{20}`
* 2 -> 20 // `MulByN{10}`
* 2 -> 102 // `AddN{100}`
* 2 -> 202 // `AddN{200}`
*/
set.erase(set.find(AddN{100}));
set.erase(set.find(MulByN{20}));
for (const auto &it : set)
std::cout << "2 -> " << it(2) << '\n';
std::cout << '\n';
/* Prints:
* 2 -> 20 // `MulByN{10}`
* 2 -> 202 // `AddN{200}`
* `MulByN{20}` and `AddN{100}` were removed.
*/
}
Implementation:
#include <cstddef>
#include <functional>
#include <experimental/type_traits>
#include <utility>
enum class FunctionFlags
{
none = 0,
comparable_less = 0b1,
comparable_eq = 0b10,
hashable = 0b100,
};
constexpr FunctionFlags operator|(FunctionFlags a, FunctionFlags b) {return FunctionFlags(int(a) | int(b));}
constexpr FunctionFlags operator&(FunctionFlags a, FunctionFlags b) {return FunctionFlags(int(a) & int(b));}
template <typename T> using detect_hashable = decltype(std::hash<T>{}(std::declval<const T &>()));
template <typename T, FunctionFlags Flags = FunctionFlags::none>
class FancyFunction;
template <typename ReturnType, typename ...ParamTypes, FunctionFlags Flags>
class FancyFunction<ReturnType(ParamTypes...), Flags>
{
struct TypeDetails
{
int index = 0;
bool (*less)(const void *, const void *) = 0;
bool (*eq)(const void *, const void *) = 0;
std::size_t (*hash)(const void *) = 0;
inline static int index_counter = 0;
};
template <typename T> const TypeDetails *GetDetails()
{
static TypeDetails ret = []()
{
using type = std::remove_cv_t<std::remove_reference_t<T>>;
TypeDetails d;
d.index = TypeDetails::index_counter++;
if constexpr (comparable_less)
{
// We can't SFINAE on `std::less`.
d.less = [](const void *a_ptr, const void *b_ptr) -> bool
{
const type &a = *static_cast<const FancyFunction *>(a_ptr)->func.template target<type>();
const type &b = *static_cast<const FancyFunction *>(b_ptr)->func.template target<type>();
return std::less<type>{}(a, b);
};
}
if constexpr (comparable_eq)
{
// We can't SFINAE on `std::equal_to`.
d.eq = [](const void *a_ptr, const void *b_ptr) -> bool
{
const type &a = *static_cast<const FancyFunction *>(a_ptr)->func.template target<type>();
const type &b = *static_cast<const FancyFunction *>(b_ptr)->func.template target<type>();
return std::equal_to<type>{}(a, b);
};
}
if constexpr (hashable)
{
static_assert(std::experimental::is_detected_v<detect_hashable, type>, "This type is not hashable.");
d.hash = [](const void *a_ptr) -> std::size_t
{
const type &a = *static_cast<const FancyFunction *>(a_ptr)->func.template target<type>();
return std::hash<type>(a);
};
}
return d;
}();
return &ret;
}
std::function<ReturnType(ParamTypes...)> func;
const TypeDetails *details = 0;
public:
inline static constexpr bool
comparable_less = bool(Flags & FunctionFlags::comparable_less),
comparable_eq = bool(Flags & FunctionFlags::comparable_eq),
hashable = bool(Flags & FunctionFlags::hashable);
FancyFunction(decltype(nullptr) = nullptr) {}
template <typename T>
FancyFunction(T &&obj)
{
func = std::forward<T>(obj);
details = GetDetails<T>();
}
explicit operator bool() const
{
return bool(func);
}
ReturnType operator()(ParamTypes ... params) const
{
return ReturnType(func(std::forward<ParamTypes>(params)...));
}
bool less(const FancyFunction &other) const
{
static_assert(comparable_less, "This function is disabled.");
if (int delta = bool(details) - bool(other.details)) return delta < 0;
if (!details) return 0;
if (int delta = details->index - other.details->index) return delta < 0;
return details->less(this, &other);
}
bool equal_to(const FancyFunction &other) const
{
static_assert(comparable_eq, "This function is disabled.");
if (bool(details) != bool(other.details)) return 0;
if (!details) return 1;
if (details->index != other.details->index) return 0;
return details->eq(this, &other);
}
std::size_t hash() const
{
static_assert(hashable, "This function is disabled.");
if (!details) return 0;
return details->hash(this);
}
friend bool operator<(const FancyFunction &a, const FancyFunction &b) {return a.less(b);}
friend bool operator>(const FancyFunction &a, const FancyFunction &b) {return b.less(a);}
friend bool operator<=(const FancyFunction &a, const FancyFunction &b) {return !b.less(a);}
friend bool operator>=(const FancyFunction &a, const FancyFunction &b) {return !a.less(b);}
friend bool operator==(const FancyFunction &a, const FancyFunction &b) {return a.equal_to(b);}
friend bool operator!=(const FancyFunction &a, const FancyFunction &b) {return !a.equal_to(b);}
};
namespace std
{
template <typename T, FunctionFlags Flags> struct hash<FancyFunction<T, Flags>>
{
using argument_type = FancyFunction<T, Flags>;
using result_type = std::size_t;
size_t operator()(const FancyFunction<T, Flags> &f) const
{
return f.hash();
}
};
}