pythonenumspicklepartial

How to pickle Enum with values of type functools.partial


Problem

Suppose we have a python Enum where values are of type functools.partial. How to pickle and unpickle a member of that enum ?

import pickle
from enum import Enum
from functools import partial


def function_a():
    pass


class EnumOfPartials(Enum):
    FUNCTION_A = partial(function_a)


if __name__ == "__main__":
    with open("test.pkl", "wb") as f:
        pickle.dump(EnumOfPartials.FUNCTION_A, f)

    with open("test.pkl", "rb") as f:
        pickle.load(f)

The code above tries to pickle and unpickle such an object. The pickle.load operations results in error:

ValueError: functools.partial(<function function_a at 0x7f5973e804a0>) is not a valid EnumOfPartials

Motivation

The object in itself is useful for configurations purposes: using hydra, I can have a parameter in a YAML that corresponds to the choice of a function in the enum. The reason for using partial is so that FUNCTION_A does not get interpreted as a method (see this question). Being able to pickle a member of this enum is desirable to be able to send it to another process.

Given my use case, an obvious workaround would be to have a dictionary of functions indexed by an enum, but I would prefer directly having the relevant value (the function) in the enum.

Note

I am using python 3.11.11.


Solution

  • The first answer is to pickle by name:

    from enum import Enum, pickle_by_enum_name
    
    class EnumDefs(Enum):
        __reduce_ex__ = pickle_by_enum_name
    
    

    The second answer is to use the new member class/decorator to avoid using partial, and to add __call__ so you can actually invoke the members:

    import pickle
    from enum import Enum, member, pickle_by_enum_name
    
    def function_a():
        print('function a()!')
    
    class EnumOfFunctions(Enum):
        #
        __reduce_ex__ = pickle_by_enum_name
        #
        def __call__(self, *args, **kwds):
            return self._value_(*args, **kwds)
        #
        FUNCTION_A = member(function_a)
    
    
    if __name__ == "__main__":
        with open("test.pkl", "wb") as f:
            pickle.dump(EnumOfFunctions.FUNCTION_A, f)
    
        with open("test.pkl", "rb") as f:
            func = pickle.load(f)
            print(func)
            func()
    
    

    Note that you could also just have the functions themselves be in the enum:

    class EnumOfFunctions(Enum):
        .
        .
        .
        #
        @member
        def function_a():
            print('function a()!')
    

    1 Disclosure: I am the author of the Python stdlib Enum, the enum34 backport, and the Advanced Enumeration (aenum) library.