diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-08-08 15:28:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 15:32:36 -0700 |
commit | 80dae290b7d4e24b005d419da866f2c22410d818 (patch) | |
tree | fab4391ce4bb6920d5d6c650e809a3ca42b9e9a0 /tensorflow/python/util | |
parent | a1915c5f008cd7e6f01d563f83b36de783a76a0a (diff) |
Automated rollback of commit 151aed209dc5e11059046367a3a115bba324800d
PiperOrigin-RevId: 207956477
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/tf_should_use.py | 169 | ||||
-rw-r--r-- | tensorflow/python/util/tf_should_use_test.py | 80 |
2 files changed, 88 insertions, 161 deletions
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index ca6710bcf2..28e49afa02 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -17,124 +17,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy -import sys -import traceback +import functools +import types import six # pylint: disable=unused-import -from tensorflow.python.platform import tf_logging +from tensorflow.python.eager import context from tensorflow.python.util import tf_decorator # pylint: enable=g-bad-import-order,g-import-not-at-top -class _TFShouldUseHelper(object): - """Object stored in TFShouldUse-wrapped objects. - - When it is deleted it will emit a warning or error if its `sate` method - has not been called by time of deletion. - """ - - def __init__(self, type_, repr_, stack_frame, fatal_error_if_unsated): - self._type = type_ - self._repr = repr_ - self._stack_frame = stack_frame - self._fatal_error_if_unsated = fatal_error_if_unsated - self._sated = False - - def sate(self): - self._sated = True - self._type = None - self._repr = None - self._stack_frame = None - self._logging_module = None - - def __del__(self): - if self._sated: - return - if self._fatal_error_if_unsated: - logger = tf_logging.fatal - else: - logger = tf_logging.error - creation_stack = ''.join( - [line.rstrip() for line in traceback.format_stack(self._stack_frame)]) - logger( - '==================================\n' - 'Object was never used (type %s):\n%s\nIf you want to mark it as ' - 'used call its "mark_used()" method.\nIt was originally created ' - 'here:\n%s\n' - '==================================' % - (self._type, self._repr, creation_stack)) - - -def _new__init__(self, true_value, tf_should_use_helper): - # pylint: disable=protected-access - self._tf_should_use_helper = tf_should_use_helper - self._true_value = true_value - - -def _new__setattr__(self, key, value): - if key in ('_tf_should_use_helper', '_true_value'): - return object.__setattr__(self, key, value) - return setattr( - object.__getattribute__(self, '_true_value'), - key, value) - - -def _new__getattribute__(self, key): - if key not in ('_tf_should_use_helper', '_true_value'): - object.__getattribute__(self, '_tf_should_use_helper').sate() - if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'): - return object.__getattribute__(self, key) - return getattr(object.__getattribute__(self, '_true_value'), key) - - -def _new_mark_used(self, *args, **kwargs): - object.__getattribute__(self, '_tf_should_use_helper').sate() - try: - mu = object.__getattribute__( - object.__getattribute__(self, '_true_value'), - 'mark_used') - return mu(*args, **kwargs) - except AttributeError: - pass - - -_WRAPPERS = dict() - - -def _get_wrapper(x, tf_should_use_helper): - """Create a wrapper for object x, whose class subclasses type(x). - - The wrapper will emit a warning if it is deleted without any of its - properties being accessed or methods being called. - - Args: - x: The instance to wrap. - tf_should_use_helper: The object that tracks usage. - - Returns: - An object wrapping `x`, of type `type(x)`. - """ - type_x = type(x) - memoized = _WRAPPERS.get(type_x, None) - if memoized: - return memoized(x, tf_should_use_helper) - - tx = copy.deepcopy(type_x) - copy_tx = type(tx.__name__, tx.__bases__, dict(tx.__dict__)) - copy_tx.__init__ = _new__init__ - copy_tx.__getattribute__ = _new__getattribute__ - copy_tx.mark_used = _new_mark_used - copy_tx.__setattr__ = _new__setattr__ - _WRAPPERS[type_x] = copy_tx - - return copy_tx(x, tf_should_use_helper) - - +# TODO(b/65412899): Re-implement to avoid leaking python objects. +# This function / class remains since the API is public (mark_used()). def _add_should_use_warning(x, fatal_error=False): """Wraps object x so that if it is never used, a warning is logged. + Does nothing when executing eagerly. + Args: x: Python object. fatal_error: Python bool. If `True`, tf.logging.fatal is raised @@ -144,22 +43,50 @@ def _add_should_use_warning(x, fatal_error=False): An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)` and is a very shallow wrapper for `x` which logs access into `x`. """ + del fatal_error if x is None or x == []: # pylint: disable=g-explicit-bool-comparison return x - # Extract the current frame for later use by traceback printing. - try: - raise ValueError() - except ValueError: - stack_frame = sys.exc_info()[2].tb_frame.f_back + if context.executing_eagerly(): + # Typically not needed when executing eagerly (the main use case is for ops + # which need to be incorporated into the graph), and even the no-op wrapper + # creates reference cycles which require garbage collection. + return x + + def override_method(method): + def fn(self, *args, **kwargs): + return method(self, *args, **kwargs) + return fn + + class TFShouldUseWarningWrapper(type(x)): + """Wrapper for objects that keeps track of their use.""" + + def __init__(self, true_self): + self.__dict__ = true_self.__dict__ - tf_should_use_helper = _TFShouldUseHelper( - type_=type(x), - repr_=repr(x), - stack_frame=stack_frame, - fatal_error_if_unsated=fatal_error) + # Not sure why this pylint warning is being used; this is not an + # old class form. + # pylint: disable=super-on-old-class + def __getattribute__(self, name): + return super(TFShouldUseWarningWrapper, self).__getattribute__(name) + + def mark_used(self, *args, **kwargs): + return - return _get_wrapper(x, tf_should_use_helper) + # pylint: enable=super-on-old-class + + for name in dir(TFShouldUseWarningWrapper): + method = getattr(TFShouldUseWarningWrapper, name) + if not isinstance(method, types.FunctionType): + continue + if name in ('__init__', '__getattribute__', '__del__', 'mark_used'): + continue + setattr(TFShouldUseWarningWrapper, name, + functools.wraps(method)(override_method(method))) + + wrapped = TFShouldUseWarningWrapper(x) + wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects. + return wrapped def should_use_result(fn): @@ -179,6 +106,8 @@ def should_use_result(fn): - `t != 0`. In this case, comparison is done on types / ids. - `isinstance(t, tf.Tensor)`. Similar to above. + Does nothing when executing eagerly. + Args: fn: The function to wrap. @@ -213,6 +142,8 @@ def must_use_result_or_fatal(fn): - `t != 0`. In this case, comparison is done on types / ids. - `isinstance(t, tf.Tensor)`. Similar to above. + Does nothing when executing eagerly. + Args: fn: The function to wrap. diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py index 4c09c2107e..4c6e48b11c 100644 --- a/tensorflow/python/util/tf_should_use_test.py +++ b/tensorflow/python/util/tf_should_use_test.py @@ -30,53 +30,48 @@ from tensorflow.python.util import tf_should_use @contextlib.contextmanager -def reroute_error(): +def reroute_error(captured): """Temporarily reroute errors written to tf_logging.error into `captured`.""" - with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error: - with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal: - yield error, fatal + del captured[:] + true_logger = tf_logging.error + def capture_errors(*args, **unused_kwargs): + captured.extend(args) + tf_logging.error = capture_errors + try: + yield + finally: + tf_logging.error = true_logger class TfShouldUseTest(test.TestCase): def testAddShouldUseWarningWhenNotUsed(self): + self.skipTest('b/65412899') c = constant_op.constant(0, name='blah0') - def in_this_function(): - h = tf_should_use._add_should_use_warning(c) - del h - with reroute_error() as (error, _): - in_this_function() - error.assert_called() - msg = '\n'.join(error.call_args[0]) - self.assertIn('Object was never used', msg) - self.assertIn('blah0:0', msg) - self.assertIn('in_this_function', msg) - self.assertFalse(gc.garbage) - - def testAddShouldUseFatalWhenNotUsed(self): - c = constant_op.constant(0, name='blah0') - def in_this_function(): - h = tf_should_use._add_should_use_warning(c, fatal_error=True) - del h - with reroute_error() as (_, fatal): + captured = [] + with reroute_error(captured): + def in_this_function(): + h = tf_should_use._add_should_use_warning(c) + del h in_this_function() - fatal.assert_called() - msg = '\n'.join(fatal.call_args[0]) - self.assertIn('Object was never used', msg) - self.assertIn('blah0:0', msg) - self.assertIn('in_this_function', msg) + self.assertIn('Object was never used', '\n'.join(captured)) + self.assertIn('blah0:0', '\n'.join(captured)) + self.assertIn('in_this_function', '\n'.join(captured)) + gc.collect() self.assertFalse(gc.garbage) def _testAddShouldUseWarningWhenUsed(self, fn, name): c = constant_op.constant(0, name=name) - with reroute_error() as (error, fatal): + captured = [] + with reroute_error(captured): h = tf_should_use._add_should_use_warning(c) fn(h) del h - error.assert_not_called() - fatal.assert_not_called() + self.assertNotIn('Object was never used', '\n'.join(captured)) + self.assertNotIn('%s:0' % name, '\n'.join(captured)) def testAddShouldUseWarningWhenUsedWithAdd(self): + self.skipTest('b/65412899') def add(h): _ = h + 1 self._testAddShouldUseWarningWhenUsed(add, name='blah_add') @@ -84,6 +79,7 @@ class TfShouldUseTest(test.TestCase): self.assertFalse(gc.garbage) def testAddShouldUseWarningWhenUsedWithGetName(self): + self.skipTest('b/65412899') def get_name(h): _ = h.name self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name') @@ -91,35 +87,35 @@ class TfShouldUseTest(test.TestCase): self.assertFalse(gc.garbage) def testShouldUseResult(self): + self.skipTest('b/65412899') @tf_should_use.should_use_result def return_const(value): return constant_op.constant(value, name='blah2') - with reroute_error() as (error, _): + captured = [] + with reroute_error(captured): return_const(0.0) - error.assert_called() - msg = '\n'.join(error.call_args[0]) - self.assertIn('Object was never used', msg) - self.assertIn('blah2:0', msg) - self.assertIn('return_const', msg) + self.assertIn('Object was never used', '\n'.join(captured)) + self.assertIn('blah2:0', '\n'.join(captured)) + self.assertIn('return_const', '\n'.join(captured)) gc.collect() self.assertFalse(gc.garbage) def testShouldUseResultWhenNotReallyUsed(self): + self.skipTest('b/65412899') @tf_should_use.should_use_result def return_const(value): return constant_op.constant(value, name='blah3') - with reroute_error() as (error, _): + captured = [] + with reroute_error(captured): with self.test_session(): return_const(0.0) # Creating another op and executing it does not mark the # unused op as being "used". v = constant_op.constant(1.0, name='meh') v.eval() - error.assert_called() - msg = '\n'.join(error.call_args[0]) - self.assertIn('Object was never used', msg) - self.assertIn('blah3:0', msg) - self.assertIn('return_const', msg) + self.assertIn('Object was never used', '\n'.join(captured)) + self.assertIn('blah3:0', '\n'.join(captured)) + self.assertIn('return_const', '\n'.join(captured)) gc.collect() self.assertFalse(gc.garbage) |