aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-08-08 15:28:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 15:32:36 -0700
commit80dae290b7d4e24b005d419da866f2c22410d818 (patch)
treefab4391ce4bb6920d5d6c650e809a3ca42b9e9a0 /tensorflow/python/util
parenta1915c5f008cd7e6f01d563f83b36de783a76a0a (diff)
Automated rollback of commit 151aed209dc5e11059046367a3a115bba324800d
PiperOrigin-RevId: 207956477
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/tf_should_use.py169
-rw-r--r--tensorflow/python/util/tf_should_use_test.py80
2 files changed, 88 insertions, 161 deletions
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py
index ca6710bcf2..28e49afa02 100644
--- a/tensorflow/python/util/tf_should_use.py
+++ b/tensorflow/python/util/tf_should_use.py
@@ -17,124 +17,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import copy
-import sys
-import traceback
+import functools
+import types
import six # pylint: disable=unused-import
-from tensorflow.python.platform import tf_logging
+from tensorflow.python.eager import context
from tensorflow.python.util import tf_decorator
# pylint: enable=g-bad-import-order,g-import-not-at-top
-class _TFShouldUseHelper(object):
- """Object stored in TFShouldUse-wrapped objects.
-
- When it is deleted it will emit a warning or error if its `sate` method
- has not been called by time of deletion.
- """
-
- def __init__(self, type_, repr_, stack_frame, fatal_error_if_unsated):
- self._type = type_
- self._repr = repr_
- self._stack_frame = stack_frame
- self._fatal_error_if_unsated = fatal_error_if_unsated
- self._sated = False
-
- def sate(self):
- self._sated = True
- self._type = None
- self._repr = None
- self._stack_frame = None
- self._logging_module = None
-
- def __del__(self):
- if self._sated:
- return
- if self._fatal_error_if_unsated:
- logger = tf_logging.fatal
- else:
- logger = tf_logging.error
- creation_stack = ''.join(
- [line.rstrip() for line in traceback.format_stack(self._stack_frame)])
- logger(
- '==================================\n'
- 'Object was never used (type %s):\n%s\nIf you want to mark it as '
- 'used call its "mark_used()" method.\nIt was originally created '
- 'here:\n%s\n'
- '==================================' %
- (self._type, self._repr, creation_stack))
-
-
-def _new__init__(self, true_value, tf_should_use_helper):
- # pylint: disable=protected-access
- self._tf_should_use_helper = tf_should_use_helper
- self._true_value = true_value
-
-
-def _new__setattr__(self, key, value):
- if key in ('_tf_should_use_helper', '_true_value'):
- return object.__setattr__(self, key, value)
- return setattr(
- object.__getattribute__(self, '_true_value'),
- key, value)
-
-
-def _new__getattribute__(self, key):
- if key not in ('_tf_should_use_helper', '_true_value'):
- object.__getattribute__(self, '_tf_should_use_helper').sate()
- if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'):
- return object.__getattribute__(self, key)
- return getattr(object.__getattribute__(self, '_true_value'), key)
-
-
-def _new_mark_used(self, *args, **kwargs):
- object.__getattribute__(self, '_tf_should_use_helper').sate()
- try:
- mu = object.__getattribute__(
- object.__getattribute__(self, '_true_value'),
- 'mark_used')
- return mu(*args, **kwargs)
- except AttributeError:
- pass
-
-
-_WRAPPERS = dict()
-
-
-def _get_wrapper(x, tf_should_use_helper):
- """Create a wrapper for object x, whose class subclasses type(x).
-
- The wrapper will emit a warning if it is deleted without any of its
- properties being accessed or methods being called.
-
- Args:
- x: The instance to wrap.
- tf_should_use_helper: The object that tracks usage.
-
- Returns:
- An object wrapping `x`, of type `type(x)`.
- """
- type_x = type(x)
- memoized = _WRAPPERS.get(type_x, None)
- if memoized:
- return memoized(x, tf_should_use_helper)
-
- tx = copy.deepcopy(type_x)
- copy_tx = type(tx.__name__, tx.__bases__, dict(tx.__dict__))
- copy_tx.__init__ = _new__init__
- copy_tx.__getattribute__ = _new__getattribute__
- copy_tx.mark_used = _new_mark_used
- copy_tx.__setattr__ = _new__setattr__
- _WRAPPERS[type_x] = copy_tx
-
- return copy_tx(x, tf_should_use_helper)
-
-
+# TODO(b/65412899): Re-implement to avoid leaking python objects.
+# This function / class remains since the API is public (mark_used()).
def _add_should_use_warning(x, fatal_error=False):
"""Wraps object x so that if it is never used, a warning is logged.
+ Does nothing when executing eagerly.
+
Args:
x: Python object.
fatal_error: Python bool. If `True`, tf.logging.fatal is raised
@@ -144,22 +43,50 @@ def _add_should_use_warning(x, fatal_error=False):
An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)`
and is a very shallow wrapper for `x` which logs access into `x`.
"""
+ del fatal_error
if x is None or x == []: # pylint: disable=g-explicit-bool-comparison
return x
- # Extract the current frame for later use by traceback printing.
- try:
- raise ValueError()
- except ValueError:
- stack_frame = sys.exc_info()[2].tb_frame.f_back
+ if context.executing_eagerly():
+ # Typically not needed when executing eagerly (the main use case is for ops
+ # which need to be incorporated into the graph), and even the no-op wrapper
+ # creates reference cycles which require garbage collection.
+ return x
+
+ def override_method(method):
+ def fn(self, *args, **kwargs):
+ return method(self, *args, **kwargs)
+ return fn
+
+ class TFShouldUseWarningWrapper(type(x)):
+ """Wrapper for objects that keeps track of their use."""
+
+ def __init__(self, true_self):
+ self.__dict__ = true_self.__dict__
- tf_should_use_helper = _TFShouldUseHelper(
- type_=type(x),
- repr_=repr(x),
- stack_frame=stack_frame,
- fatal_error_if_unsated=fatal_error)
+ # Not sure why this pylint warning is being used; this is not an
+ # old class form.
+ # pylint: disable=super-on-old-class
+ def __getattribute__(self, name):
+ return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
+
+ def mark_used(self, *args, **kwargs):
+ return
- return _get_wrapper(x, tf_should_use_helper)
+ # pylint: enable=super-on-old-class
+
+ for name in dir(TFShouldUseWarningWrapper):
+ method = getattr(TFShouldUseWarningWrapper, name)
+ if not isinstance(method, types.FunctionType):
+ continue
+ if name in ('__init__', '__getattribute__', '__del__', 'mark_used'):
+ continue
+ setattr(TFShouldUseWarningWrapper, name,
+ functools.wraps(method)(override_method(method)))
+
+ wrapped = TFShouldUseWarningWrapper(x)
+ wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
+ return wrapped
def should_use_result(fn):
@@ -179,6 +106,8 @@ def should_use_result(fn):
- `t != 0`. In this case, comparison is done on types / ids.
- `isinstance(t, tf.Tensor)`. Similar to above.
+ Does nothing when executing eagerly.
+
Args:
fn: The function to wrap.
@@ -213,6 +142,8 @@ def must_use_result_or_fatal(fn):
- `t != 0`. In this case, comparison is done on types / ids.
- `isinstance(t, tf.Tensor)`. Similar to above.
+ Does nothing when executing eagerly.
+
Args:
fn: The function to wrap.
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index 4c09c2107e..4c6e48b11c 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -30,53 +30,48 @@ from tensorflow.python.util import tf_should_use
@contextlib.contextmanager
-def reroute_error():
+def reroute_error(captured):
"""Temporarily reroute errors written to tf_logging.error into `captured`."""
- with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error:
- with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal:
- yield error, fatal
+ del captured[:]
+ true_logger = tf_logging.error
+ def capture_errors(*args, **unused_kwargs):
+ captured.extend(args)
+ tf_logging.error = capture_errors
+ try:
+ yield
+ finally:
+ tf_logging.error = true_logger
class TfShouldUseTest(test.TestCase):
def testAddShouldUseWarningWhenNotUsed(self):
+ self.skipTest('b/65412899')
c = constant_op.constant(0, name='blah0')
- def in_this_function():
- h = tf_should_use._add_should_use_warning(c)
- del h
- with reroute_error() as (error, _):
- in_this_function()
- error.assert_called()
- msg = '\n'.join(error.call_args[0])
- self.assertIn('Object was never used', msg)
- self.assertIn('blah0:0', msg)
- self.assertIn('in_this_function', msg)
- self.assertFalse(gc.garbage)
-
- def testAddShouldUseFatalWhenNotUsed(self):
- c = constant_op.constant(0, name='blah0')
- def in_this_function():
- h = tf_should_use._add_should_use_warning(c, fatal_error=True)
- del h
- with reroute_error() as (_, fatal):
+ captured = []
+ with reroute_error(captured):
+ def in_this_function():
+ h = tf_should_use._add_should_use_warning(c)
+ del h
in_this_function()
- fatal.assert_called()
- msg = '\n'.join(fatal.call_args[0])
- self.assertIn('Object was never used', msg)
- self.assertIn('blah0:0', msg)
- self.assertIn('in_this_function', msg)
+ self.assertIn('Object was never used', '\n'.join(captured))
+ self.assertIn('blah0:0', '\n'.join(captured))
+ self.assertIn('in_this_function', '\n'.join(captured))
+ gc.collect()
self.assertFalse(gc.garbage)
def _testAddShouldUseWarningWhenUsed(self, fn, name):
c = constant_op.constant(0, name=name)
- with reroute_error() as (error, fatal):
+ captured = []
+ with reroute_error(captured):
h = tf_should_use._add_should_use_warning(c)
fn(h)
del h
- error.assert_not_called()
- fatal.assert_not_called()
+ self.assertNotIn('Object was never used', '\n'.join(captured))
+ self.assertNotIn('%s:0' % name, '\n'.join(captured))
def testAddShouldUseWarningWhenUsedWithAdd(self):
+ self.skipTest('b/65412899')
def add(h):
_ = h + 1
self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
@@ -84,6 +79,7 @@ class TfShouldUseTest(test.TestCase):
self.assertFalse(gc.garbage)
def testAddShouldUseWarningWhenUsedWithGetName(self):
+ self.skipTest('b/65412899')
def get_name(h):
_ = h.name
self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
@@ -91,35 +87,35 @@ class TfShouldUseTest(test.TestCase):
self.assertFalse(gc.garbage)
def testShouldUseResult(self):
+ self.skipTest('b/65412899')
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah2')
- with reroute_error() as (error, _):
+ captured = []
+ with reroute_error(captured):
return_const(0.0)
- error.assert_called()
- msg = '\n'.join(error.call_args[0])
- self.assertIn('Object was never used', msg)
- self.assertIn('blah2:0', msg)
- self.assertIn('return_const', msg)
+ self.assertIn('Object was never used', '\n'.join(captured))
+ self.assertIn('blah2:0', '\n'.join(captured))
+ self.assertIn('return_const', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)
def testShouldUseResultWhenNotReallyUsed(self):
+ self.skipTest('b/65412899')
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah3')
- with reroute_error() as (error, _):
+ captured = []
+ with reroute_error(captured):
with self.test_session():
return_const(0.0)
# Creating another op and executing it does not mark the
# unused op as being "used".
v = constant_op.constant(1.0, name='meh')
v.eval()
- error.assert_called()
- msg = '\n'.join(error.call_args[0])
- self.assertIn('Object was never used', msg)
- self.assertIn('blah3:0', msg)
- self.assertIn('return_const', msg)
+ self.assertIn('Object was never used', '\n'.join(captured))
+ self.assertIn('blah3:0', '\n'.join(captured))
+ self.assertIn('return_const', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)