This question asks about using the name of an enum when serializing a model. I want something like that except with the @validate_call
decorator.
Take this function foo()
:
from enum import Enum
from pydantic import validate_call
class Direction(Enum):
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
@validate_call
def foo(d: Direction):
print(d)
I want all of these inputs to work:
# These work
>>> foo(0)
Direction.NORTH
>>> foo(Direction.EAST)
Direction.EAST
# These don't, but I want them to
>>> foo('WEST')
Direction.WEST
>>> foo(' sOUtH ') # This would be great though not essential
Direction.SOUTH
What's the simplest way to do this?
If it requires creating a function used as a BeforeValidator
, I'd prefer that that function be generic. I have many foo
s and many enums, and I don't want a separate validator for handling each one. (At that point, it's easier to validate the type within the function itself instead of using Pydantic).
I can think of 2 options. The simplest solution would be to make Enum more flexible.
from enum import Enum
from pydantic import validate_call
def normalize_enum_name(name: str) -> str:
"""A user-defined normalization function for enum names."""
# The results are only used as dictionary keys, so you can do whatever you like.
return name.strip().upper()
class FlexibleEnumNameMixin:
@classmethod
def _missing_(cls: type[Enum], value: object) -> Enum | None:
enum_name_map = {normalize_enum_name(member.name): member for member in cls}
if isinstance(value, str) and (normalized_name := normalize_enum_name(value)) in enum_name_map:
return enum_name_map[normalized_name]
return None
class Direction(FlexibleEnumNameMixin, Enum): # Insert the mixin here.
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
@validate_call
def foo(d: Direction):
print(repr(d))
foo(0) # <Direction.NORTH: 0>
foo(Direction.EAST) # <Direction.EAST: 1>
foo("WEST") # <Direction.WEST: 3>
foo(" sOUtH ") # <Direction.SOUTH: 2>
foo("unknown") # ValidationError
The good (or bad) thing about this approach is that your Enum class will also be able to accept such badly-formed input.
print(repr(Direction(" sOUtH "))) # <Direction.SOUTH: 2>
This might be a useful feature, but its scope is very wide and could cause unexpected bugs.
A probably safer but slightly more complex solution would be to implement a custom validator. Below is a working example of a custom field before validator.
from enum import Enum
from typing import Annotated
from pydantic import validate_call
from pydantic_core import core_schema
class FlexibleEnumNameValidator:
@classmethod
def __get_pydantic_core_schema__(cls, source_type: type[Enum], _):
enum_name_map = {normalize_enum_name(member.name): member for member in source_type}
def parse_as_enum(value: object):
if isinstance(value, str) and (normalized_name := normalize_enum_name(value)) in enum_name_map:
return enum_name_map[normalized_name]
return value
return core_schema.no_info_before_validator_function(
parse_as_enum,
schema=core_schema.enum_schema(source_type, list(source_type)),
)
class Direction(Enum): # No need to modify your enum class!
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
@validate_call
def foo(d: Annotated[Direction, FlexibleEnumNameValidator]): # Append the validator here.
print(repr(d))
foo(0) # <Direction.NORTH: 0>
foo(Direction.EAST) # <Direction.EAST: 1>
foo("WEST") # <Direction.WEST: 3>
foo(" sOUtH ") # <Direction.SOUTH: 2>
foo("unknown") # ValidationError
With this approach, you can explicitly specify which arguments you want to be flexible.
Note that this validator can be used as a mixin for Enum as well.
Unlike the mixin that overrides _missing_
, this is only used by pydantic.
class Direction(FlexibleEnumNameValidator, Enum): # Insert the validator here.
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
@validate_call
def foo(d: Direction):
print(repr(d))
foo(" sOUtH ") # <Direction.SOUTH: 2>
print(repr(Direction(" sOUtH "))) # ValueError