Source code for local_migrator._serialize_hooks

import dataclasses
import enum
import json
import typing
from pathlib import Path

from ._class_register import REGISTER, class_to_str

try:
    from pydantic import BaseModel
except ImportError:  # pragma: no cover
    # allow to use in environment without pydantic.
    class BaseModel:  # type: ignore  [no-redef]
        pass


try:
    from numpy import floating, integer, ndarray
except ImportError:  # pragma: no cover
    # allow to use in environment without numpy.
    class ndarray:  # type: ignore  [no-redef]
        pass

    class integer:  # type: ignore  [no-redef]
        pass

    class floating:  # type: ignore  [no-redef]
        pass


def add_class_info(obj: typing.Any, dkt: dict) -> dict:
    return {
        "__class__": class_to_str(obj.__class__),
        "__class_version_dkt__": {
            class_to_str(sup_obj): str(REGISTER.get_version(sup_obj))
            for sup_obj in obj.__class__.__mro__
            if class_to_str(sup_obj)
            not in {
                "object",
                "pydantic.main.BaseModel",
                "pydantic.utils.Representation",
                "enum.Enum",
                "builtins.object",
                "typing.Generic",
            }
            and not class_to_str(sup_obj).startswith("collections.abc")
        },
        "__values__": dkt,
    }


[docs]def object_encoder(obj: typing.Any): # noqa: PLR0911 """ Function changing supported types to basic python types supported by most serializers and which could be restored by :py:func:`nme_object_hook` function. Supported types are: * :py:class:`enum.Enum` * :py:func:`dataclasses.dataclass` * :py:class:`numpy.ndarray` * :py:class:`pydantic.BaseModel` * :py:class:`numpy.integer` (change to pure int) * :py:class:`numpy.floating` (change to pure float) * :py:class:`pathlib.Path` (Serialized to string) * Any class with an ``as_dict`` method. This method should return a dictionary of valid constructor arguments. :param obj: object to be encoded. :return: encoded object for supported types. Otherwise ``None``. """ if isinstance(obj, enum.Enum): dkt = {"value": obj.value} return add_class_info(obj, dkt) if dataclasses.is_dataclass(obj): fields = dataclasses.fields(obj) dkt = {x.name: getattr(obj, x.name) for x in fields} return add_class_info(obj, dkt) if isinstance(obj, ndarray): return obj.tolist() if isinstance(obj, BaseModel): try: dkt = dict(obj) except (ValueError, TypeError): dkt = obj.dict() # workaround for napari Colormap class return add_class_info(obj, dkt) if hasattr(obj, "as_dict"): dkt = obj.as_dict() return add_class_info(obj, dkt) if isinstance(obj, integer): return int(obj) if isinstance(obj, floating): return float(obj) if isinstance(obj, Path): return str(obj) return None
[docs]class Encoder(json.JSONEncoder): """ JSONEncoder subclass for serializing Python objects into JSON. For list of supported types check :py:func:`nme_object_encoder` function. """
[docs] def default(self, o): """ Implementation that calls :py:func:`nme_object_encoder` function. """ val = object_encoder(o) if val is None: # pragma: no cover return super().default(o) return val
[docs]def check_for_errors_in_dkt_values(dkt: dict) -> typing.List[str]: """ Function checking if any of values in dict contains ``"__error__"`` key. :param dkt: dictionary to check. :return: list of keys that value is dict containing ``"__error__"`` key. """ return [key for key, value in dkt.items() if isinstance(value, dict) and "__error__" in value]
[docs]def object_hook(dkt: dict) -> typing.Any: """ Function restoring supported types from :py:func:`nme_object_encoder` function output. If ``dkt`` does not contain ``__class__`` key, it is returned as is. If the restoring object fails then function return dict with ``"__error__"`` key. :param dkt: dictionary with data to restore. """ if "__error__" in dkt: dkt.pop("__error__") # different environments without same plugins installed if "__class__" in dkt: if "__values__" not in dkt: cls_str = dkt.pop("__class__") version_dkt = dkt.pop("__class_version_dkt__") if "__class_version_dkt__" in dkt else {cls_str: "0.0.0"} dkt = {"__values__": dkt, "__class__": cls_str, "__class_version_dkt__": version_dkt} try: cls = REGISTER.get_class(dkt["__class__"]) except (KeyError, ValueError): dkt["__error__"] = f"Class {dkt['__class__']} not found in register." return dkt problematic_fields = check_for_errors_in_dkt_values(dkt["__values__"]) if problematic_fields and not REGISTER.allow_errors_in_values(cls): dkt["__error__"] = f"Error in fields: {', '.join(problematic_fields)}" return dkt try: dkt_migrated = REGISTER.migrate_data(dkt["__class__"], dkt["__class_version_dkt__"], dkt["__values__"]) cls = REGISTER.get_class(dkt["__class__"]) return cls(**dkt_migrated) except Exception as e: # pylint: disable=W0703 dkt["__error__"] = str(e) return dkt
nme_object_hook = object_hook nme_object_encoder = object_encoder NMEEncoder = Encoder