pydantic

How to add and retrieve metadata on a pydantic field


I would like to add some metadata to model fields. This is not necessary for validation or serialisation but might be user for data display or perhaps as a hint when entering data. For instance one might want to add a unit to a field.

I came up with this:

from pydantic import BaseModel, Field
from typing import Annotated
from dataclasses import dataclass

@dataclass
class MyFieldMetadata:
    unit: str


class MyModel(BaseModel):
    length: Annotated[float, Field(gte=0.0), MyFieldMetadata(unit="meter")]
    duration: Annotated[float, Field(gte=0.0), MyFieldMetadata(unit="seconds")]

and in order to print all fields with the unit postfixed:

m = MyModel(length=10.0, duration=60.0)
for field_name, field_info in m.model_fields.items():
    extra_info = next(
        (m for m in field_info.metadata if isinstance(m, MyFieldMetadata)), None
    )

    print(
        f"{field_name}: {getattr(m, field_name)} {extra_info.unit if extra_info else ''}"
    )

This works, but the question is if this is correct? And the second question: The way to retrieve MyFieldMetadata from the model seems a convoluted (iterating through meta_data and finding the first instance of MyFieldMetadata). Is there a cleaner or standard way to achieve this?

I'm using pydantic 2.7.1


Solution

  • I came up with this: [...]. [T]he question is if this is correct

    Annotated was introduced to add Metadata for use during "static analysis tools or at runtime", so you're definitely not abusing this mechanism.

    If it works for you, I'd say it's correct enough. But it very much depends on your specific needs.

    E.g., if it's just for display, this is fine, and you should call it a day. If this was just for communicating the field semantics to other developers, you could consider an even simpler approach and just rename length to metres and duration to seconds. But if you wanted to ensure that unit data is always available for some fields, or do some calculations that depend on the units involved, I would add this data explicitly.^[YAGNI] I have an example on the bottom.

    ^[YAGNI]: Most likely, you aren't going to need it!

    Regarding, whether there's a less convoluted "cleaner or standard way", I fear there isn't. The docs have the following to say wrt. this.

    A tool or library encountering an Annotated type can scan through the metadata elements to determine if they are of interest (e.g., using isinstance()).

    Another way to get at this data is via typing.get_type_hints with include_extras set, but that's not much cleaner.

    Example

    Consider the contrived example of collecting weather data on Mars. When you apply thrust to your orbiter, you want to be able to check your units easily, and only allow certain kinds of values. You could use specific unit fields that hold enumerations for that.

    The following has all kinds of issues, but outlines how an extra enumeration field can be used to check unit compatibility. The downside of this approach is, that it comes with a lot of added complexity and is less easy to handle. You can work around some of these issues, but that doesn't make it decrease the mental burden to reason about this by much.

    from abc import ABC, abstractmethod
    from enum import StrEnum
    from typing import  Any
    
    from pydantic import BaseModel, Field, NonNegativeFloat
    
    
    class Unit(StrEnum):
        """Unit base class."""
    
    
    class UnitValue(BaseModel, ABC):
        value: Any
        unit: Unit
    
        def __str__(self):
            return f"{self.value} {self.unit}"
    
        def __add__(self, other):
            klass = self.__class__
            if not isinstance(other, klass):
                return NotImplemented
            if self.unit == other.unit:
                value = self.value + other.value
            else:
                raise TypeError("Adding incompatible units")
    
            return klass(value=value, unit=self.unit)
    
    
    class ForceUnit(Unit):
        POUND_FORCE = "lbf"
        NEWTON = "N"
    
    
    class Force(UnitValue):
        value: float
        unit: ForceUnit = Field(default=ForceUnit.NEWTON)
    
    
    class MassUnit(Unit):
        KILOGRAM = "kg"
        POUND_MASS = "lb"
    
    
    class Mass(UnitValue):
        value: NonNegativeFloat
        unit: MassUnit
    
    
    class VelocityUnit(Unit):
        MILES_PER_HOUR = "mph"
        METRES_PER_SECOND = "m/s"
        FEET_PER_SECOND = "f/s"
    
    
    class Velocity(UnitValue):
        value: float
        unit: VelocityUnit = Field(default=VelocityUnit.METRES_PER_SECOND)
    
    
    class Orbiter(BaseModel):
        mass: Mass
        velocity: Velocity = Field(default=Velocity(value=0))
    
        def apply_thrust(self, force: Force):
            # TODO check unit compatibility
            assert force.unit is ForceUnit.NEWTON and self.mass.unit is MassUnit.KILOGRAM, "Wrong units"
            acceleration = force.value / self.mass.value
            self.velocity.value += acceleration
    
    
    if __name__ == '__main__':
        orbiter = Orbiter(mass=Mass(value=500, unit=MassUnit.KILOGRAM))
        orbiter.apply_thrust(Force(value=1200))
        print(orbiter.velocity)
    
        thruster_a = Force(value=300)
        thruster_b = Force(value=500)
        orbiter.apply_thrust(thruster_a + thruster_b)
        print(orbiter.velocity)
    
        try:
            orbiter.apply_thrust(Force(value=500, unit=ForceUnit.POUND_FORCE))  # Wrong units
        except AssertionError as e:
            print(e)