diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-08-08 13:04:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 13:08:27 -0700 |
commit | 151aed209dc5e11059046367a3a115bba324800d (patch) | |
tree | 6a4a5db0745cc1043615c7a8bdb4c2e846f1198f /tensorflow/python/util | |
parent | 054b0463ebd748b7fe41e5ac22337c8df0ed9821 (diff) |
Bring back TFShouldUse without the memory leaks.
PiperOrigin-RevId: 207933109
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/tf_should_use.py | 169 | ||||
-rw-r--r-- | tensorflow/python/util/tf_should_use_test.py | 80 |
2 files changed, 161 insertions, 88 deletions
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 28e49afa02..ca6710bcf2 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -17,23 +17,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import types +import copy +import sys +import traceback import six # pylint: disable=unused-import -from tensorflow.python.eager import context +from tensorflow.python.platform import tf_logging from tensorflow.python.util import tf_decorator # pylint: enable=g-bad-import-order,g-import-not-at-top -# TODO(b/65412899): Re-implement to avoid leaking python objects. -# This function / class remains since the API is public (mark_used()). +class _TFShouldUseHelper(object): + """Object stored in TFShouldUse-wrapped objects. + + When it is deleted it will emit a warning or error if its `sate` method + has not been called by time of deletion. + """ + + def __init__(self, type_, repr_, stack_frame, fatal_error_if_unsated): + self._type = type_ + self._repr = repr_ + self._stack_frame = stack_frame + self._fatal_error_if_unsated = fatal_error_if_unsated + self._sated = False + + def sate(self): + self._sated = True + self._type = None + self._repr = None + self._stack_frame = None + self._logging_module = None + + def __del__(self): + if self._sated: + return + if self._fatal_error_if_unsated: + logger = tf_logging.fatal + else: + logger = tf_logging.error + creation_stack = ''.join( + [line.rstrip() for line in traceback.format_stack(self._stack_frame)]) + logger( + '==================================\n' + 'Object was never used (type %s):\n%s\nIf you want to mark it as ' + 'used call its "mark_used()" method.\nIt was originally created ' + 'here:\n%s\n' + '==================================' % + (self._type, self._repr, creation_stack)) + + +def _new__init__(self, true_value, tf_should_use_helper): + # pylint: disable=protected-access + self._tf_should_use_helper = tf_should_use_helper + self._true_value = true_value + + +def _new__setattr__(self, key, value): + if key in ('_tf_should_use_helper', '_true_value'): + return object.__setattr__(self, key, value) + return setattr( + object.__getattribute__(self, '_true_value'), + key, value) + + +def _new__getattribute__(self, key): + if key not in ('_tf_should_use_helper', '_true_value'): + object.__getattribute__(self, '_tf_should_use_helper').sate() + if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'): + return object.__getattribute__(self, key) + return getattr(object.__getattribute__(self, '_true_value'), key) + + +def _new_mark_used(self, *args, **kwargs): + object.__getattribute__(self, '_tf_should_use_helper').sate() + try: + mu = object.__getattribute__( + object.__getattribute__(self, '_true_value'), + 'mark_used') + return mu(*args, **kwargs) + except AttributeError: + pass + + +_WRAPPERS = dict() + + +def _get_wrapper(x, tf_should_use_helper): + """Create a wrapper for object x, whose class subclasses type(x). + + The wrapper will emit a warning if it is deleted without any of its + properties being accessed or methods being called. + + Args: + x: The instance to wrap. + tf_should_use_helper: The object that tracks usage. + + Returns: + An object wrapping `x`, of type `type(x)`. + """ + type_x = type(x) + memoized = _WRAPPERS.get(type_x, None) + if memoized: + return memoized(x, tf_should_use_helper) + + tx = copy.deepcopy(type_x) + copy_tx = type(tx.__name__, tx.__bases__, dict(tx.__dict__)) + copy_tx.__init__ = _new__init__ + copy_tx.__getattribute__ = _new__getattribute__ + copy_tx.mark_used = _new_mark_used + copy_tx.__setattr__ = _new__setattr__ + _WRAPPERS[type_x] = copy_tx + + return copy_tx(x, tf_should_use_helper) + + def _add_should_use_warning(x, fatal_error=False): """Wraps object x so that if it is never used, a warning is logged. - Does nothing when executing eagerly. - Args: x: Python object. fatal_error: Python bool. If `True`, tf.logging.fatal is raised @@ -43,50 +144,22 @@ def _add_should_use_warning(x, fatal_error=False): An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)` and is a very shallow wrapper for `x` which logs access into `x`. """ - del fatal_error if x is None or x == []: # pylint: disable=g-explicit-bool-comparison return x - if context.executing_eagerly(): - # Typically not needed when executing eagerly (the main use case is for ops - # which need to be incorporated into the graph), and even the no-op wrapper - # creates reference cycles which require garbage collection. - return x - - def override_method(method): - def fn(self, *args, **kwargs): - return method(self, *args, **kwargs) - return fn - - class TFShouldUseWarningWrapper(type(x)): - """Wrapper for objects that keeps track of their use.""" - - def __init__(self, true_self): - self.__dict__ = true_self.__dict__ + # Extract the current frame for later use by traceback printing. + try: + raise ValueError() + except ValueError: + stack_frame = sys.exc_info()[2].tb_frame.f_back - # Not sure why this pylint warning is being used; this is not an - # old class form. - # pylint: disable=super-on-old-class - def __getattribute__(self, name): - return super(TFShouldUseWarningWrapper, self).__getattribute__(name) - - def mark_used(self, *args, **kwargs): - return + tf_should_use_helper = _TFShouldUseHelper( + type_=type(x), + repr_=repr(x), + stack_frame=stack_frame, + fatal_error_if_unsated=fatal_error) - # pylint: enable=super-on-old-class - - for name in dir(TFShouldUseWarningWrapper): - method = getattr(TFShouldUseWarningWrapper, name) - if not isinstance(method, types.FunctionType): - continue - if name in ('__init__', '__getattribute__', '__del__', 'mark_used'): - continue - setattr(TFShouldUseWarningWrapper, name, - functools.wraps(method)(override_method(method))) - - wrapped = TFShouldUseWarningWrapper(x) - wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects. - return wrapped + return _get_wrapper(x, tf_should_use_helper) def should_use_result(fn): @@ -106,8 +179,6 @@ def should_use_result(fn): - `t != 0`. In this case, comparison is done on types / ids. - `isinstance(t, tf.Tensor)`. Similar to above. - Does nothing when executing eagerly. - Args: fn: The function to wrap. @@ -142,8 +213,6 @@ def must_use_result_or_fatal(fn): - `t != 0`. In this case, comparison is done on types / ids. - `isinstance(t, tf.Tensor)`. Similar to above. - Does nothing when executing eagerly. - Args: fn: The function to wrap. diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py index 4c6e48b11c..4c09c2107e 100644 --- a/tensorflow/python/util/tf_should_use_test.py +++ b/tensorflow/python/util/tf_should_use_test.py @@ -30,48 +30,53 @@ from tensorflow.python.util import tf_should_use @contextlib.contextmanager -def reroute_error(captured): +def reroute_error(): """Temporarily reroute errors written to tf_logging.error into `captured`.""" - del captured[:] - true_logger = tf_logging.error - def capture_errors(*args, **unused_kwargs): - captured.extend(args) - tf_logging.error = capture_errors - try: - yield - finally: - tf_logging.error = true_logger + with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error: + with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal: + yield error, fatal class TfShouldUseTest(test.TestCase): def testAddShouldUseWarningWhenNotUsed(self): - self.skipTest('b/65412899') c = constant_op.constant(0, name='blah0') - captured = [] - with reroute_error(captured): - def in_this_function(): - h = tf_should_use._add_should_use_warning(c) - del h + def in_this_function(): + h = tf_should_use._add_should_use_warning(c) + del h + with reroute_error() as (error, _): in_this_function() - self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah0:0', '\n'.join(captured)) - self.assertIn('in_this_function', '\n'.join(captured)) - gc.collect() + error.assert_called() + msg = '\n'.join(error.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah0:0', msg) + self.assertIn('in_this_function', msg) + self.assertFalse(gc.garbage) + + def testAddShouldUseFatalWhenNotUsed(self): + c = constant_op.constant(0, name='blah0') + def in_this_function(): + h = tf_should_use._add_should_use_warning(c, fatal_error=True) + del h + with reroute_error() as (_, fatal): + in_this_function() + fatal.assert_called() + msg = '\n'.join(fatal.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah0:0', msg) + self.assertIn('in_this_function', msg) self.assertFalse(gc.garbage) def _testAddShouldUseWarningWhenUsed(self, fn, name): c = constant_op.constant(0, name=name) - captured = [] - with reroute_error(captured): + with reroute_error() as (error, fatal): h = tf_should_use._add_should_use_warning(c) fn(h) del h - self.assertNotIn('Object was never used', '\n'.join(captured)) - self.assertNotIn('%s:0' % name, '\n'.join(captured)) + error.assert_not_called() + fatal.assert_not_called() def testAddShouldUseWarningWhenUsedWithAdd(self): - self.skipTest('b/65412899') def add(h): _ = h + 1 self._testAddShouldUseWarningWhenUsed(add, name='blah_add') @@ -79,7 +84,6 @@ class TfShouldUseTest(test.TestCase): self.assertFalse(gc.garbage) def testAddShouldUseWarningWhenUsedWithGetName(self): - self.skipTest('b/65412899') def get_name(h): _ = h.name self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name') @@ -87,35 +91,35 @@ class TfShouldUseTest(test.TestCase): self.assertFalse(gc.garbage) def testShouldUseResult(self): - self.skipTest('b/65412899') @tf_should_use.should_use_result def return_const(value): return constant_op.constant(value, name='blah2') - captured = [] - with reroute_error(captured): + with reroute_error() as (error, _): return_const(0.0) - self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah2:0', '\n'.join(captured)) - self.assertIn('return_const', '\n'.join(captured)) + error.assert_called() + msg = '\n'.join(error.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah2:0', msg) + self.assertIn('return_const', msg) gc.collect() self.assertFalse(gc.garbage) def testShouldUseResultWhenNotReallyUsed(self): - self.skipTest('b/65412899') @tf_should_use.should_use_result def return_const(value): return constant_op.constant(value, name='blah3') - captured = [] - with reroute_error(captured): + with reroute_error() as (error, _): with self.test_session(): return_const(0.0) # Creating another op and executing it does not mark the # unused op as being "used". v = constant_op.constant(1.0, name='meh') v.eval() - self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah3:0', '\n'.join(captured)) - self.assertIn('return_const', '\n'.join(captured)) + error.assert_called() + msg = '\n'.join(error.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah3:0', msg) + self.assertIn('return_const', msg) gc.collect() self.assertFalse(gc.garbage) |