Skip to content

Commit

Permalink
Merge pull request #33120 from tensorflow/perf
Browse files Browse the repository at this point in the history
[Cherrypick TF1.15] Fix performance slowdown
  • Loading branch information
goldiegadde committed Oct 7, 2019
2 parents 7764511 + a8ef0f5 commit 38ea9bb
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions tensorflow/python/util/module_wrapper.py
Expand Up @@ -90,6 +90,8 @@ def __init__( # pylint: disable=super-on-old-class
deprecation=True,
has_lite=False): # pylint: enable=super-on-old-class
super(TFModuleWrapper, self).__init__(wrapped.__name__)
# A cache for all members which do not print deprecations (any more).
self._tfmw_attr_map = {}
self.__dict__.update(wrapped.__dict__)
# Prefix all local attributes with _tfmw_ so that we can
# handle them differently in attribute access methods.
Expand Down Expand Up @@ -136,6 +138,8 @@ def _tfmw_add_deprecation_warning(self, name, attr):
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), full_name, rename)
self._tfmw_warning_count += 1
return True
return False

def _tfmw_import_module(self, name):
symbol_loc_info = self._tfmw_public_apis[name]
Expand All @@ -149,25 +153,45 @@ def _tfmw_import_module(self, name):
return attr

def __getattribute__(self, name): # pylint: disable=super-on-old-class
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
# Handle edge case where we unpickle and the object is not initialized yet
# and does not have _tfmw_attr_map attribute. Otherwise, calling
# __getattribute__ on __setstate__ will result in infinite recursion where
# we keep trying to get _tfmw_wrapped_module in __getattr__.
try:
attr_map = object.__getattribute__(self, '_tfmw_attr_map')
except AttributeError:
self._tfmw_attr_map = attr_map = {}

try:
# Use cached attrs if available
return attr_map[name]
except KeyError:
# Make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
attr_map[name] = attr
return attr

attr = super(TFModuleWrapper, self).__getattribute__(name)

# Return and cache dunders and our own members.
if name.startswith('__') or name.startswith('_tfmw_'):
attr_map[name] = attr
return attr

attr = super(TFModuleWrapper, self).__getattribute__(name)
if name.startswith('__') or name.startswith('_tfmw_'):
# Print deprecations, only cache functions after deprecation warnings have
# stopped.
if not (self._tfmw_print_deprecation_warnings and
self._tfmw_add_deprecation_warning(name, attr)):
attr_map[name] = attr
return attr

if self._tfmw_print_deprecation_warnings:
self._tfmw_add_deprecation_warning(name, attr)
return attr

def __getattr__(self, name):
try:
attr = getattr(self._tfmw_wrapped_module, name)
except AttributeError as e:
except AttributeError:
if not self._tfmw_public_apis:
raise
if name not in self._tfmw_public_apis:
Expand All @@ -184,6 +208,8 @@ def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
self.__dict__[arg] = val
if arg not in self.__all__ and arg != '__all__':
self.__all__.append(arg)
if arg in self._tfmw_attr_map:
self._tfmw_attr_map[arg] = val
super(TFModuleWrapper, self).__setattr__(arg, val)

def __dir__(self):
Expand Down

0 comments on commit 38ea9bb

Please sign in to comment.