pythonpolymorphismmarshmallowpython-dataclassesdiscriminator

How to combine marshmallow-dataclass with marshmallow-oneofschema for polymorph structure?


I'm trying to combine marshmallow-dataclass with marshmallow-oneofschema to process a data structure that is given to me and is used to connect a java and a python application.

In java the concept is commonly known as "discriminator" and it's implemented by different frameworks: Jackson (de-)serialization for polymorph list

I thought it should be possible, but now I'm facing the following issue:

Traceback (most recent call last):
File "one_file_dataclass_oneofschema_example.py", line 69, in <module>
    ContainerSchema.load(example)
File "/site-packages/marshmallow_dataclass/__init__.py", line 473, in load
    all_loaded = super().load(data, many=many, **kwargs)
File "/site-packages/marshmallow/schema.py", line 722, in load
    return self._do_load(
File "/site-packages/marshmallow/schema.py", line 856, in _do_load
    result = self._deserialize(
File "/site-packages/marshmallow/schema.py", line 664, in _deserialize
    value = self._call_and_store(
File "/site-packages/marshmallow/schema.py", line 493, in _call_and_store
    value = getter_func(data)
File "/site-packages/marshmallow/schema.py", line 661, in <lambda>
    getter = lambda val: field_obj.deserialize(
File "/site-packages/marshmallow/fields.py", line 342, in deserialize
    output = self._deserialize(value, attr, data, **kwargs)
File "/site-packages/marshmallow/fields.py", line 713, in _deserialize
    result.append(self.inner.deserialize(each, **kwargs))
File "/site-packages/marshmallow/fields.py", line 342, in deserialize
    output = self._deserialize(value, attr, data, **kwargs)
File "/site-packages/marshmallow/fields.py", line 597, in _deserialize
    return self._load(value, data, partial=partial)
File "/site-packages/marshmallow/fields.py", line 580, in _load
    valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)
File "/site-packages/marshmallow_dataclass/__init__.py", line 481, in load
    raise e
File "/site-packages/marshmallow_dataclass/__init__.py", line 479, in load
    return clazz(**all_loaded)
TypeError: type object argument after ** must be a mapping, not SubClass1

When debugging it looks like an already deserialized object should be deserialized again by marshmallow_dataclass:

def _base_schema(
    clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
) -> Type[marshmallow.Schema]:
    """
    Base schema factory that creates a schema for `clazz` derived either from `base_schema`
    or `BaseSchema`
    """

    # Remove `type: ignore` when mypy handles dynamic base classes
    # https://github.com/python/mypy/issues/2813
    class BaseSchema(base_schema or marshmallow.Schema):  # type: ignore
        def load(self, data: Mapping, *, many: bool = None, **kwargs):
            all_loaded = super().load(data, many=many, **kwargs)
            many = self.many if many is None else bool(many)
            if many:
                return [clazz(**loaded) for loaded in all_loaded]
            else:
                return clazz(**all_loaded) # Here the error occurs.
    return 

I've seen the post_load logic mentioned in this question, but have no idea how to integrate or if it would help.

Here is what I tried:

from dataclasses import dataclass, field  # @UnusedImport
from typing import List

from marshmallow_dataclass import class_schema, add_schema
from marshmallow_oneofschema.one_of_schema import OneOfSchema


@dataclass
class BaseClass:
    base_property_a: str


@dataclass
class SubClass1(BaseClass):
    sub_property_1: str


@dataclass
class SubClass2(BaseClass):
    sub_property_2: str


@dataclass
class Container:
    container_property_c: str
    some_objects: List[BaseClass]

# These schemas are required by OneOfSchema
SubClass1Schema = class_schema(SubClass1)()
SubClass2Schema = class_schema(SubClass2)()


class BaseClassSchema(OneOfSchema):
    type_field = 'modelType'

    type_schemas = {
        'SubClass1': SubClass1Schema,
        'SubClass2': SubClass2Schema
    }

    def get_obj_type(self, obj):
        if isinstance(obj, BaseClass):
            return obj.__class__.__name__
        else:
            raise Exception("Unknown object type: {}".format(obj.__class__.__name__))

# Publishing BaseClassSchema so that it's use while not explicitly called.
add_schema(BaseClass, BaseClassSchema)

ContainerSchema = class_schema(Container)()


example = {
    'container_property_c': 'bla',
    'some_objects': [
        {
            'modelType': 'SubClass1',
            'base_property_a': 'blub_a1',
            'sub_property_1': 'blub_1'
        },
        {
            'modelType': 'SubClass2',
            'base_property_a': 'blub_a2',
            'sub_property_2': 'blub_2'
        }
    ]
}

ContainerSchema.load(example)

Solution

  • Found a simpler solution using marshmallow-union, which is also the recommended way by marshmallow-dataclass: https://github.com/lovasoa/marshmallow_dataclass/issues/62

    from dataclasses import field  # @UnusedImport
    from typing import List, Union
    
    from marshmallow.validate import Equal
    from marshmallow_dataclass import dataclass
    
    
    @dataclass
    class BaseClass:
        base_property_a: str
    
    
    @dataclass
    class SubClass1(BaseClass):
        modelType: str = field(metadata={"validate": Equal("SubClass1")})
        sub_property_1: str
    
    
    @dataclass
    class SubClass2(BaseClass):
        modelType: str = field(metadata={"validate": Equal("SubClass2")})
        sub_property_2: str
    
    
    @dataclass
    class Container:
        container_property_c: str
        some_objects: List[Union[
            SubClass1,
            SubClass2
        ]]
    
    
    example = {
        'container_property_c': 'bla',
        'some_objects': [
            {
                'modelType': 'SubClass1',
                'base_property_a': 'blub_a1',
                'sub_property_1': 'blub_1'
            },
            {
                'modelType': 'SubClass2',
                'base_property_a': 'blub_a2',
                'sub_property_2': 'blub_2'
            }
        ]
    }
    
    print(Container.Schema().load(example))