I'm trying to utilize msgspec
to encode and decode numpy data into json serialized objects. I've found lots of good resources on encoding the data and gotten my encoder to work no problem, but I can't get the data decoded back into the original format.
from dataclasses import dataclass
import numpy as np
import msgspec as ms
from traits.api import List, Array, Instance
def NumpyEncoder(obj):
if isinstance(
obj,
(
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
return ms.json.encode(obj)
enc = ms.json.Encoder(enc_hook=NumpyEncoder)
@dataclass
class C:
c1: np.ndarray = Array(dtype=np.float64)
@dataclass
class A:
a1: list[C] = List(Instance(C))
c = C(np.ones(10))
a = A(c1 = c)
enc.encode(a)
which gives the correct serialized value of a
. But how do I decode it correctly?
I've tried the following:
def NumpyEncoder(obj):
if isinstance(
obj,
(
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return dict(__ndarray__=obj.tolist(),dtype=str(obj.dtype))
return ms.json.encode(obj)
class NumpyDecoder:
def decoderHook(self, dct):
"""Decodes a previously encoded numpy ndarray with proper shape and
dtype.
:param dct: (dict) json encoded ndarray
:return: (ndarray) if input was an encoded ndarray
"""
if isinstance(dct, dict) and '__ndarray__' in dct:
return np.array(dct["__ndarray__"], dct['dtype'])
return ms.json.decode(dct)
enc = ms.json.Encoder(enc_hook=NumpyEncoder)
dec = ms.json.Decoder(dec_hook=NumpyDecoder, type=A)
b = enc.encode(a)
print(dec.decode(b))
Which does not decode b
back into an object of type A
.
Thanks!
As you intend to use dataclass for inheriting purposes, I've added this answer for tackling that, please let me know if this works as well.
import numpy as np
import msgspec
from dataclasses import dataclass, field
from typing import List
# Custom encoder function for numpy objects
def NumpyEncoder(obj):
if isinstance(
obj,
(
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return {"__ndarray__": obj.tolist(), "dtype": str(obj.dtype)}
return ms.json.encode(obj)
# Custom decoder function for numpy objects
def numpy_decoder_hook(type_, dct):
if isinstance(dct, dict) and "__ndarray__" in dct:
return np.array(dct["__ndarray__"], dtype=dct["dtype"])
return dct
# Define data classes
@dataclass
class C:
c1: np.ndarray
@dataclass
class A:
a1: List[C] = field(default_factory=list)
# Creating encoder and decoder with hooks
enc = msgspec.json.Encoder(enc_hook=NumpyEncoder)
dec = msgspec.json.Decoder(dec_hook=numpy_decoder_hook, type=A)
# Create instances
c = C(c1=np.ones(10))
a = A(a1=[c])
# Encode
encoded = enc.encode(a)
print(f"Encoded: {encoded}")
# Decode
decoded = dec.decode(encoded)
print(f"Decoded: {decoded}")
print(f"Decoded a1[0].c1: {decoded.a1[0].c1}")