From 9a8c5ad18c61cb0695d31e2ce969008c82999c7c Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 15 Mar 2016 11:02:22 -0800 Subject: Prevent the feeding of tensors whose values are used to calculate shapes. This change prevents feeding a tensor if its constant value has been accessed. For example, the constant value is often used in shape functions (e.g. in `tf.reshape(data, indices)`, `indices` is a tensor, but it is often constant) in order to infer more precise shapes. It is also used in the `tf.image.*` to generate simpler, more efficient graphs. However, doing so is potentially unsafe, because the tensor can be fed with a different value, which invalidates the previously obtained constant value, and can lead to subtle bugs. IF THIS BREAKS YOU ------------------ Replace the tensor that you are feeding with a `tf.placeholder()` of the appropriate dtype and shape. Change: 117263031 --- tensorflow/core/kernels/identity_op.cc | 4 ++ tensorflow/core/ops/array_ops.cc | 15 ++++++ tensorflow/core/ops/compat/ops_history.v0.pbtxt | 19 ++++++++ tensorflow/core/ops/ops.pbtxt | 24 ++++++++++ tensorflow/python/client/session.py | 13 +++--- tensorflow/python/client/session_test.py | 18 +++++++- tensorflow/python/framework/ops.py | 10 ++++ tensorflow/python/framework/tensor_shape.py | 11 +++-- tensorflow/python/framework/tensor_util.py | 54 ++++++++++++++-------- tensorflow/python/kernel_tests/cast_op_test.py | 2 +- tensorflow/python/kernel_tests/constant_op_test.py | 34 ++++++++++++++ tensorflow/python/kernel_tests/reshape_op_test.py | 4 +- tensorflow/python/ops/array_ops.py | 22 +++++++++ tensorflow/python/ops/io_ops.py | 1 + 14 files changed, 196 insertions(+), 35 deletions(-) diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc index dabffc583a..40d72c22c2 100644 --- a/tensorflow/core/kernels/identity_op.cc +++ b/tensorflow/core/kernels/identity_op.cc @@ -27,6 +27,10 @@ REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_CPU), IdentityOp); // StopGradient does the same thing as Identity, but has a different // gradient registered. REGISTER_KERNEL_BUILDER(Name("StopGradient").Device(DEVICE_CPU), IdentityOp); +// PlaceholderWithDefault does the same thing as Identity, but has a +// different shape function (and constant value function) registered. +REGISTER_KERNEL_BUILDER(Name("PlaceholderWithDefault").Device(DEVICE_CPU), + IdentityOp); REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index c56af21814..2eadd8364a 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1004,6 +1004,21 @@ shape: (Optional) The shape of the tensor. If the shape has 0 dimensions, the shape is unconstrained. )doc"); +// -------------------------------------------------------------------------- +REGISTER_OP("PlaceholderWithDefault") + .Input("input: dtype") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("shape: shape") + .Doc(R"doc( +A placeholder op that passes though `input` when its output is not fed. + +input: The default value to produce when `output` is not fed. +output: A placeholder tensor that defaults to `input` if it is not fed. +dtype: The type of elements in the tensor. +shape: The (possibly partial) shape of the tensor. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("ExpandDims") .Input("input: T") diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index 7970b94be2..7d4b14e965 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -7510,6 +7510,25 @@ op { } } } +op { + name: "PlaceholderWithDefault" + input_arg { + name: "input" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + } +} op { name: "Pow" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 121c03bd6d..130c2e4156 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -6273,6 +6273,30 @@ op { summary: "A placeholder op for a value that will be fed into the computation." description: "N.B. This operation will fail with an error if it is executed. It is\nintended as a way to represent a value that will always be fed, and to\nprovide attrs that enable the fed value to be checked at runtime." } +op { + name: "PlaceholderWithDefault" + input_arg { + name: "input" + description: "The default value to produce when `output` is not fed." + type_attr: "dtype" + } + output_arg { + name: "output" + description: "A placeholder tensor that defaults to `input` if it is not fed." + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + description: "The type of elements in the tensor." + } + attr { + name: "shape" + type: "shape" + description: "The (possibly partial) shape of the tensor." + } + summary: "A placeholder op that passes though `input` when its output is not fed." +} op { name: "Pow" input_arg { diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 80cc504427..a77cdffda5 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -523,12 +523,13 @@ class BaseSession(SessionInterface): 'strings, lists, or numpy ndarrays.') np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype) - if subfeed_t.op.type == 'Placeholder': - if not subfeed_t.get_shape().is_compatible_with(np_val.shape): - raise ValueError( - 'Cannot feed value of shape %r for Tensor %r, ' - 'which has shape %r' - % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) + if not subfeed_t.get_shape().is_compatible_with(np_val.shape): + raise ValueError( + 'Cannot feed value of shape %r for Tensor %r, ' + 'which has shape %r' + % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) + if not self.graph.is_feedable(subfeed_t): + raise ValueError('Tensor %s may not be fed.' % subfeed_t) feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val # Run request and get response. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 549ada459c..55868328ff 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -70,7 +70,7 @@ class SessionTest(test_util.TensorFlowTestCase): def testCreate(self): with session.Session(): - inp = constant_op.constant(10.0, name='W1') + inp = constant_op.constant(10.0, shape=[2, 3], name='W1') copy = array_ops.identity(inp) # Test with feed. # TODO(mrry): Investigate why order='F' didn't work. @@ -79,7 +79,8 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual(arr, copy_val) # Test without feed. copy_val = copy.eval() - self.assertAllEqual(np.asarray(10.0, dtype=np.float32), copy_val) + self.assertAllEqual(np.asarray([[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], + dtype=np.float32), copy_val) def testManyCPUs(self): # TODO(keveman): Implement ListDevices and test for the number of @@ -931,5 +932,18 @@ class SessionTest(test_util.TensorFlowTestCase): step_stats.CopyFrom(run_outputs.step_stats) self.assertEquals(len(step_stats.dev_stats), 1) + def testFeedShapeCompatibility(self): + with session.Session() as sess: + some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) + new_shape = constant_op.constant([2, 2]) + reshaped_tensor = array_ops.reshape(some_tensor, new_shape) + + with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'): + sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) + + with self.assertRaisesRegexp(ValueError, 'may not be fed'): + sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index d7ba56c479..018a6d1ca9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1842,6 +1842,8 @@ class Graph(object): min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) # Stack of colocate_with ops self._colocation_stack = [] + # Set of tensors that are dangerous to feed! + self._unfeedable_tensors = set() def _check_not_finalized(self): """Check if the graph is finalized. @@ -3045,6 +3047,14 @@ class Graph(object): del self._gradient_override_map[op_type] # pylint: enable=g-doc-return-or-yield + def prevent_feeding(self, tensor): + """Marks the given `tensor` as unfeedable in this graph.""" + self._unfeedable_tensors.add(tensor) + + def is_feedable(self, tensor): + """Returns `True` if and only if `tensor` is feedable.""" + return tensor not in self._unfeedable_tensors + def device(device_name_or_function): """Wrapper for `Graph.device()` using the default graph. diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index aafe32c3ec..e8643109bc 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -411,10 +411,13 @@ class TensorShape(object): if dims is None: self._dims = None elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): - self._dims = [ - # Protos store variable-size dimensions as -1 - as_dimension(dim.size if dim.size != -1 else None) - for dim in dims.dim] + if dims.unknown_rank: + self._dims = None + else: + self._dims = [ + # Protos store variable-size dimensions as -1 + as_dimension(dim.size if dim.size != -1 else None) + for dim in dims.dim] else: try: dims_iter = iter(dims) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index ae4e73a363..b1b39f0651 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -487,26 +487,7 @@ def ShapeEquals(tensor_proto, shape): return all(x == y for x, y in zip(tensor_shape_list, shape)) -def constant_value(tensor): - """Returns the constant value of the given tensor, if efficiently calculable. - - This function attempts to partially evaluate the given tensor, and - returns its value as a numpy ndarray if this succeeds. - - TODO(mrry): Consider whether this function should use a registration - mechanism like gradients and ShapeFunctions, so that it is easily - extensible. - - Args: - tensor: The Tensor to be evaluated. - - Returns: - A numpy ndarray containing the constant value of the given `tensor`, - or None if it cannot be calculated. - - Raises: - TypeError: if tensor is not an ops.Tensor. - """ +def _ConstantValue(tensor): # TODO(touts): Support Variables? if not isinstance(tensor, ops.Tensor): raise TypeError("tensor is not a Tensor") @@ -561,3 +542,36 @@ def constant_value(tensor): return np.concatenate(values, axis=dim) else: return None + + +def constant_value(tensor): + """Returns the constant value of the given tensor, if efficiently calculable. + + This function attempts to partially evaluate the given tensor, and + returns its value as a numpy ndarray if this succeeds. + + TODO(mrry): Consider whether this function should use a registration + mechanism like gradients and ShapeFunctions, so that it is easily + extensible. + + NOTE: If `constant_value(tensor)` returns a non-`None` result, it will no + longer be possible to feed a different value for `tensor`. This allows the + result of this function to influence the graph that is constructed, and + permits static shape optimizations. + + Args: + tensor: The Tensor to be evaluated. + + Returns: + A numpy ndarray containing the constant value of the given `tensor`, + or None if it cannot be calculated. + + Raises: + TypeError: if tensor is not an ops.Tensor. + """ + ret = _ConstantValue(tensor) + if ret is not None: + # The caller may now depend on the constant value of `tensor`, so we + # conservatively prevent it from being fed. + tensor.graph.prevent_feeding(tensor) + return ret diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py index 3a0fe60344..e89d6d5c9c 100644 --- a/tensorflow/python/kernel_tests/cast_op_test.py +++ b/tensorflow/python/kernel_tests/cast_op_test.py @@ -156,7 +156,7 @@ class CastOpTest(tf.test.TestCase): x = tf.constant(1.0, src_t) z = tf.identity(x) y = tf.cast(z, dst_t) - err = tf.test.compute_gradient_error(x, [1], y, [1]) + err = tf.test.compute_gradient_error(x, [], y, []) self.assertLess(err, 1e-3) diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index d93020e825..2a49e3581d 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -560,5 +560,39 @@ class PlaceholderTest(tf.test.TestCase): "", repr(c)) + +class PlaceholderWithDefaultTest(tf.test.TestCase): + + def testFullShape(self): + with self.test_session(): + p = tf.placeholder_with_default([[2, 2], [2, 2]], shape=[2, 2]) + a = tf.identity(p) + self.assertAllEqual([[2, 2], [2, 2]], a.eval()) + self.assertAllEqual([[3, 3], [3, 3]], + a.eval(feed_dict={p: [[3, 3], [3, 3]]})) + + with self.assertRaises(ValueError): + a.eval(feed_dict={p: [[6, 6, 6], [6, 6, 6]]}) + + def testPartialShape(self): + with self.test_session(): + p = tf.placeholder_with_default([1, 2, 3], shape=[None]) + a = tf.identity(p) + self.assertAllEqual([1, 2, 3], a.eval()) + self.assertAllEqual([3, 37], a.eval(feed_dict={p: [3, 37]})) + + with self.assertRaises(ValueError): + a.eval(feed_dict={p: [[2, 2], [2, 2]]}) + + def testNoShape(self): + with self.test_session(): + p = tf.placeholder_with_default([17], shape=None) + a = tf.identity(p) + self.assertAllEqual([17], a.eval()) + self.assertAllEqual([3, 37], a.eval(feed_dict={p: [3, 37]})) + self.assertAllEqual([[3, 3], [3, 3]], + a.eval(feed_dict={p: [[3, 3], [3, 3]]})) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py index 04f7f8fa7c..d27106afdf 100644 --- a/tensorflow/python/kernel_tests/reshape_op_test.py +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -72,10 +72,10 @@ class ReshapeTest(tf.test.TestCase): # reports errors. def testFloatReshapeGradThreeDimensions(self): - x = np.arange(1., 25.).reshape([1, 24]).astype(np.float32) + x = np.arange(1., 25.).reshape([2, 3, 4]).astype(np.float32) s = list(np.shape(x)) with self.test_session(): - input_tensor = tf.constant(x, shape=[2, 3, 4]) + input_tensor = tf.constant(x) reshape_out = tf.reshape(input_tensor, [1, 8, 3]) err = tf.test.compute_gradient_error(input_tensor, s, diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 00f35510e9..ccf38d5be1 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1596,3 +1596,25 @@ def _OneHotShape(op): new_shape.insert(axis % (indices_dims + 1), depth) return [tensor_shape.TensorShape(new_shape)] + + +@ops.RegisterShape("PlaceholderWithDefault") +def _PlaceholderWithDefaultShape(op): + """Shape function for the PlaceholderWithDefault op. + + This op acts as an identity when it is not fed (passing through a + default value), but allows the user to feed it with tensors of a + possibly less precise shape than its default value. + + Args: + op: A PlaceholderWithDefault `Operation`. + + Returns: + A single-element list containing the shape of the output. + """ + input_shape = op.inputs[0].get_shape() + output_shape = tensor_shape.TensorShape(op.get_attr("shape")) + # NOTE(mrry): We don't merge these shapes, because `output_shape` + # may be *less* precise than `input_shape`. + input_shape.assert_is_compatible_with(output_shape) + return [output_shape] diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 3e6a1a5951..15cf4736a3 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -21,6 +21,7 @@ on execution. For more info, see the section on [Feeding data](../../how_tos/reading_data/index.md#feeding). @@placeholder +@@placeholder_with_default ## Readers -- cgit v1.2.3