aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-05-01 11:55:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-01 13:09:58 -0700
commitdbdc9fc2ed2e8f9b247c8b2980b0bcd079e39359 (patch)
tree3cc3c5b059c28c0f6203f91b43f0a0ff3c9b87fc /tensorflow
parentbb17b9665c189d1349d783219306100204ef2352 (diff)
Add should-use for commonly misused ops.
Fixed a bunch of invalid callers (by initially using should_use_with_fatal and looking for failing unit tests) Change: 154748211
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py3
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops.py2
-rw-r--r--tensorflow/python/ops/image_ops_impl.py38
-rw-r--r--tensorflow/python/ops/image_ops_test.py14
-rw-r--r--tensorflow/python/ops/resources.py2
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py6
-rw-r--r--tensorflow/python/ops/variables.py7
-rw-r--r--tensorflow/python/util/tf_should_use.py39
-rw-r--r--tensorflow/python/util/tf_should_use_test.py8
11 files changed, 81 insertions, 42 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index ac027f0f43..8aebb79b91 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -232,7 +232,7 @@ class QuantizedDistribution(distributions.Distribution):
graph_parents = self._dist._graph_parents # pylint: disable=protected-access
checks = []
- if low is not None and high is not None:
+ if validate_args and low is not None and high is not None:
message = "low must be strictly less than high."
checks.append(
check_ops.assert_less(
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 4ca3120dec..123db50d32 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -1287,9 +1287,6 @@ class Estimator(BaseEstimator):
else:
saver_for_restore = saver.Saver(sharded=True)
with tf_session.Session('') as session:
- variables.initialize_local_variables()
- data_flow_ops.tables_initializer()
- resources.initialize_resources(resources.shared_resources())
saver_for_restore.restore(session, checkpoint_path)
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 0fec42e1db..41fe29e006 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -1128,7 +1128,7 @@ class TensorArrayTest(test.TestCase):
dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=True)
self.assertEqual(0, ta.size().eval())
# Don't actually perform the pack. This stores the static shape.
- ta.unstack(array_ops.zeros([0, 3, 5]))
+ ta.unstack(array_ops.zeros([0, 3, 5])).mark_used()
packed = ta.stack()
self.assertAllEqual([0, 3, 5], packed.eval().shape)
# Concatenating zero tensors along their first dimension gives a
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 75a537643e..ebe5259de5 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -72,6 +72,7 @@ from tensorflow.python.ops.gen_control_flow_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_should_use
# We override the 'tuple' for a control flow op, so we keep python's
@@ -84,6 +85,7 @@ _basetuple = tuple
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
+@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 78621d3b57..c79f413c5e 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -218,7 +218,8 @@ def random_flip_up_down(image, seed=None):
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(mirror_cond,
@@ -246,7 +247,8 @@ def random_flip_left_right(image, seed=None):
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(mirror_cond,
@@ -273,7 +275,8 @@ def flip_left_right(image):
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
return fix_image_flip_shape(image, array_ops.reverse(image, [1]))
@@ -295,7 +298,8 @@ def flip_up_down(image):
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
return fix_image_flip_shape(image, array_ops.reverse(image, [0]))
@@ -312,7 +316,8 @@ def rot90(image, k=1, name=None):
"""
with ops.name_scope(name, 'rot90', [image, k]) as scope:
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k')
k.get_shape().assert_has_rank(0)
k = math_ops.mod(k, 4)
@@ -350,7 +355,8 @@ def transpose_image(image):
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
return array_ops.transpose(image, [1, 0, 2], name='transpose_image')
@@ -379,12 +385,14 @@ def central_crop(image, central_fraction):
3-D float Tensor
"""
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
if central_fraction <= 0.0 or central_fraction > 1.0:
raise ValueError('central_fraction must be within (0, 1]')
if central_fraction == 1.0:
return image
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
+
img_shape = array_ops.shape(image)
depth = image.get_shape()[2]
fraction_offset = int(1 / ((1 - central_fraction) / 2.0))
@@ -435,9 +443,6 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
"""
image = ops.convert_to_tensor(image, name='image')
- assert_ops = []
- assert_ops += _CheckAtLeast3DImage(image, require_static=False)
-
is_batch = True
image_shape = image.get_shape()
if image_shape.ndims == 3:
@@ -450,6 +455,8 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
elif image_shape.ndims != 4:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
+ assert_ops = _CheckAtLeast3DImage(image, require_static=False)
+
batch, height, width, depth = _ImageDimensions(image, rank=4)
after_padding_width = target_width - offset_width - width
@@ -515,9 +522,6 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
"""
image = ops.convert_to_tensor(image, name='image')
- assert_ops = []
- assert_ops += _CheckAtLeast3DImage(image, require_static=False)
-
is_batch = True
image_shape = image.get_shape()
if image_shape.ndims == 3:
@@ -530,6 +534,8 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
elif image_shape.ndims != 4:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
+ assert_ops = _CheckAtLeast3DImage(image, require_static=False)
+
batch, height, width, depth = _ImageDimensions(image, rank=4)
assert_ops += _assert(offset_width >= 0, ValueError,
@@ -602,8 +608,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
elif image_shape.ndims != 4:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
- assert_ops = []
- assert_ops += _CheckAtLeast3DImage(image, require_static=False)
+ assert_ops = _CheckAtLeast3DImage(image, require_static=False)
assert_ops += _assert(target_width > 0, ValueError,
'target_width must be > 0.')
assert_ops += _assert(target_height > 0, ValueError,
@@ -800,7 +805,8 @@ def per_image_standardization(image):
ValueError: if the shape of 'image' is incompatible with this function.
"""
image = ops.convert_to_tensor(image, name='image')
- _Check3DImage(image, require_static=False)
+ image = control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
num_pixels = math_ops.reduce_prod(array_ops.shape(image))
image = math_ops.cast(image, dtype=dtypes.float32)
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 716d767b0c..1a70d46507 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -1175,12 +1175,7 @@ class CropToBoundingBoxTest(test_util.TensorFlowTestCase):
offset_height, offset_width = [0, 0]
target_height, target_width = [2, 2]
- for x_shape in ([3, 5],):
- self._assertRaises(x, x_shape, offset_height, offset_width, target_height,
- target_width,
- "'image' must be at least three-dimensional.")
-
- for x_shape in ([1, 3, 5, 1, 1],):
+ for x_shape in ([3, 5], [1, 3, 5, 1, 1]):
self._assertRaises(x, x_shape, offset_height, offset_width, target_height,
target_width,
"'image' must have either 3 or 4 dimensions.")
@@ -1426,12 +1421,7 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase):
offset_height, offset_width = [0, 0]
target_height, target_width = [2, 2]
- for x_shape in ([3, 5],):
- self._assertRaises(x, x_shape, offset_height, offset_width, target_height,
- target_width,
- "'image' must be at least three-dimensional")
-
- for x_shape in ([1, 3, 5, 1, 1],):
+ for x_shape in ([3, 5], [1, 3, 5, 1, 1]):
self._assertRaises(x, x_shape, offset_height, offset_width, target_height,
target_width,
"'image' must have either 3 or 4 dimensions.")
diff --git a/tensorflow/python/ops/resources.py b/tensorflow/python/ops/resources.py
index 41fb8a74a9..57ba0084e8 100644
--- a/tensorflow/python/ops/resources.py
+++ b/tensorflow/python/ops/resources.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import tf_should_use
_Resource = collections.namedtuple("_Resource",
@@ -98,6 +99,7 @@ def report_uninitialized_resources(resource_list=None,
return array_ops.boolean_mask(variable_names_tensor, variables_mask)
+@tf_should_use.should_use_result
def initialize_resources(resource_list, name="init"):
"""Initializes the resources in the given list.
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index d1013c53dd..b1c7d74a0c 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import tf_should_use
def _maybe_set_device(handle_op, value_t):
@@ -252,6 +253,7 @@ class TensorArray(object):
value.set_shape(self._element_shape[0].dims)
return value
+ @tf_should_use.should_use_result
def write(self, index, value, name=None):
"""Write `value` into index `index` of the TensorArray.
@@ -358,6 +360,7 @@ class TensorArray(object):
value.set_shape([None] + self._element_shape[0].dims[1:])
return value
+ @tf_should_use.should_use_result
def unstack(self, value, name=None):
"""Unstack the values of a `Tensor` in the TensorArray.
@@ -380,6 +383,7 @@ class TensorArray(object):
return self.scatter(
indices=math_ops.range(0, num_elements), value=value, name=name)
+ @tf_should_use.should_use_result
def scatter(self, indices, value, name=None):
"""Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
@@ -418,6 +422,7 @@ class TensorArray(object):
ta._merge_element_shape(element_shape)
return ta
+ @tf_should_use.should_use_result
def split(self, value, lengths, name=None):
"""Split the values of a `Tensor` into the TensorArray.
@@ -466,6 +471,7 @@ class TensorArray(object):
return gen_data_flow_ops._tensor_array_size_v3(
handle=self._handle, flow_in=self.flow, name=name)
+ @tf_should_use.should_use_result
def close(self, name=None):
"""Close the current TensorArray."""
with ops.colocate_with(self._handle):
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 8b508e45a4..33523f1a71 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated
@@ -1152,6 +1153,7 @@ def variables_initializer(var_list, name="init"):
return control_flow_ops.no_op(name=name)
+@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
def initialize_variables(var_list, name="init"):
"""See `tf.variables_initializer`."""
@@ -1169,6 +1171,7 @@ def global_variables_initializer():
return variables_initializer(global_variables())
+@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
def initialize_all_variables():
"""See `tf.global_variables_initializer`."""
@@ -1186,12 +1189,14 @@ def local_variables_initializer():
return variables_initializer(local_variables())
+@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
def initialize_local_variables():
"""See `tf.local_variables_initializer`."""
return local_variables_initializer()
+@tf_should_use.should_use_result
def is_variable_initialized(variable):
"""Tests if a variable has been initialized.
@@ -1205,6 +1210,7 @@ def is_variable_initialized(variable):
return state_ops.is_variable_initialized(variable)
+@tf_should_use.should_use_result
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
@@ -1246,6 +1252,7 @@ def assert_variables_initialized(var_list=None):
return array_ops.stack(ranks)
+@tf_should_use.should_use_result
def report_uninitialized_variables(var_list=None,
name="report_uninitialized_variables"):
"""Adds ops to list the names of uninitialized variables.
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py
index a6a1ad4892..88df3351e6 100644
--- a/tensorflow/python/util/tf_should_use.py
+++ b/tensorflow/python/util/tf_should_use.py
@@ -22,6 +22,7 @@ import traceback
import types
from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import tf_decorator
def _add_should_use_warning(x, fatal_error=False):
@@ -36,6 +37,13 @@ def _add_should_use_warning(x, fatal_error=False):
An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)`
and is a very shallow wrapper for `x` which logs access into `x`.
"""
+ if x is None: # special corner case where x is None
+ return x
+ has_been_used = getattr(x, '_tf_object_has_been_used', None)
+ if has_been_used is not None:
+ x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access
+ return x
+
def override_method(method):
def fn(self, *args, **kwargs):
self._tf_object_has_been_used = True # pylint: disable=protected-access
@@ -67,18 +75,27 @@ def _add_should_use_warning(x, fatal_error=False):
else:
logger = tf_logging.error
logger(
- 'Object was never used: %s.\nIt was originally created here:\n%s'
- % (self, self._tf_object_creation_stack))
+ '==================================\n'
+ 'Object was never used (type %s):\n%s\nIf you want to mark it as '
+ 'used call its "mark_used()" method.\nIt was originally created '
+ 'here:\n%s\n'
+ '==================================' %
+ (type(x), x, self._tf_object_creation_stack))
if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'):
return super(TFShouldUseWarningWrapper, self).__del__()
+
+ def mark_used(self, *args, **kwargs):
+ self._tf_object_has_been_used = True
+ if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
+ return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
# pylint: enable=super-on-old-class
for name in dir(TFShouldUseWarningWrapper):
method = getattr(TFShouldUseWarningWrapper, name)
if not isinstance(method, types.FunctionType):
continue
- if name in ('__init__', '__getattribute__', '__del__'):
+ if name in ('__init__', '__getattribute__', '__del__', 'mark_used'):
continue
setattr(TFShouldUseWarningWrapper, name,
functools.wraps(method)(override_method(method)))
@@ -114,7 +131,13 @@ def should_use_result(fn):
"""
def wrapped(*args, **kwargs):
return _add_should_use_warning(fn(*args, **kwargs))
- return functools.wraps(fn)(wrapped)
+ return tf_decorator.make_decorator(
+ fn, wrapped, 'should_use_result',
+ ((fn.__doc__ or '') +
+ ('\n\n '
+ '**NOTE** The output of this function should be used. If it is not, '
+ 'a warning will be logged. To mark the output as used, '
+ 'call its .mark_used() method.')))
def must_use_result_or_fatal(fn):
@@ -142,4 +165,10 @@ def must_use_result_or_fatal(fn):
"""
def wrapped(*args, **kwargs):
return _add_should_use_warning(fn(*args, **kwargs), fatal_error=True)
- return functools.wraps(fn)(wrapped)
+ return tf_decorator.make_decorator(
+ fn, wrapped, 'must_use_result_or_fatal',
+ ((fn.__doc__ or '') +
+ ('\n\n '
+ '**NOTE** The output of this function must be used. If it is not, '
+ 'a fatal error will be raised. To mark the output as used, '
+ 'call its .mark_used() method.')))
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index 09130eed3a..71d48e3dde 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -52,7 +52,7 @@ class TfShouldUseTest(test.TestCase):
h = tf_should_use._add_should_use_warning(c)
del h
in_this_function()
- self.assertIn('Object was never used:', '\n'.join(captured))
+ self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('in_this_function', '\n'.join(captured))
@@ -63,7 +63,7 @@ class TfShouldUseTest(test.TestCase):
h = tf_should_use._add_should_use_warning(c)
fn(h)
del h
- self.assertNotIn('Object was never used:', '\n'.join(captured))
+ self.assertNotIn('Object was never used', '\n'.join(captured))
self.assertNotIn('blah:0', '\n'.join(captured))
def testAddShouldUseWarningWhenUsedWithAdd(self):
@@ -83,7 +83,7 @@ class TfShouldUseTest(test.TestCase):
captured = []
with reroute_error(captured):
return_const(0.0)
- self.assertIn('Object was never used:', '\n'.join(captured))
+ self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))
@@ -99,7 +99,7 @@ class TfShouldUseTest(test.TestCase):
# unused op as being "used".
v = constant_op.constant(1.0, name='meh')
v.eval()
- self.assertIn('Object was never used:', '\n'.join(captured))
+ self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))