I'm trying to combine marshmallow-dataclass with marshmallow-oneofschema to process a data structure that is given to me and is used to connect a java and a python application.
In java the concept is commonly known as "discriminator" and it's implemented by different frameworks: Jackson (de-)serialization for polymorph list
I thought it should be possible, but now I'm facing the following issue:
Traceback (most recent call last):
File "one_file_dataclass_oneofschema_example.py", line 69, in <module>
ContainerSchema.load(example)
File "/site-packages/marshmallow_dataclass/__init__.py", line 473, in load
all_loaded = super().load(data, many=many, **kwargs)
File "/site-packages/marshmallow/schema.py", line 722, in load
return self._do_load(
File "/site-packages/marshmallow/schema.py", line 856, in _do_load
result = self._deserialize(
File "/site-packages/marshmallow/schema.py", line 664, in _deserialize
value = self._call_and_store(
File "/site-packages/marshmallow/schema.py", line 493, in _call_and_store
value = getter_func(data)
File "/site-packages/marshmallow/schema.py", line 661, in <lambda>
getter = lambda val: field_obj.deserialize(
File "/site-packages/marshmallow/fields.py", line 342, in deserialize
output = self._deserialize(value, attr, data, **kwargs)
File "/site-packages/marshmallow/fields.py", line 713, in _deserialize
result.append(self.inner.deserialize(each, **kwargs))
File "/site-packages/marshmallow/fields.py", line 342, in deserialize
output = self._deserialize(value, attr, data, **kwargs)
File "/site-packages/marshmallow/fields.py", line 597, in _deserialize
return self._load(value, data, partial=partial)
File "/site-packages/marshmallow/fields.py", line 580, in _load
valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)
File "/site-packages/marshmallow_dataclass/__init__.py", line 481, in load
raise e
File "/site-packages/marshmallow_dataclass/__init__.py", line 479, in load
return clazz(**all_loaded)
TypeError: type object argument after ** must be a mapping, not SubClass1
When debugging it looks like an already deserialized object should be deserialized again by marshmallow_dataclass:
def _base_schema(
clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
) -> Type[marshmallow.Schema]:
"""
Base schema factory that creates a schema for `clazz` derived either from `base_schema`
or `BaseSchema`
"""
# Remove `type: ignore` when mypy handles dynamic base classes
# https://github.com/python/mypy/issues/2813
class BaseSchema(base_schema or marshmallow.Schema): # type: ignore
def load(self, data: Mapping, *, many: bool = None, **kwargs):
all_loaded = super().load(data, many=many, **kwargs)
many = self.many if many is None else bool(many)
if many:
return [clazz(**loaded) for loaded in all_loaded]
else:
return clazz(**all_loaded) # Here the error occurs.
return
I've seen the post_load logic mentioned in this question, but have no idea how to integrate or if it would help.
Here is what I tried:
from dataclasses import dataclass, field # @UnusedImport
from typing import List
from marshmallow_dataclass import class_schema, add_schema
from marshmallow_oneofschema.one_of_schema import OneOfSchema
@dataclass
class BaseClass:
base_property_a: str
@dataclass
class SubClass1(BaseClass):
sub_property_1: str
@dataclass
class SubClass2(BaseClass):
sub_property_2: str
@dataclass
class Container:
container_property_c: str
some_objects: List[BaseClass]
# These schemas are required by OneOfSchema
SubClass1Schema = class_schema(SubClass1)()
SubClass2Schema = class_schema(SubClass2)()
class BaseClassSchema(OneOfSchema):
type_field = 'modelType'
type_schemas = {
'SubClass1': SubClass1Schema,
'SubClass2': SubClass2Schema
}
def get_obj_type(self, obj):
if isinstance(obj, BaseClass):
return obj.__class__.__name__
else:
raise Exception("Unknown object type: {}".format(obj.__class__.__name__))
# Publishing BaseClassSchema so that it's use while not explicitly called.
add_schema(BaseClass, BaseClassSchema)
ContainerSchema = class_schema(Container)()
example = {
'container_property_c': 'bla',
'some_objects': [
{
'modelType': 'SubClass1',
'base_property_a': 'blub_a1',
'sub_property_1': 'blub_1'
},
{
'modelType': 'SubClass2',
'base_property_a': 'blub_a2',
'sub_property_2': 'blub_2'
}
]
}
ContainerSchema.load(example)
Found a simpler solution using marshmallow-union, which is also the recommended way by marshmallow-dataclass: https://github.com/lovasoa/marshmallow_dataclass/issues/62
from dataclasses import field # @UnusedImport
from typing import List, Union
from marshmallow.validate import Equal
from marshmallow_dataclass import dataclass
@dataclass
class BaseClass:
base_property_a: str
@dataclass
class SubClass1(BaseClass):
modelType: str = field(metadata={"validate": Equal("SubClass1")})
sub_property_1: str
@dataclass
class SubClass2(BaseClass):
modelType: str = field(metadata={"validate": Equal("SubClass2")})
sub_property_2: str
@dataclass
class Container:
container_property_c: str
some_objects: List[Union[
SubClass1,
SubClass2
]]
example = {
'container_property_c': 'bla',
'some_objects': [
{
'modelType': 'SubClass1',
'base_property_a': 'blub_a1',
'sub_property_1': 'blub_1'
},
{
'modelType': 'SubClass2',
'base_property_a': 'blub_a2',
'sub_property_2': 'blub_2'
}
]
}
print(Container.Schema().load(example))