Skip to content

Commit

Permalink
Merge pull request #670 from matthewcarbone/mc-save-load-2
Browse files Browse the repository at this point in the history
Add save and load functionality to MSONable
  • Loading branch information
shyuep committed May 13, 2024
2 parents e0b72dd + 077159b commit a1d241a
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 80 deletions.
247 changes: 225 additions & 22 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
import pathlib
import pickle
import traceback
import types
from collections import OrderedDict, defaultdict
Expand All @@ -16,7 +17,8 @@
from importlib import import_module
from inspect import getfullargspec
from pathlib import Path
from uuid import UUID
from typing import Any, Dict
from uuid import UUID, uuid4

try:
import numpy as np
Expand Down Expand Up @@ -94,8 +96,14 @@ def _check_type(obj, type_str) -> bool:
Checks whether obj is an instance of the type defined by type_str. This
removes the need to explicitly import type_str. Handles subclasses like
isinstance does. E.g.::
class A: pass
class B(A): pass
class A:
pass
class B(A):
pass
a, b = A(), B()
assert isinstance(a, A)
assert isinstance(b, B)
Expand Down Expand Up @@ -166,7 +174,10 @@ def as_dict(self) -> dict:
"""
A JSON serializable dict representation of an object.
"""
d = {"@module": self.__class__.__module__, "@class": self.__class__.__name__}
d = {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
}

try:
parent_module = self.__class__.__module__.split(".", maxsplit=1)[0]
Expand Down Expand Up @@ -357,6 +368,175 @@ def __modify_schema__(cls, field_schema):
custom_schema = cls._generic_json_schema()
field_schema.update(custom_schema)

def _get_partial_json(self, json_kwargs, pickle_kwargs):
"""Used with the save method. Gets the json representation of a class
with the unserializable components sustituted for hash references."""

if pickle_kwargs is None:
pickle_kwargs = {}
if json_kwargs is None:
json_kwargs = {}
encoder = MontyEncoder(allow_unserializable_objects=True, **json_kwargs)
encoded = encoder.encode(self)
return encoder, encoded, json_kwargs, pickle_kwargs

def get_partial_json(self, json_kwargs=None, pickle_kwargs=None):
"""
Parameters
----------
json_kwargs : dict
Keyword arguments to pass to the serializer.
pickle_kwargs : dict
Keyword arguments to pass to pickle.dump.
Returns
-------
str, dict
The json encoding of the class and the name-object map if one is
required, otherwise None.
"""

encoder, encoded, json_kwargs, pickle_kwargs = self._get_partial_json(
json_kwargs, pickle_kwargs
)
name_object_map = encoder._name_object_map
if len(name_object_map) == 0:
name_object_map = None
return encoded, name_object_map, json_kwargs, pickle_kwargs

def save(
self,
json_path,
mkdir=True,
json_kwargs=None,
pickle_kwargs=None,
strict=True,
):
"""Utility that uses the standard tools of MSONable to convert the
class to json format, but also save it to disk. In addition, this
method intelligently uses pickle to individually pickle class objects
that are not serializable, saving them separately. This maximizes the
readability of the saved class information while allowing _any_
class to be at least partially serializable to disk.
For a fully MSONable class, only a class.json file will be saved to
the location {save_dir}/class.json. For a partially MSONable class,
additional information will be saved to the save directory at
{save_dir}. This includes a pickled object for each attribute that
e serialized.
Parameters
----------
file_path : os.PathLike
The file to which to save the json object. A pickled object of
the same name but different extension might also be saved if the
class is not entirely MSONable.
mkdir : bool
If True, makes the provided directory, including all parent
directories.
json_kwargs : dict
Keyword arguments to pass to the serializer.
pickle_kwargs : dict
Keyword arguments to pass to pickle.dump.
strict : bool
If True, will not allow you to overwrite existing files.
"""

json_path = Path(json_path)
save_dir = json_path.parent

encoded, name_object_map, json_kwargs, pickle_kwargs = self.get_partial_json(
json_kwargs, pickle_kwargs
)

if mkdir:
save_dir.mkdir(exist_ok=True, parents=True)

# Define the pickle path
pickle_path = save_dir / f"{json_path.stem}.pkl"

# Check if the files exist and the strict parameter is True
if strict and json_path.exists():
raise FileExistsError(f"strict is true and file {json_path} exists")
if strict and pickle_path.exists():
raise FileExistsError(f"strict is true and file {pickle_path} exists")

# Save the json file
with open(json_path, "w") as outfile:
outfile.write(encoded)

# Save the pickle file if we have anything to save from the name_object_map
if name_object_map is not None:
with open(pickle_path, "wb") as f:
pickle.dump(name_object_map, f, **pickle_kwargs)

@classmethod
def load(cls, file_path):
"""Loads a class from a provided json file.
Parameters
----------
file_path : os.PathLike
The json file to load from.
Returns
-------
MSONable
An instance of the class being reloaded.
"""

d = _d_from_path(file_path)
return cls.from_dict(d)


def load(path):
"""Loads a json file that was saved using MSONable.save.
Parameters
----------
path : os.PathLike
Path to the json file to load.
Returns
-------
MSONable
"""

d = _d_from_path(path)
module = d["@module"]
klass = d["@class"]
module = import_module(module)
klass = getattr(module, klass)
return klass.from_dict(d)


def _d_from_path(file_path):
json_path = Path(file_path)
save_dir = json_path.parent
pickle_path = save_dir / f"{json_path.stem}.pkl"

with open(json_path, "r") as infile:
d = json.loads(infile.read())

if pickle_path.exists():
name_object_map = pickle.load(open(pickle_path, "rb"))
d = _recursive_name_object_map_replacement(d, name_object_map)
return d


def _recursive_name_object_map_replacement(d, name_object_map):
if isinstance(d, dict):
if "@object_reference" in d:
name = d["@object_reference"]
return name_object_map.pop(name)
return {
k: _recursive_name_object_map_replacement(v, name_object_map)
for k, v in d.items()
}
elif isinstance(d, list):
return [_recursive_name_object_map_replacement(x, name_object_map) for x in d]
return d


class MontyEncoder(json.JSONEncoder):
"""
Expand All @@ -367,6 +547,18 @@ class MontyEncoder(json.JSONEncoder):
json.dumps(object, cls=MontyEncoder)
"""

def __init__(self, *args, allow_unserializable_objects=False, **kwargs):
super().__init__(*args, **kwargs)
self._allow_unserializable_objects = allow_unserializable_objects
self._name_object_map: Dict[str, Any] = {}
self._index = 0

def _update_name_object_map(self, o):
name = f"{self._index:012}-{str(uuid4())}"
self._index += 1
self._name_object_map[name] = o
return {"@object_reference": name}

def default(self, o) -> dict: # pylint: disable=E0202
"""
Overriding default method for JSON encoding. This method does two
Expand All @@ -380,7 +572,11 @@ def default(self, o) -> dict: # pylint: disable=E0202
Python dict representation.
"""
if isinstance(o, datetime.datetime):
return {"@module": "datetime", "@class": "datetime", "string": str(o)}
return {
"@module": "datetime",
"@class": "datetime",
"string": str(o),
}
if isinstance(o, UUID):
return {"@module": "uuid", "@class": "UUID", "string": str(o)}
if isinstance(o, Path):
Expand Down Expand Up @@ -431,10 +627,20 @@ def default(self, o) -> dict: # pylint: disable=E0202
}

if bson is not None and isinstance(o, bson.objectid.ObjectId):
return {"@module": "bson.objectid", "@class": "ObjectId", "oid": str(o)}
return {
"@module": "bson.objectid",
"@class": "ObjectId",
"oid": str(o),
}

if callable(o) and not isinstance(o, MSONable):
return _serialize_callable(o)
try:
return _serialize_callable(o)
except AttributeError as e:
# Some callables may not have instance __name__
if self._allow_unserializable_objects:
return self._update_name_object_map(o)
raise AttributeError(e)

try:
if pydantic is not None and isinstance(o, pydantic.BaseModel):
Expand All @@ -450,6 +656,11 @@ def default(self, o) -> dict: # pylint: disable=E0202
d = o.as_dict()
elif isinstance(o, Enum):
d = {"value": o.value}
elif self._allow_unserializable_objects:
# Last resort logic. We keep track of some name of the object
# as a reference, and instead of the object, store that
# name, which of course is json-serializable
d = self._update_name_object_map(o)
else:
raise TypeError(
f"Object of type {o.__class__.__name__} is not JSON serializable"
Expand Down Expand Up @@ -639,7 +850,11 @@ class MSONError(Exception):


def jsanitize(
obj, strict=False, allow_bson=False, enum_values=False, recursive_msonable=False
obj,
strict=False,
allow_bson=False,
enum_values=False,
recursive_msonable=False,
):
"""
This method cleans an input json-like object, either a list or a dict or
Expand Down Expand Up @@ -680,24 +895,12 @@ def jsanitize(
return obj
if isinstance(obj, (list, tuple)):
return [
jsanitize(
i,
strict=strict,
allow_bson=allow_bson,
enum_values=enum_values,
recursive_msonable=recursive_msonable,
)
jsanitize(i, strict=strict, allow_bson=allow_bson, enum_values=enum_values)
for i in obj
]
if np is not None and isinstance(obj, np.ndarray):
return [
jsanitize(
i,
strict=strict,
allow_bson=allow_bson,
enum_values=enum_values,
recursive_msonable=recursive_msonable,
)
jsanitize(i, strict=strict, allow_bson=allow_bson, enum_values=enum_values)
for i in obj.tolist()
]
if np is not None and isinstance(obj, np.generic):
Expand Down

0 comments on commit a1d241a

Please sign in to comment.