pythongenericspydanticpydantic-v2

Implementing a Lazy evaluated field for Pydantic v2


I'm trying to implement a Lazily evaluated generic field type for Pydantic v2. This is the simple implementation I have. You can assign either a value, a function or an async function to the lazy field and, it's only evaluated when you access it. If you use this in any normal class, it works perfectly. But it doesn't work as a Pydantic field.

The problem is __set__ is never called here. __get__ is called twice for some reason though. I know Pydantic does some weird stuff internally which might be the reason. Any help would be highly appreciated to resolve this.

import asyncio
import inspect
from typing import Any, Awaitable, Callable, Generic, Optional, TypeVar, Union, cast

from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

T = TypeVar("T")


class LazyField(Generic[T]):
    """A lazy field that can hold a value, function, or async function.
    The value is evaluated only when accessed and then cached.
    """

    def __init__(self, value=None) -> None:
        print("LazyField.__init__")

        self._value: Optional[T] = None
        self._loader: Optional[Callable[[], Union[T, Awaitable[T]]]] = None
        self._is_loaded: bool = False

    def __get__(self, obj: Any, objtype=None) -> T:
        print("LazyField.__get__")

        if obj is None:
            return self  # type: ignore

        if not self._is_loaded:
            if self._loader is None:
                if self._value is None:
                    raise AttributeError("LazyField has no value or loader set")
                return self._value

            if inspect.iscoroutinefunction(self._loader):
                try:
                    loop = asyncio.get_running_loop()
                except RuntimeError:
                    loop = asyncio.new_event_loop()
                self._value = loop.run_until_complete(self._loader())  # type: ignore
            else:
                self._value = self._loader()  # type: ignore

            self._is_loaded = True
            self._loader = None

        assert self._value is not None
        return self._value

    def __set__(
        self, obj: Any, value: Union[T, Callable[[], T], Callable[[], Awaitable[T]]]
    ) -> None:
        print("LazyField.__set__")

        self._is_loaded = False
        if callable(value):
            self._loader = cast(
                Union[Callable[[], T], Callable[[], Awaitable[T]]], value
            )
            self._value = None
        else:
            self._loader = None
            self._value = cast(T, value)

    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: type[Any], handler: GetCoreSchemaHandler
    ) -> CoreSchema:
        print("LazyField.__get_pydantic_core_schema__")

        # Extract the inner type from LazyField[T]
        inner_type = (
            source_type.__args__[0] if hasattr(source_type, "__args__") else Any
        )
        # Generate schema for the inner type
        inner_schema = handler.generate_schema(inner_type)

        schema = core_schema.json_or_python_schema(
            json_schema=inner_schema,
            python_schema=core_schema.union_schema(
                [
                    # Handle direct value assignment
                    inner_schema,
                    # Handle callable assignment
                    core_schema.callable_schema(),
                    # Handle coroutine function assignment
                    core_schema.callable_schema(),
                ]
            ),
            serialization=core_schema.plain_serializer_function_ser_schema(
                lambda x: x._value if hasattr(x, "_value") and x._is_loaded else None,
                return_schema=inner_schema,
                when_used="json",
            ),
        )
        return schema


class A(BaseModel):
    content: LazyField[bytes] = LazyField()


async def get_content():
    return b"Hello, world!"


a = A(content=get_content)

print(a.content)

This is the output from above:

LazyField.__init__
LazyField.__get__
LazyField.__get__
LazyField.__get_pydantic_core_schema__
<function get_content at 0x102cc4860>

As you can see, __get__ is called twice. And because __set__ is never called, _is_loaded and _loader is None, so __get__ just returns the raw value as a function without evaluating.


Solution

  • I have found a solution to this problem finally! Trick was to make it a property type so that Pydantic won't do any weird stuff with it. Hope this would help someone else looking for a solution.

    import asyncio
    import inspect
    from typing import Any, Awaitable, Callable, Generic, Optional, TypeVar, Union, cast
    
    from pydantic import BaseModel
    
    T = TypeVar("T")
    
    
    class LazyField(Generic[T], property):
        def __init__(self, value=None) -> None:
            self._value: Optional[T] = None
            self._loader: Optional[Callable[[], Union[T, Awaitable[T]]]] = None
            self._is_loaded: bool = False
            super().__init__()
    
        def __get__(self, obj: Any, objtype=None) -> T:
            if obj is None:
                return self  # type: ignore
    
            if not self._is_loaded:
                if self._loader is None:
                    if self._value is None:
                        raise AttributeError("LazyField has no value or loader set")
                    return self._value
    
                if inspect.iscoroutinefunction(self._loader):
                    try:
                        loop = asyncio.get_running_loop()
                    except RuntimeError:
                        loop = asyncio.new_event_loop()
                    self._value = loop.run_until_complete(self._loader())  # type: ignore
                else:
                    self._value = self._loader()  # type: ignore
    
                self._is_loaded = True
                self._loader = None
    
            assert self._value is not None
            return self._value
    
        def __set__(
            self, obj: Any, value: Union[T, Callable[[], T], Callable[[], Awaitable[T]]]
        ) -> None:
            self._is_loaded = False
            if callable(value):
                self._loader = cast(
                    Union[Callable[[], T], Callable[[], Awaitable[T]]], value
                )
                self._value = None
            else:
                self._loader = None
                self._value = cast(T, value)
    
    
    class Artifact(BaseModel):
        content = LazyField[bytes]()
    
    
    a = Artifact()
    
    
    async def load_content() -> bytes:
        return b"content"
    
    
    a.content = load_content
    print(a.content)
    
    a.content = b"content2"
    print(a.content)