python-3.xpydanticabstract-base-class

Use subclasses of abstract outer class in nested class


I want to use all subclasses of an abstract class in the nested config class of a pydantic class like this:

def custom_json_loads(classes, ....):
    ##use classes here for json parsing

class Outer(BaseModel, abc.ABC):
    name = "test"
    class Config:
        json_loads = partial(custom_json_loads, Outer.__subclasses__)

The aim of it all is that I know the OuterClass Type for my JSON and the name of the classes signify which instance of a subclass should be created

E.g. I have BlueOuter, RedOuter, GreenOuter and in the json there would be "outer" : { "name" : "BlueOuter", ....}

But I don't want to import all possible variants of the subclasses because they evolve over time


Solution

  • Why not use a discriminated union?

    import abc
    from typing import Annotated, Literal, Union
    
    from pydantic import BaseModel, Field
    
    
    class Outer(BaseModel, abc.ABC):
        ...
    
    
    class BlueOuter(Outer):
        name: Literal["BlueOuter"]
    
    
    class RedOuter(Outer):
        name: Literal["RedOuter"]
    
    
    class GreenOuter(Outer):
        name: Literal["GreenOuter"]
    
    
    OuterUnion = Annotated[
        Union[BlueOuter, RedOuter, GreenOuter], Field(discriminator="name")
    ]
    
    
    class Foo(BaseModel):
        outer: OuterUnion
    
    
    print(Foo.parse_raw('{"outer": {"name": "BlueOuter"}}'))
    print(Foo.parse_raw('{"outer": {"name": "RedOuter"}}'))
    print(Foo.parse_raw('{"outer": {"name": "GreenOuter"}}'))
    

    Output:

    outer=BlueOuter(name='BlueOuter')
    outer=RedOuter(name='RedOuter')
    outer=GreenOuter(name='GreenOuter')
    

    If you worry about the need to mantain OuterUnion when a new Outer subclass is added, you could have a unit test to check that OuterUnion has all the subclasses of Outer:

    class OrangeOuter(Outer):
        name: Literal["OrangeOuter"]
    
    
    outer_union_classes = OuterUnion.__args__[0].__args__
    for subclass in Outer.__subclasses__():
        assert (
            subclass in outer_union_classes
        ), f"{subclass.__name__} must be a member of OuterUnion (classes: {[c.__name__ for c in outer_union_classes]}). Please add it."
    

    Output:

    AssertionError: OrangeOuter must be a member of OuterUnion (classes: ['BlueOuter', 'RedOuter', 'GreenOuter']). Please add it.