Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherrypick TF1.15] Fix performance slowdown #33120

Merged
merged 1 commit into from Oct 7, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 38 additions & 12 deletions tensorflow/python/util/module_wrapper.py
Expand Up @@ -90,6 +90,8 @@ def __init__( # pylint: disable=super-on-old-class
deprecation=True,
has_lite=False): # pylint: enable=super-on-old-class
super(TFModuleWrapper, self).__init__(wrapped.__name__)
# A cache for all members which do not print deprecations (any more).
self._tfmw_attr_map = {}
self.__dict__.update(wrapped.__dict__)
# Prefix all local attributes with _tfmw_ so that we can
# handle them differently in attribute access methods.
Expand Down Expand Up @@ -136,6 +138,8 @@ def _tfmw_add_deprecation_warning(self, name, attr):
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), full_name, rename)
self._tfmw_warning_count += 1
return True
return False

def _tfmw_import_module(self, name):
symbol_loc_info = self._tfmw_public_apis[name]
Expand All @@ -149,25 +153,45 @@ def _tfmw_import_module(self, name):
return attr

def __getattribute__(self, name): # pylint: disable=super-on-old-class
# Workaround to make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
# Handle edge case where we unpickle and the object is not initialized yet
# and does not have _tfmw_attr_map attribute. Otherwise, calling
# __getattribute__ on __setstate__ will result in infinite recursion where
# we keep trying to get _tfmw_wrapped_module in __getattr__.
try:
attr_map = object.__getattribute__(self, '_tfmw_attr_map')
except AttributeError:
self._tfmw_attr_map = attr_map = {}

try:
# Use cached attrs if available
return attr_map[name]
except KeyError:
# Make sure we do not import from tensorflow/lite/__init__.py
if name == 'lite':
if self._tfmw_has_lite:
attr = self._tfmw_import_module(name)
setattr(self._tfmw_wrapped_module, 'lite', attr)
attr_map[name] = attr
return attr

attr = super(TFModuleWrapper, self).__getattribute__(name)

# Return and cache dunders and our own members.
if name.startswith('__') or name.startswith('_tfmw_'):
attr_map[name] = attr
return attr

attr = super(TFModuleWrapper, self).__getattribute__(name)
if name.startswith('__') or name.startswith('_tfmw_'):
# Print deprecations, only cache functions after deprecation warnings have
# stopped.
if not (self._tfmw_print_deprecation_warnings and
self._tfmw_add_deprecation_warning(name, attr)):
attr_map[name] = attr
return attr

if self._tfmw_print_deprecation_warnings:
self._tfmw_add_deprecation_warning(name, attr)
return attr

def __getattr__(self, name):
try:
attr = getattr(self._tfmw_wrapped_module, name)
except AttributeError as e:
except AttributeError:
if not self._tfmw_public_apis:
raise
if name not in self._tfmw_public_apis:
Expand All @@ -184,6 +208,8 @@ def __setattr__(self, arg, val): # pylint: disable=super-on-old-class
self.__dict__[arg] = val
if arg not in self.__all__ and arg != '__all__':
self.__all__.append(arg)
if arg in self._tfmw_attr_map:
self._tfmw_attr_map[arg] = val
super(TFModuleWrapper, self).__setattr__(arg, val)

def __dir__(self):
Expand Down