pythonenumspydantic

Allow Enum Names as Valid Inputs with Pydantic's @validate_call


This question asks about using the name of an enum when serializing a model. I want something like that except with the @validate_call decorator.

Take this function foo():

from enum import Enum
from pydantic import validate_call

class Direction(Enum):
    NORTH = 0
    EAST = 1
    SOUTH = 2
    WEST = 3

@validate_call
def foo(d: Direction):
    print(d)

I want all of these inputs to work:

# These work
>>> foo(0)
Direction.NORTH
>>> foo(Direction.EAST)
Direction.EAST

# These don't, but I want them to
>>> foo('WEST')
Direction.WEST
>>> foo('   sOUtH ') # This would be great though not essential
Direction.SOUTH

What's the simplest way to do this?

If it requires creating a function used as a BeforeValidator, I'd prefer that that function be generic. I have many foos and many enums, and I don't want a separate validator for handling each one. (At that point, it's easier to validate the type within the function itself instead of using Pydantic).


Solution

  • I can think of 2 options. The simplest solution would be to make Enum more flexible.

    from enum import Enum
    
    from pydantic import validate_call
    
    
    def normalize_enum_name(name: str) -> str:
        """A user-defined normalization function for enum names."""
        # The results are only used as dictionary keys, so you can do whatever you like.
        return name.strip().upper()
    
    
    class FlexibleEnumNameMixin:
        @classmethod
        def _missing_(cls: type[Enum], value: object) -> Enum | None:
            enum_name_map = {normalize_enum_name(member.name): member for member in cls}
            if isinstance(value, str) and (normalized_name := normalize_enum_name(value)) in enum_name_map:
                return enum_name_map[normalized_name]
            return None
    
    
    class Direction(FlexibleEnumNameMixin, Enum):  # Insert the mixin here.
        NORTH = 0
        EAST = 1
        SOUTH = 2
        WEST = 3
    
    
    @validate_call
    def foo(d: Direction):
        print(repr(d))
    
    
    foo(0)  # <Direction.NORTH: 0>
    foo(Direction.EAST)  # <Direction.EAST: 1>
    foo("WEST")  # <Direction.WEST: 3>
    foo("   sOUtH ")  # <Direction.SOUTH: 2>
    foo("unknown")  # ValidationError
    

    The good (or bad) thing about this approach is that your Enum class will also be able to accept such badly-formed input.

    print(repr(Direction("   sOUtH ")))  # <Direction.SOUTH: 2>
    

    This might be a useful feature, but its scope is very wide and could cause unexpected bugs.

    A probably safer but slightly more complex solution would be to implement a custom validator. Below is a working example of a custom field before validator.

    from enum import Enum
    from typing import Annotated
    
    from pydantic import validate_call
    from pydantic_core import core_schema
    
    
    class FlexibleEnumNameValidator:
        @classmethod
        def __get_pydantic_core_schema__(cls, source_type: type[Enum], _):
            enum_name_map = {normalize_enum_name(member.name): member for member in source_type}
    
            def parse_as_enum(value: object):
                if isinstance(value, str) and (normalized_name := normalize_enum_name(value)) in enum_name_map:
                    return enum_name_map[normalized_name]
                return value
    
            return core_schema.no_info_before_validator_function(
                parse_as_enum,
                schema=core_schema.enum_schema(source_type, list(source_type)),
            )
    
    
    class Direction(Enum):  # No need to modify your enum class!
        NORTH = 0
        EAST = 1
        SOUTH = 2
        WEST = 3
    
    
    @validate_call
    def foo(d: Annotated[Direction, FlexibleEnumNameValidator]):  # Append the validator here.
        print(repr(d))
    
    
    foo(0)  # <Direction.NORTH: 0>
    foo(Direction.EAST)  # <Direction.EAST: 1>
    foo("WEST")  # <Direction.WEST: 3>
    foo("   sOUtH ")  # <Direction.SOUTH: 2>
    foo("unknown")  # ValidationError
    

    With this approach, you can explicitly specify which arguments you want to be flexible.

    Note that this validator can be used as a mixin for Enum as well. Unlike the mixin that overrides _missing_, this is only used by pydantic.

    class Direction(FlexibleEnumNameValidator, Enum):  # Insert the validator here.
        NORTH = 0
        EAST = 1
        SOUTH = 2
        WEST = 3
    
    
    @validate_call
    def foo(d: Direction):
        print(repr(d))
    
    
    foo("   sOUtH ")  # <Direction.SOUTH: 2>
    print(repr(Direction("   sOUtH ")))  # ValueError