aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-08-08 13:04:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 13:08:27 -0700
commit151aed209dc5e11059046367a3a115bba324800d (patch)
tree6a4a5db0745cc1043615c7a8bdb4c2e846f1198f /tensorflow/python/util
parent054b0463ebd748b7fe41e5ac22337c8df0ed9821 (diff)
Bring back TFShouldUse without the memory leaks.
PiperOrigin-RevId: 207933109
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/tf_should_use.py169
-rw-r--r--tensorflow/python/util/tf_should_use_test.py80
2 files changed, 161 insertions, 88 deletions
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py
index 28e49afa02..ca6710bcf2 100644
--- a/tensorflow/python/util/tf_should_use.py
+++ b/tensorflow/python/util/tf_should_use.py
@@ -17,23 +17,124 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
-import types
+import copy
+import sys
+import traceback
import six # pylint: disable=unused-import
-from tensorflow.python.eager import context
+from tensorflow.python.platform import tf_logging
from tensorflow.python.util import tf_decorator
# pylint: enable=g-bad-import-order,g-import-not-at-top
-# TODO(b/65412899): Re-implement to avoid leaking python objects.
-# This function / class remains since the API is public (mark_used()).
+class _TFShouldUseHelper(object):
+ """Object stored in TFShouldUse-wrapped objects.
+
+ When it is deleted it will emit a warning or error if its `sate` method
+ has not been called by time of deletion.
+ """
+
+ def __init__(self, type_, repr_, stack_frame, fatal_error_if_unsated):
+ self._type = type_
+ self._repr = repr_
+ self._stack_frame = stack_frame
+ self._fatal_error_if_unsated = fatal_error_if_unsated
+ self._sated = False
+
+ def sate(self):
+ self._sated = True
+ self._type = None
+ self._repr = None
+ self._stack_frame = None
+ self._logging_module = None
+
+ def __del__(self):
+ if self._sated:
+ return
+ if self._fatal_error_if_unsated:
+ logger = tf_logging.fatal
+ else:
+ logger = tf_logging.error
+ creation_stack = ''.join(
+ [line.rstrip() for line in traceback.format_stack(self._stack_frame)])
+ logger(
+ '==================================\n'
+ 'Object was never used (type %s):\n%s\nIf you want to mark it as '
+ 'used call its "mark_used()" method.\nIt was originally created '
+ 'here:\n%s\n'
+ '==================================' %
+ (self._type, self._repr, creation_stack))
+
+
+def _new__init__(self, true_value, tf_should_use_helper):
+ # pylint: disable=protected-access
+ self._tf_should_use_helper = tf_should_use_helper
+ self._true_value = true_value
+
+
+def _new__setattr__(self, key, value):
+ if key in ('_tf_should_use_helper', '_true_value'):
+ return object.__setattr__(self, key, value)
+ return setattr(
+ object.__getattribute__(self, '_true_value'),
+ key, value)
+
+
+def _new__getattribute__(self, key):
+ if key not in ('_tf_should_use_helper', '_true_value'):
+ object.__getattribute__(self, '_tf_should_use_helper').sate()
+ if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'):
+ return object.__getattribute__(self, key)
+ return getattr(object.__getattribute__(self, '_true_value'), key)
+
+
+def _new_mark_used(self, *args, **kwargs):
+ object.__getattribute__(self, '_tf_should_use_helper').sate()
+ try:
+ mu = object.__getattribute__(
+ object.__getattribute__(self, '_true_value'),
+ 'mark_used')
+ return mu(*args, **kwargs)
+ except AttributeError:
+ pass
+
+
+_WRAPPERS = dict()
+
+
+def _get_wrapper(x, tf_should_use_helper):
+ """Create a wrapper for object x, whose class subclasses type(x).
+
+ The wrapper will emit a warning if it is deleted without any of its
+ properties being accessed or methods being called.
+
+ Args:
+ x: The instance to wrap.
+ tf_should_use_helper: The object that tracks usage.
+
+ Returns:
+ An object wrapping `x`, of type `type(x)`.
+ """
+ type_x = type(x)
+ memoized = _WRAPPERS.get(type_x, None)
+ if memoized:
+ return memoized(x, tf_should_use_helper)
+
+ tx = copy.deepcopy(type_x)
+ copy_tx = type(tx.__name__, tx.__bases__, dict(tx.__dict__))
+ copy_tx.__init__ = _new__init__
+ copy_tx.__getattribute__ = _new__getattribute__
+ copy_tx.mark_used = _new_mark_used
+ copy_tx.__setattr__ = _new__setattr__
+ _WRAPPERS[type_x] = copy_tx
+
+ return copy_tx(x, tf_should_use_helper)
+
+
def _add_should_use_warning(x, fatal_error=False):
"""Wraps object x so that if it is never used, a warning is logged.
- Does nothing when executing eagerly.
-
Args:
x: Python object.
fatal_error: Python bool. If `True`, tf.logging.fatal is raised
@@ -43,50 +144,22 @@ def _add_should_use_warning(x, fatal_error=False):
An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)`
and is a very shallow wrapper for `x` which logs access into `x`.
"""
- del fatal_error
if x is None or x == []: # pylint: disable=g-explicit-bool-comparison
return x
- if context.executing_eagerly():
- # Typically not needed when executing eagerly (the main use case is for ops
- # which need to be incorporated into the graph), and even the no-op wrapper
- # creates reference cycles which require garbage collection.
- return x
-
- def override_method(method):
- def fn(self, *args, **kwargs):
- return method(self, *args, **kwargs)
- return fn
-
- class TFShouldUseWarningWrapper(type(x)):
- """Wrapper for objects that keeps track of their use."""
-
- def __init__(self, true_self):
- self.__dict__ = true_self.__dict__
+ # Extract the current frame for later use by traceback printing.
+ try:
+ raise ValueError()
+ except ValueError:
+ stack_frame = sys.exc_info()[2].tb_frame.f_back
- # Not sure why this pylint warning is being used; this is not an
- # old class form.
- # pylint: disable=super-on-old-class
- def __getattribute__(self, name):
- return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
-
- def mark_used(self, *args, **kwargs):
- return
+ tf_should_use_helper = _TFShouldUseHelper(
+ type_=type(x),
+ repr_=repr(x),
+ stack_frame=stack_frame,
+ fatal_error_if_unsated=fatal_error)
- # pylint: enable=super-on-old-class
-
- for name in dir(TFShouldUseWarningWrapper):
- method = getattr(TFShouldUseWarningWrapper, name)
- if not isinstance(method, types.FunctionType):
- continue
- if name in ('__init__', '__getattribute__', '__del__', 'mark_used'):
- continue
- setattr(TFShouldUseWarningWrapper, name,
- functools.wraps(method)(override_method(method)))
-
- wrapped = TFShouldUseWarningWrapper(x)
- wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
- return wrapped
+ return _get_wrapper(x, tf_should_use_helper)
def should_use_result(fn):
@@ -106,8 +179,6 @@ def should_use_result(fn):
- `t != 0`. In this case, comparison is done on types / ids.
- `isinstance(t, tf.Tensor)`. Similar to above.
- Does nothing when executing eagerly.
-
Args:
fn: The function to wrap.
@@ -142,8 +213,6 @@ def must_use_result_or_fatal(fn):
- `t != 0`. In this case, comparison is done on types / ids.
- `isinstance(t, tf.Tensor)`. Similar to above.
- Does nothing when executing eagerly.
-
Args:
fn: The function to wrap.
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index 4c6e48b11c..4c09c2107e 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -30,48 +30,53 @@ from tensorflow.python.util import tf_should_use
@contextlib.contextmanager
-def reroute_error(captured):
+def reroute_error():
"""Temporarily reroute errors written to tf_logging.error into `captured`."""
- del captured[:]
- true_logger = tf_logging.error
- def capture_errors(*args, **unused_kwargs):
- captured.extend(args)
- tf_logging.error = capture_errors
- try:
- yield
- finally:
- tf_logging.error = true_logger
+ with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error:
+ with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal:
+ yield error, fatal
class TfShouldUseTest(test.TestCase):
def testAddShouldUseWarningWhenNotUsed(self):
- self.skipTest('b/65412899')
c = constant_op.constant(0, name='blah0')
- captured = []
- with reroute_error(captured):
- def in_this_function():
- h = tf_should_use._add_should_use_warning(c)
- del h
+ def in_this_function():
+ h = tf_should_use._add_should_use_warning(c)
+ del h
+ with reroute_error() as (error, _):
in_this_function()
- self.assertIn('Object was never used', '\n'.join(captured))
- self.assertIn('blah0:0', '\n'.join(captured))
- self.assertIn('in_this_function', '\n'.join(captured))
- gc.collect()
+ error.assert_called()
+ msg = '\n'.join(error.call_args[0])
+ self.assertIn('Object was never used', msg)
+ self.assertIn('blah0:0', msg)
+ self.assertIn('in_this_function', msg)
+ self.assertFalse(gc.garbage)
+
+ def testAddShouldUseFatalWhenNotUsed(self):
+ c = constant_op.constant(0, name='blah0')
+ def in_this_function():
+ h = tf_should_use._add_should_use_warning(c, fatal_error=True)
+ del h
+ with reroute_error() as (_, fatal):
+ in_this_function()
+ fatal.assert_called()
+ msg = '\n'.join(fatal.call_args[0])
+ self.assertIn('Object was never used', msg)
+ self.assertIn('blah0:0', msg)
+ self.assertIn('in_this_function', msg)
self.assertFalse(gc.garbage)
def _testAddShouldUseWarningWhenUsed(self, fn, name):
c = constant_op.constant(0, name=name)
- captured = []
- with reroute_error(captured):
+ with reroute_error() as (error, fatal):
h = tf_should_use._add_should_use_warning(c)
fn(h)
del h
- self.assertNotIn('Object was never used', '\n'.join(captured))
- self.assertNotIn('%s:0' % name, '\n'.join(captured))
+ error.assert_not_called()
+ fatal.assert_not_called()
def testAddShouldUseWarningWhenUsedWithAdd(self):
- self.skipTest('b/65412899')
def add(h):
_ = h + 1
self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
@@ -79,7 +84,6 @@ class TfShouldUseTest(test.TestCase):
self.assertFalse(gc.garbage)
def testAddShouldUseWarningWhenUsedWithGetName(self):
- self.skipTest('b/65412899')
def get_name(h):
_ = h.name
self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
@@ -87,35 +91,35 @@ class TfShouldUseTest(test.TestCase):
self.assertFalse(gc.garbage)
def testShouldUseResult(self):
- self.skipTest('b/65412899')
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah2')
- captured = []
- with reroute_error(captured):
+ with reroute_error() as (error, _):
return_const(0.0)
- self.assertIn('Object was never used', '\n'.join(captured))
- self.assertIn('blah2:0', '\n'.join(captured))
- self.assertIn('return_const', '\n'.join(captured))
+ error.assert_called()
+ msg = '\n'.join(error.call_args[0])
+ self.assertIn('Object was never used', msg)
+ self.assertIn('blah2:0', msg)
+ self.assertIn('return_const', msg)
gc.collect()
self.assertFalse(gc.garbage)
def testShouldUseResultWhenNotReallyUsed(self):
- self.skipTest('b/65412899')
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah3')
- captured = []
- with reroute_error(captured):
+ with reroute_error() as (error, _):
with self.test_session():
return_const(0.0)
# Creating another op and executing it does not mark the
# unused op as being "used".
v = constant_op.constant(1.0, name='meh')
v.eval()
- self.assertIn('Object was never used', '\n'.join(captured))
- self.assertIn('blah3:0', '\n'.join(captured))
- self.assertIn('return_const', '\n'.join(captured))
+ error.assert_called()
+ msg = '\n'.join(error.call_args[0])
+ self.assertIn('Object was never used', msg)
+ self.assertIn('blah3:0', msg)
+ self.assertIn('return_const', msg)
gc.collect()
self.assertFalse(gc.garbage)