pythonpython-decoratorspython-typingcode-inspection

Python Typing: Validation Decorator for Literal typed Arguments


Often I encounter the scenario of functions which accept a finite set of values only. I know how to reflect this behavior in the type annotations, using typing.Literal like so:

import typing


def func(a: typing.Literal['foo', 'bar']):
    pass

I would like to have a decorator @validate_literals which validates that the parameters to the are consistent with their type:

@validate_literals
def picky_typed_function(
    binary: typing.Literal[0, 1],
    char: typing.Literal['a', 'b']
) -> None:
    pass

so that the input is validated against the restrictions defined by the arguments's types, and a ValueError is raised in case of a violation:

picky_typed_function(0, 'a')  # should pass
picky_typed_function(2, 'a')  # should raise "ValueError: binary must be one of (0, 1)')"
picky_typed_function(0, 'c')  # should raise "ValueError: char must be one of ('a', 'b')"
picky_typed_function(0, char='c')  # should raise "ValueError: char must be one of ('a', 'b')"
picky_typed_function(binary=2, char='c')  # should raise "ValueError: binary must be one of (0, 1)"

typing type checks are designed to be static, and not happen during runtime. How can I leverage the typing definition for runtime validation?


Solution

  • We can inspect the decorated (validated) function's signature by using inspect.signature, check which of the parameters of the function is typed as a Literal alias by getting the "origin" of the parameter's annotation through typing.get_origin() (or, for python versions < 3.8, using __origin__) and retrieve the valid values by using [typing.get_args()] (https://stackoverflow.com/a/64522240/3566606) (and iterating recursively over nested Literal definitions) from the Literal alias.

    In order to do that, all that is left to do, is to figure out which parameters have been passed as positional arguments and map the corresponding values to the parameter's name, so the value can be compared against the valid values of the parameter.

    Finally, we build the decorator using the standard recipe with functools.wraps. In the end, this is the code:

    import inspect
    import typing
    import functools
    
    
    def args_to_kwargs(func: typing.Callable, *args: list, **kwargs: dict) -> dict:
        args_dict = {
            list(inspect.signature(func).parameters.keys())[i]: arg
            for i, arg in enumerate(args)
        }
    
        return {**args_dict, **kwargs}
    
    
    def valid_args_from_literal(annotation: _GenericAlias) -> Set[Any]:
        args = get_args(annotation)
        valid_values = []
    
        for arg in args:
            if typing.get_origin(annotation) is Literal:
                valid_values += valid_args_from_literal(arg)
            else:
                valid_values += [arg]
    
        return set(valid_values)
    
    
    def validate_literals(func: typing.Callable) -> typing.Callable:
        @functools.wraps(func)
        def validated(*args, **kwargs):
            kwargs = args_to_kwargs(func, *args, **kwargs)
            for name, parameter in inspect.signature(func).parameters.items():
                # use parameter.annotation.__origin__ for Python versions < 3.8
                if typing.get_origin(parameter.annotation) is typing.Literal:
                    valid_values = valid_args_from_literal(parameter.annotation)
                    if kwargs[name] not in valid_values:
                        raise ValueError(
                            f"Argument '{name}' must be one of {valid_values}"
                        )
    
            return func(**kwargs)
    
        return validated
    

    This gives the results specified in the question.

    I have also published the alpha version of a python package runtime-typing to perform runtime typechecking: https://pypi.org/project/runtime-typing/ (documentation:https://runtime-typing.readthedocs.io) which handles more cases than just typing.Literal, such as typing.TypeVar and typing.Union.