Often I encounter the scenario of functions which accept a finite set of values only. I know how to reflect this behavior in the type annotations, using typing.Literal
like so:
import typing
def func(a: typing.Literal['foo', 'bar']):
pass
I would like to have a decorator @validate_literals
which validates that the parameters to the are consistent with their type:
@validate_literals
def picky_typed_function(
binary: typing.Literal[0, 1],
char: typing.Literal['a', 'b']
) -> None:
pass
so that the input is validated against the restrictions defined by the arguments's types, and a ValueError
is raised in case of a violation:
picky_typed_function(0, 'a') # should pass
picky_typed_function(2, 'a') # should raise "ValueError: binary must be one of (0, 1)')"
picky_typed_function(0, 'c') # should raise "ValueError: char must be one of ('a', 'b')"
picky_typed_function(0, char='c') # should raise "ValueError: char must be one of ('a', 'b')"
picky_typed_function(binary=2, char='c') # should raise "ValueError: binary must be one of (0, 1)"
typing
type checks are designed to be static, and not happen during runtime. How can I leverage the typing definition for runtime validation?
We can inspect the decorated (validated) function's signature by using inspect.signature
, check which of the parameters of the function is typed as a Literal alias by getting the "origin" of the parameter's annotation through typing.get_origin()
(or, for python versions < 3.8, using __origin__
) and retrieve the valid values by using [typing.get_args()
] (https://stackoverflow.com/a/64522240/3566606) (and iterating recursively over nested Literal definitions) from the Literal alias.
In order to do that, all that is left to do, is to figure out which parameters have been passed as positional arguments and map the corresponding values to the parameter's name, so the value can be compared against the valid values of the parameter.
Finally, we build the decorator using the standard recipe with functools.wraps
. In the end, this is the code:
import inspect
import typing
import functools
def args_to_kwargs(func: typing.Callable, *args: list, **kwargs: dict) -> dict:
args_dict = {
list(inspect.signature(func).parameters.keys())[i]: arg
for i, arg in enumerate(args)
}
return {**args_dict, **kwargs}
def valid_args_from_literal(annotation: _GenericAlias) -> Set[Any]:
args = get_args(annotation)
valid_values = []
for arg in args:
if typing.get_origin(annotation) is Literal:
valid_values += valid_args_from_literal(arg)
else:
valid_values += [arg]
return set(valid_values)
def validate_literals(func: typing.Callable) -> typing.Callable:
@functools.wraps(func)
def validated(*args, **kwargs):
kwargs = args_to_kwargs(func, *args, **kwargs)
for name, parameter in inspect.signature(func).parameters.items():
# use parameter.annotation.__origin__ for Python versions < 3.8
if typing.get_origin(parameter.annotation) is typing.Literal:
valid_values = valid_args_from_literal(parameter.annotation)
if kwargs[name] not in valid_values:
raise ValueError(
f"Argument '{name}' must be one of {valid_values}"
)
return func(**kwargs)
return validated
This gives the results specified in the question.
I have also published the alpha version of a python package runtime-typing
to perform runtime typechecking: https://pypi.org/project/runtime-typing/ (documentation:https://runtime-typing.readthedocs.io) which handles more cases than just typing.Literal
, such as typing.TypeVar
and typing.Union
.