Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the numpy hook #542

Merged
merged 4 commits into from Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
143 changes: 54 additions & 89 deletions dill/_dill.py
Expand Up @@ -91,54 +91,24 @@ def __hook__():
from numpy import dtype as NumpyDType
return True
if NumpyArrayType: # then has numpy
def ndarraysubclassinstance(obj):
if type(obj) in (TypeType, ClassType):
return False # all classes return False
try: # check if is ndarray, and elif is subclass of ndarray
cls = getattr(obj, '__class__', None)
if cls is None: return False
elif cls is TypeType: return False
elif 'numpy.ndarray' not in str(getattr(cls, 'mro', int.mro)()):
return False
except OSError: return False # ctypes.LibraryLoader
except ReferenceError: return False # handle 'R3' weakref in 3.x
except TypeError: return False
def ndarraysubclassinstance(obj_type):
if all((c.__module__, c.__name__) != ('numpy', 'ndarray') for c in obj_type.__mro__):
return False
# anything below here is a numpy array (or subclass) instance
__hook__() # import numpy (so the following works!!!)
# verify that __reduce__ has not been overridden
NumpyInstance = NumpyArrayType((0,),'int8')
if id(obj.__reduce_ex__) == id(NumpyInstance.__reduce_ex__) and \
id(obj.__reduce__) == id(NumpyInstance.__reduce__): return True
return False
def numpyufunc(obj):
if type(obj) in (TypeType, ClassType):
return False # all classes return False
try: # check if is ufunc
cls = getattr(obj, '__class__', None)
if cls is None: return False
elif cls is TypeType: return False
if 'numpy.ufunc' not in str(getattr(cls, 'mro', int.mro)()):
return False
except OSError: return False # ctypes.LibraryLoader
except ReferenceError: return False # handle 'R3' weakref in 3.x
except TypeError: return False
# anything below here is a numpy ufunc
if obj_type.__reduce_ex__ is not NumpyArrayType.__reduce_ex__ \
or obj_type.__reduce__ is not NumpyArrayType.__reduce__:
return False
return True
def numpydtype(obj):
if type(obj) in (TypeType, ClassType):
return False # all classes return False
try: # check if is dtype
cls = getattr(obj, '__class__', None)
if cls is None: return False
elif cls is TypeType: return False
if 'numpy.dtype' not in str(getattr(obj, 'mro', int.mro)()):
return False
except OSError: return False # ctypes.LibraryLoader
except ReferenceError: return False # handle 'R3' weakref in 3.x
except TypeError: return False
def numpyufunc(obj_type):
return any((c.__module__, c.__name__) == ('numpy', 'ufunc') for c in obj_type.__mro__)
def numpydtype(obj_type):
if all((c.__module__, c.__name__) != ('numpy', 'dtype') for c in obj_type.__mro__):
return False
# anything below here is a numpy dtype
__hook__() # import numpy (so the following works!!!)
return type(obj) is type(NumpyDType) # handles subclasses
return obj_type is type(NumpyDType) # handles subclasses
else:
def ndarraysubclassinstance(obj): return False
def numpyufunc(obj): return False
Expand Down Expand Up @@ -373,42 +343,44 @@ def __init__(self, file, *args, **kwds):
def save(self, obj, save_persistent_id=True):
# register if the object is a numpy ufunc
# thanks to Paul Kienzle for pointing out ufuncs didn't pickle
if NumpyUfuncType and numpyufunc(obj):
@register(type(obj))
def save_numpy_ufunc(pickler, obj):
logger.trace(pickler, "Nu: %s", obj)
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
StockPickler.save_global(pickler, obj, name=name)
logger.trace(pickler, "# Nu")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def udump(f): return f.__name__
# def uload(name): return getattr(numpy, name)
# copy_reg.pickle(NumpyUfuncType, udump, uload)
# register if the object is a numpy dtype
if NumpyDType and numpydtype(obj):
@register(type(obj))
def save_numpy_dtype(pickler, obj):
logger.trace(pickler, "Dt: %s", obj)
pickler.save_reduce(_create_dtypemeta, (obj.type,), obj=obj)
logger.trace(pickler, "# Dt")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def uload(name): return type(NumpyDType(name))
# def udump(f): return uload, (f.type,)
# copy_reg.pickle(NumpyDTypeType, udump, uload)
# register if the object is a subclassed numpy array instance
if NumpyArrayType and ndarraysubclassinstance(obj):
@register(type(obj))
def save_numpy_array(pickler, obj):
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype)
npdict = getattr(obj, '__dict__', None)
f, args, state = obj.__reduce__()
pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj)
logger.trace(pickler, "# Nu")
return
obj_type = type(obj)
if NumpyArrayType and not (obj_type is type or obj_type in Pickler.dispatch):
if NumpyUfuncType and numpyufunc(obj_type):
@register(obj_type)
def save_numpy_ufunc(pickler, obj):
logger.trace(pickler, "Nu: %s", obj)
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
StockPickler.save_global(pickler, obj, name=name)
logger.trace(pickler, "# Nu")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def udump(f): return f.__name__
# def uload(name): return getattr(numpy, name)
# copy_reg.pickle(NumpyUfuncType, udump, uload)
# register if the object is a numpy dtype
if NumpyDType and numpydtype(obj_type):
@register(obj_type)
def save_numpy_dtype(pickler, obj):
logger.trace(pickler, "Dt: %s", obj)
pickler.save_reduce(_create_dtypemeta, (obj.type,), obj=obj)
logger.trace(pickler, "# Dt")
return
# NOTE: the above 'save' performs like:
# import copy_reg
# def uload(name): return type(NumpyDType(name))
# def udump(f): return uload, (f.type,)
# copy_reg.pickle(NumpyDTypeType, udump, uload)
# register if the object is a subclassed numpy array instance
if NumpyArrayType and ndarraysubclassinstance(obj_type):
@register(obj_type)
def save_numpy_array(pickler, obj):
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype)
npdict = getattr(obj, '__dict__', None)
f, args, state = obj.__reduce__()
pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj)
logger.trace(pickler, "# Nu")
return
# end hack
if GENERATOR_FAIL and type(obj) == GeneratorType:
msg = "Can't pickle %s: attribute lookup builtins.generator failed" % GeneratorType
Expand Down Expand Up @@ -1604,18 +1576,11 @@ def save_weakref(pickler, obj):
@register(ProxyType)
@register(CallableProxyType)
def save_weakproxy(pickler, obj):
# Must do string substitution here and use %r to avoid ReferenceError.
logger.trace(pickler, "R2: %r" % obj)
refobj = _locate_object(_proxy_helper(obj))
try:
_t = "R2"
logger.trace(pickler, "%s: %s", _t, obj)
except ReferenceError:
_t = "R3"
logger.trace(pickler, "%s: %s", _t, sys.exc_info()[1])
#callable = bool(getattr(refobj, '__call__', None))
if type(obj) is CallableProxyType: callable = True
else: callable = False
pickler.save_reduce(_create_weakproxy, (refobj, callable), obj=obj)
logger.trace(pickler, "# %s", _t)
pickler.save_reduce(_create_weakproxy, (refobj, callable(obj)), obj=obj)
logger.trace(pickler, "# R2")
return

def _is_builtin_module(module):
Expand Down
4 changes: 2 additions & 2 deletions dill/logger.py
Expand Up @@ -50,7 +50,7 @@
import math
import os
from functools import partial
from typing import NoReturn, TextIO, Union
from typing import TextIO, Union

import dill

Expand Down Expand Up @@ -214,7 +214,7 @@ def format(self, record):
stderr_handler = logging._StderrHandler()
adapter.addHandler(stderr_handler)

def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') -> NoReturn:
def trace(arg: Union[bool, TextIO, str, os.PathLike] = None, *, mode: str = 'a') -> None:
"""print a trace through the stack when pickling; useful for debugging

With a single boolean argument, enable or disable the tracing.
Expand Down
7 changes: 3 additions & 4 deletions dill/tests/test_classdef.py
Expand Up @@ -128,8 +128,8 @@ def test_dtype():
import numpy as np

dti = np.dtype('int')
assert np.dtype == dill.loads(dill.dumps(np.dtype))
assert dti == dill.loads(dill.dumps(dti))
assert np.dtype == dill.copy(np.dtype)
assert dti == dill.copy(dti)
except ImportError: pass


Expand All @@ -139,8 +139,7 @@ def test_array_nested():

x = np.array([1])
y = (x,)
dill.dumps(x)
assert y == dill.loads(dill.dumps(y))
assert y == dill.copy(y)

except ImportError: pass

Expand Down
47 changes: 15 additions & 32 deletions dill/tests/test_weakref.py
Expand Up @@ -14,15 +14,7 @@ class _class:
def _method(self):
pass

class _class2:
def __call__(self):
pass

class _newclass(object):
def _method(self):
pass

class _newclass2(object):
class _callable_class:
def __call__(self):
pass

Expand All @@ -32,42 +24,33 @@ def _function():

def test_weakref():
o = _class()
oc = _class2()
n = _newclass()
nc = _newclass2()
oc = _callable_class()
f = _function
z = _class
x = _newclass
x = _class

# ReferenceType
r = weakref.ref(o)
dr = weakref.ref(_class())
p = weakref.proxy(o)
dp = weakref.proxy(_class())
c = weakref.proxy(oc)
dc = weakref.proxy(_class2())
d_r = weakref.ref(_class())
fr = weakref.ref(f)
xr = weakref.ref(x)

m = weakref.ref(n)
dm = weakref.ref(_newclass())
t = weakref.proxy(n)
dt = weakref.proxy(_newclass())
d = weakref.proxy(nc)
dd = weakref.proxy(_newclass2())
# ProxyType
p = weakref.proxy(o)
d_p = weakref.proxy(_class())

fr = weakref.ref(f)
# CallableProxyType
cp = weakref.proxy(oc)
d_cp = weakref.proxy(_callable_class())
fp = weakref.proxy(f)
#zr = weakref.ref(z) #XXX: weakrefs not allowed for classobj objects
#zp = weakref.proxy(z) #XXX: weakrefs not allowed for classobj objects
xr = weakref.ref(x)
xp = weakref.proxy(x)

objlist = [r,dr,m,dm,fr,xr, p,dp,t,dt, c,dc,d,dd, fp,xp]
objlist = [r,d_r,fr,xr, p,d_p, cp,d_cp,fp,xp]
#dill.detect.trace(True)

for obj in objlist:
res = dill.detect.errors(obj)
if res:
print ("%s" % res)
#print ("%s:\n %s" % (obj, res))
print ("%r:\n %s" % (obj, res))
# else:
# print ("PASS: %s" % obj)
assert not res
Expand Down