I would like to add some metadata to model fields. This is not necessary for validation or serialisation but might be user for data display or perhaps as a hint when entering data. For instance one might want to add a unit to a field.
I came up with this:
from pydantic import BaseModel, Field
from typing import Annotated
from dataclasses import dataclass
@dataclass
class MyFieldMetadata:
unit: str
class MyModel(BaseModel):
length: Annotated[float, Field(gte=0.0), MyFieldMetadata(unit="meter")]
duration: Annotated[float, Field(gte=0.0), MyFieldMetadata(unit="seconds")]
and in order to print all fields with the unit postfixed:
m = MyModel(length=10.0, duration=60.0)
for field_name, field_info in m.model_fields.items():
extra_info = next(
(m for m in field_info.metadata if isinstance(m, MyFieldMetadata)), None
)
print(
f"{field_name}: {getattr(m, field_name)} {extra_info.unit if extra_info else ''}"
)
This works, but the question is if this is correct? And the second question: The way to retrieve MyFieldMetadata from the model seems a convoluted (iterating through meta_data and finding the first instance of MyFieldMetadata). Is there a cleaner or standard way to achieve this?
I'm using pydantic 2.7.1
I came up with this: [...]. [T]he question is if this is correct
Annotated was introduced to add Metadata for use during "static analysis tools or at runtime", so you're definitely not abusing this mechanism.
If it works for you, I'd say it's correct enough. But it very much depends on your specific needs.
E.g., if it's just for display, this is fine, and you should call it a day. If this was just for communicating the field semantics to other developers, you could consider an even simpler approach and just rename length
to metres
and duration
to seconds
. But if you wanted to ensure that unit data is always available for some fields, or do some calculations that depend on the units involved, I would add this data explicitly.^[YAGNI] I have an example on the bottom.
^[YAGNI]: Most likely, you aren't going to need it!
Regarding, whether there's a less convoluted "cleaner or standard way", I fear there isn't. The docs have the following to say wrt. this.
A tool or library encountering an Annotated type can scan through the metadata elements to determine if they are of interest (e.g., using isinstance()).
Another way to get at this data is via typing.get_type_hints
with include_extras
set, but that's not much cleaner.
Consider the contrived example of collecting weather data on Mars. When you apply thrust to your orbiter, you want to be able to check your units easily, and only allow certain kinds of values. You could use specific unit fields that hold enumerations for that.
The following has all kinds of issues, but outlines how an extra enumeration field can be used to check unit compatibility. The downside of this approach is, that it comes with a lot of added complexity and is less easy to handle. You can work around some of these issues, but that doesn't make it decrease the mental burden to reason about this by much.
from abc import ABC, abstractmethod
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, NonNegativeFloat
class Unit(StrEnum):
"""Unit base class."""
class UnitValue(BaseModel, ABC):
value: Any
unit: Unit
def __str__(self):
return f"{self.value} {self.unit}"
def __add__(self, other):
klass = self.__class__
if not isinstance(other, klass):
return NotImplemented
if self.unit == other.unit:
value = self.value + other.value
else:
raise TypeError("Adding incompatible units")
return klass(value=value, unit=self.unit)
class ForceUnit(Unit):
POUND_FORCE = "lbf"
NEWTON = "N"
class Force(UnitValue):
value: float
unit: ForceUnit = Field(default=ForceUnit.NEWTON)
class MassUnit(Unit):
KILOGRAM = "kg"
POUND_MASS = "lb"
class Mass(UnitValue):
value: NonNegativeFloat
unit: MassUnit
class VelocityUnit(Unit):
MILES_PER_HOUR = "mph"
METRES_PER_SECOND = "m/s"
FEET_PER_SECOND = "f/s"
class Velocity(UnitValue):
value: float
unit: VelocityUnit = Field(default=VelocityUnit.METRES_PER_SECOND)
class Orbiter(BaseModel):
mass: Mass
velocity: Velocity = Field(default=Velocity(value=0))
def apply_thrust(self, force: Force):
# TODO check unit compatibility
assert force.unit is ForceUnit.NEWTON and self.mass.unit is MassUnit.KILOGRAM, "Wrong units"
acceleration = force.value / self.mass.value
self.velocity.value += acceleration
if __name__ == '__main__':
orbiter = Orbiter(mass=Mass(value=500, unit=MassUnit.KILOGRAM))
orbiter.apply_thrust(Force(value=1200))
print(orbiter.velocity)
thruster_a = Force(value=300)
thruster_b = Force(value=500)
orbiter.apply_thrust(thruster_a + thruster_b)
print(orbiter.velocity)
try:
orbiter.apply_thrust(Force(value=500, unit=ForceUnit.POUND_FORCE)) # Wrong units
except AssertionError as e:
print(e)