Pydantic allows list's element with nested discriminated unions. Is there any elegant way to apply constraints (such as MaxLen, MinLen) on sublist based on inner discriminator without writing custom validator?
For example: In below PetModel
, limit number of Cat
to MinLen(1) and Dog
to MaxLen(2)
PetsModel.model_validate(
{
"pets": [
{"pet_type": "cat", "color": "black", "black_name": "black_cat_name"},
{"pet_type": "dog", "name": "dog_name"},
]
}
)
Adapted from https://docs.pydantic.dev/latest/concepts/unions/#nested-discriminated-unions
from typing import Literal, Union
from typing_extensions import Annotated
from pydantic import BaseModel, Field, ValidationError
class BlackCat(BaseModel):
pet_type: Literal['cat']
color: Literal['black']
black_name: str
class WhiteCat(BaseModel):
pet_type: Literal['cat']
color: Literal['white']
white_name: str
Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')]
class Dog(BaseModel):
pet_type: Literal['dog']
name: str
Pet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')]
class PetsModel(BaseModel):
pets: list[Pet] # list of pet
I know that we can apply constraints on pets list, as follows
from annotated_types import MaxLen, MinLen
class PetsModel(BaseModel):
pets: Annotated[list[Pet], MinLen(1), MaxLen(5)] # list of pet
But, I want to apply constraints on number of Cat
and Dog
which are elements of pets list.
You could create a field validator on pets
after [pydantic] validation, so you'd be checking an array of objects like so:
from typing import Union, List
from annotated_types import MaxLen, MinLen
from collections import Counter
from pydantic import field_validator
class PetsModel(BaseModel):
pets: Annotated[List[Pet], MinLen(1), MaxLen(5)]
@field_validator('pets')
@classmethod
def special_rules(cls, v: List[Pet]) -> str:
rules = {
'Dog': {'min': 1, 'max': None},
'BlackCat': {'min': None, 'max': 5}
}
# This will count child class name (BlackCat, not Cat)
c = Counter([x.__class__.__name__ for x in v])
# Handy helper function
replace_none = lambda x, y: x if x is not None else y
for key in rules:
x_min = replace_none(rules[key].get('min'), 0)
x_max = replace_none(rules[key].get('max'), float('inf'))
inbounds = x_min <= c.get(key, 0) <= x_max
if not inbounds:
raise ValueError("not valid value")
Now, you can be even fancier and put the rules
object in your field via json_schema_extra
like so:
rules = {
'Dog': {'min': 1, 'max': None},
'BlackCat': {'min': None, 'max': 5}
}
class PetsModel(BaseModel):
pets: Annotated[
List[Pet], MinLen(1), MaxLen(5),
Field(json_schema_extra={"rules": rules})
]
@field_validator('pets')
@classmethod
def special_rules(cls, v: List[Pet], info: ValidationInfo) -> str:
rules = cls.model_fields[info.field_name].json_schema_extra["rules"]
...
Hope this helps!