aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/util/tf_should_use.py93
-rw-r--r--tensorflow/python/util/tf_should_use_test.py33
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh3
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh3
-rw-r--r--tensorflow/tools/pip_package/setup.py1
5 files changed, 94 insertions, 39 deletions
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py
index 88df3351e6..05c99856d2 100644
--- a/tensorflow/python/util/tf_should_use.py
+++ b/tensorflow/python/util/tf_should_use.py
@@ -17,14 +17,52 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import functools
+import itertools
import traceback
import types
+import six # pylint: disable=unused-import
+
+from backports import weakref # pylint: disable=g-bad-import-order
+
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import tf_decorator
+class _RefInfoField(
+ collections.namedtuple(
+ '_RefInfoField', ('type_', 'repr_', 'creation_stack', 'object_used'))):
+ pass
+
+
+# Thread-safe up to int32max/2 thanks to python's GIL; and may be safe even for
+# higher values in Python 3.4+. We don't expect to ever count higher than this.
+# https://mail.python.org/pipermail/python-list/2005-April/342279.html
+_REF_ITER = itertools.count()
+
+# Dictionary mapping id(obj) => _RefInfoField.
+_REF_INFO = {}
+
+
+def _deleted(obj_id, fatal_error):
+ obj = _REF_INFO[obj_id]
+ del _REF_INFO[obj_id]
+ if not obj.object_used:
+ if fatal_error:
+ logger = tf_logging.fatal
+ else:
+ logger = tf_logging.error
+ logger(
+ '==================================\n'
+ 'Object was never used (type %s):\n%s\nIf you want to mark it as '
+ 'used call its "mark_used()" method.\nIt was originally created '
+ 'here:\n%s\n'
+ '==================================' %
+ (obj.type_, obj.repr_, obj.creation_stack))
+
+
def _add_should_use_warning(x, fatal_error=False):
"""Wraps object x so that if it is never used, a warning is logged.
@@ -39,14 +77,14 @@ def _add_should_use_warning(x, fatal_error=False):
"""
if x is None: # special corner case where x is None
return x
- has_been_used = getattr(x, '_tf_object_has_been_used', None)
- if has_been_used is not None:
- x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access
+ if hasattr(x, '_tf_ref_id'): # this is already a TFShouldUseWarningWrapper
return x
def override_method(method):
def fn(self, *args, **kwargs):
- self._tf_object_has_been_used = True # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
+ object_used=True)
return method(self, *args, **kwargs)
return fn
@@ -55,38 +93,36 @@ def _add_should_use_warning(x, fatal_error=False):
def __init__(self, true_self):
self.__dict__ = true_self.__dict__
- stack = [x.strip() for x in traceback.format_stack()]
+ stack = [s.strip() for s in traceback.format_stack()]
# Remove top three stack entries from adding the wrapper
- self._tf_object_creation_stack = '\n'.join(stack[:-3])
- self._tf_object_has_been_used = False
+ self.creation_stack = '\n'.join(stack[:-3])
+ self._tf_ref_id = next(_REF_ITER)
+ _REF_INFO[self._tf_ref_id] = _RefInfoField(
+ type_=type(x),
+ repr_=repr(x),
+ creation_stack=stack,
+ object_used=False)
+
+ # Create a finalizer for self, which will be called when self is
+ # garbage collected. Can't add self as the args because the
+ # loop will break garbage collection. We keep track of
+ # ourselves via python ids.
+ weakref.finalize(self, _deleted, self._tf_ref_id, fatal_error)
# Not sure why this pylint warning is being used; this is not an
# old class form.
# pylint: disable=super-on-old-class
def __getattribute__(self, name):
- if name != '_tf_object_has_been_used':
- self._tf_object_has_been_used = True
+ if name == '_tf_ref_id':
+ return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
+ if self._tf_ref_id in _REF_INFO:
+ _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
+ object_used=True)
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
- def __del__(self):
- if not self._tf_object_has_been_used:
- if fatal_error:
- logger = tf_logging.fatal
- else:
- logger = tf_logging.error
- logger(
- '==================================\n'
- 'Object was never used (type %s):\n%s\nIf you want to mark it as '
- 'used call its "mark_used()" method.\nIt was originally created '
- 'here:\n%s\n'
- '==================================' %
- (type(x), x, self._tf_object_creation_stack))
-
- if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'):
- return super(TFShouldUseWarningWrapper, self).__del__()
-
def mark_used(self, *args, **kwargs):
- self._tf_object_has_been_used = True
+ _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
+ object_used=True)
if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
# pylint: enable=super-on-old-class
@@ -102,7 +138,8 @@ def _add_should_use_warning(x, fatal_error=False):
wrapped = TFShouldUseWarningWrapper(x)
wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
- wrapped._tf_object_has_been_used = False # pylint: disable=protected-access
+ ref_id = wrapped._tf_ref_id # pylint: disable=protected-access
+ _REF_INFO[ref_id] = _REF_INFO[ref_id]._replace(object_used=False)
return wrapped
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index 71d48e3dde..c826874400 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
+import gc
import sys
from tensorflow.python.framework import constant_op
@@ -45,7 +46,7 @@ def reroute_error(captured):
class TfShouldUseTest(test.TestCase):
def testAddShouldUseWarningWhenNotUsed(self):
- c = constant_op.constant(0, name='blah')
+ c = constant_op.constant(0, name='blah0')
captured = []
with reroute_error(captured):
def in_this_function():
@@ -53,44 +54,52 @@ class TfShouldUseTest(test.TestCase):
del h
in_this_function()
self.assertIn('Object was never used', '\n'.join(captured))
- self.assertIn('blah:0', '\n'.join(captured))
+ self.assertIn('blah0:0', '\n'.join(captured))
self.assertIn('in_this_function', '\n'.join(captured))
+ gc.collect()
+ self.assertFalse(gc.garbage)
- def _testAddShouldUseWarningWhenUsed(self, fn):
- c = constant_op.constant(0, name='blah')
+ def _testAddShouldUseWarningWhenUsed(self, fn, name):
+ c = constant_op.constant(0, name=name)
captured = []
with reroute_error(captured):
h = tf_should_use._add_should_use_warning(c)
fn(h)
del h
self.assertNotIn('Object was never used', '\n'.join(captured))
- self.assertNotIn('blah:0', '\n'.join(captured))
+ self.assertNotIn('%s:0' % name, '\n'.join(captured))
def testAddShouldUseWarningWhenUsedWithAdd(self):
def add(h):
_ = h + 1
- self._testAddShouldUseWarningWhenUsed(add)
+ self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
+ gc.collect()
+ self.assertFalse(gc.garbage)
def testAddShouldUseWarningWhenUsedWithGetName(self):
def get_name(h):
_ = h.name
- self._testAddShouldUseWarningWhenUsed(get_name)
+ self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
+ gc.collect()
+ self.assertFalse(gc.garbage)
def testShouldUseResult(self):
@tf_should_use.should_use_result
def return_const(value):
- return constant_op.constant(value, name='blah')
+ return constant_op.constant(value, name='blah2')
captured = []
with reroute_error(captured):
return_const(0.0)
self.assertIn('Object was never used', '\n'.join(captured))
- self.assertIn('blah:0', '\n'.join(captured))
+ self.assertIn('blah2:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))
+ gc.collect()
+ self.assertFalse(gc.garbage)
def testShouldUseResultWhenNotReallyUsed(self):
@tf_should_use.should_use_result
def return_const(value):
- return constant_op.constant(value, name='blah')
+ return constant_op.constant(value, name='blah3')
captured = []
with reroute_error(captured):
with self.test_session():
@@ -100,8 +109,10 @@ class TfShouldUseTest(test.TestCase):
v = constant_op.constant(1.0, name='meh')
v.eval()
self.assertIn('Object was never used', '\n'.join(captured))
- self.assertIn('blah:0', '\n'.join(captured))
+ self.assertIn('blah3:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))
+ gc.collect()
+ self.assertFalse(gc.garbage)
if __name__ == '__main__':
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index b8f9fc8453..8768852dc7 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -85,3 +85,6 @@ pip2 install mock
pip2 install portpicker
pip3 install portpicker
+
+pip2 install backports.weakref==1.0rc1
+pip3 install backports.weakref==1.0rc1
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index e7e2d256cd..edfc4e3a98 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -89,3 +89,6 @@ pip3.5 install wheel==0.29.0
pip3.5 install portpicker
pip3.5 install werkzeug
+
+pip3.5 install backports.weakref==1.0rc1
+
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index a85a220270..a1676203c7 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -39,6 +39,7 @@ REQUIRED_PACKAGES = [
'html5lib == 0.9999999', # identical to 1.0b8
'markdown == 2.2.0',
'bleach == 1.5.0',
+ 'backports.weakref == 1.0rc1',
]
project_name = 'tensorflow'