aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-03-15 11:02:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-15 13:02:40 -0700
commit9a8c5ad18c61cb0695d31e2ce969008c82999c7c (patch)
tree9691b3e87d136aa7b653f45c1949678245d989e2
parentab3119e565e762828a5a818de8cdab61e5056fec (diff)
Prevent the feeding of tensors whose values are used to calculate shapes.
This change prevents feeding a tensor if its constant value has been accessed. For example, the constant value is often used in shape functions (e.g. in `tf.reshape(data, indices)`, `indices` is a tensor, but it is often constant) in order to infer more precise shapes. It is also used in the `tf.image.*` to generate simpler, more efficient graphs. However, doing so is potentially unsafe, because the tensor can be fed with a different value, which invalidates the previously obtained constant value, and can lead to subtle bugs. IF THIS BREAKS YOU ------------------ Replace the tensor that you are feeding with a `tf.placeholder()` of the appropriate dtype and shape. Change: 117263031
-rw-r--r--tensorflow/core/kernels/identity_op.cc4
-rw-r--r--tensorflow/core/ops/array_ops.cc15
-rw-r--r--tensorflow/core/ops/compat/ops_history.v0.pbtxt19
-rw-r--r--tensorflow/core/ops/ops.pbtxt24
-rw-r--r--tensorflow/python/client/session.py13
-rw-r--r--tensorflow/python/client/session_test.py18
-rw-r--r--tensorflow/python/framework/ops.py10
-rw-r--r--tensorflow/python/framework/tensor_shape.py11
-rw-r--r--tensorflow/python/framework/tensor_util.py54
-rw-r--r--tensorflow/python/kernel_tests/cast_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py34
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py4
-rw-r--r--tensorflow/python/ops/array_ops.py22
-rw-r--r--tensorflow/python/ops/io_ops.py1
14 files changed, 196 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc
index dabffc583a..40d72c22c2 100644
--- a/tensorflow/core/kernels/identity_op.cc
+++ b/tensorflow/core/kernels/identity_op.cc
@@ -27,6 +27,10 @@ REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_CPU), IdentityOp);
// StopGradient does the same thing as Identity, but has a different
// gradient registered.
REGISTER_KERNEL_BUILDER(Name("StopGradient").Device(DEVICE_CPU), IdentityOp);
+// PlaceholderWithDefault does the same thing as Identity, but has a
+// different shape function (and constant value function) registered.
+REGISTER_KERNEL_BUILDER(Name("PlaceholderWithDefault").Device(DEVICE_CPU),
+ IdentityOp);
REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp);
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index c56af21814..2eadd8364a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1005,6 +1005,21 @@ shape: (Optional) The shape of the tensor. If the shape has 0 dimensions, the
)doc");
// --------------------------------------------------------------------------
+REGISTER_OP("PlaceholderWithDefault")
+ .Input("input: dtype")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .Doc(R"doc(
+A placeholder op that passes though `input` when its output is not fed.
+
+input: The default value to produce when `output` is not fed.
+output: A placeholder tensor that defaults to `input` if it is not fed.
+dtype: The type of elements in the tensor.
+shape: The (possibly partial) shape of the tensor.
+)doc");
+
+// --------------------------------------------------------------------------
REGISTER_OP("ExpandDims")
.Input("input: T")
.Input("dim: int32")
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index 7970b94be2..7d4b14e965 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -7511,6 +7511,25 @@ op {
}
}
op {
+ name: "PlaceholderWithDefault"
+ input_arg {
+ name: "input"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+}
+op {
name: "Pow"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 121c03bd6d..130c2e4156 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -6274,6 +6274,30 @@ op {
description: "N.B. This operation will fail with an error if it is executed. It is\nintended as a way to represent a value that will always be fed, and to\nprovide attrs that enable the fed value to be checked at runtime."
}
op {
+ name: "PlaceholderWithDefault"
+ input_arg {
+ name: "input"
+ description: "The default value to produce when `output` is not fed."
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ description: "A placeholder tensor that defaults to `input` if it is not fed."
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ description: "The type of elements in the tensor."
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ description: "The (possibly partial) shape of the tensor."
+ }
+ summary: "A placeholder op that passes though `input` when its output is not fed."
+}
+op {
name: "Pow"
input_arg {
name: "x"
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 80cc504427..a77cdffda5 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -523,12 +523,13 @@ class BaseSession(SessionInterface):
'strings, lists, or numpy ndarrays.')
np_val = np.array(subfeed_val, dtype=subfeed_t.dtype.as_numpy_dtype)
- if subfeed_t.op.type == 'Placeholder':
- if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
- raise ValueError(
- 'Cannot feed value of shape %r for Tensor %r, '
- 'which has shape %r'
- % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
+ if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
+ raise ValueError(
+ 'Cannot feed value of shape %r for Tensor %r, '
+ 'which has shape %r'
+ % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
+ if not self.graph.is_feedable(subfeed_t):
+ raise ValueError('Tensor %s may not be fed.' % subfeed_t)
feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val
# Run request and get response.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 549ada459c..55868328ff 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -70,7 +70,7 @@ class SessionTest(test_util.TensorFlowTestCase):
def testCreate(self):
with session.Session():
- inp = constant_op.constant(10.0, name='W1')
+ inp = constant_op.constant(10.0, shape=[2, 3], name='W1')
copy = array_ops.identity(inp)
# Test with feed.
# TODO(mrry): Investigate why order='F' didn't work.
@@ -79,7 +79,8 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(arr, copy_val)
# Test without feed.
copy_val = copy.eval()
- self.assertAllEqual(np.asarray(10.0, dtype=np.float32), copy_val)
+ self.assertAllEqual(np.asarray([[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]],
+ dtype=np.float32), copy_val)
def testManyCPUs(self):
# TODO(keveman): Implement ListDevices and test for the number of
@@ -931,5 +932,18 @@ class SessionTest(test_util.TensorFlowTestCase):
step_stats.CopyFrom(run_outputs.step_stats)
self.assertEquals(len(step_stats.dev_stats), 1)
+ def testFeedShapeCompatibility(self):
+ with session.Session() as sess:
+ some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
+ new_shape = constant_op.constant([2, 2])
+ reshaped_tensor = array_ops.reshape(some_tensor, new_shape)
+
+ with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'):
+ sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]})
+
+ with self.assertRaisesRegexp(ValueError, 'may not be fed'):
+ sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]})
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index d7ba56c479..018a6d1ca9 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1842,6 +1842,8 @@ class Graph(object):
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
# Stack of colocate_with ops
self._colocation_stack = []
+ # Set of tensors that are dangerous to feed!
+ self._unfeedable_tensors = set()
def _check_not_finalized(self):
"""Check if the graph is finalized.
@@ -3045,6 +3047,14 @@ class Graph(object):
del self._gradient_override_map[op_type]
# pylint: enable=g-doc-return-or-yield
+ def prevent_feeding(self, tensor):
+ """Marks the given `tensor` as unfeedable in this graph."""
+ self._unfeedable_tensors.add(tensor)
+
+ def is_feedable(self, tensor):
+ """Returns `True` if and only if `tensor` is feedable."""
+ return tensor not in self._unfeedable_tensors
+
def device(device_name_or_function):
"""Wrapper for `Graph.device()` using the default graph.
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index aafe32c3ec..e8643109bc 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -411,10 +411,13 @@ class TensorShape(object):
if dims is None:
self._dims = None
elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
- self._dims = [
- # Protos store variable-size dimensions as -1
- as_dimension(dim.size if dim.size != -1 else None)
- for dim in dims.dim]
+ if dims.unknown_rank:
+ self._dims = None
+ else:
+ self._dims = [
+ # Protos store variable-size dimensions as -1
+ as_dimension(dim.size if dim.size != -1 else None)
+ for dim in dims.dim]
else:
try:
dims_iter = iter(dims)
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index ae4e73a363..b1b39f0651 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -487,26 +487,7 @@ def ShapeEquals(tensor_proto, shape):
return all(x == y for x, y in zip(tensor_shape_list, shape))
-def constant_value(tensor):
- """Returns the constant value of the given tensor, if efficiently calculable.
-
- This function attempts to partially evaluate the given tensor, and
- returns its value as a numpy ndarray if this succeeds.
-
- TODO(mrry): Consider whether this function should use a registration
- mechanism like gradients and ShapeFunctions, so that it is easily
- extensible.
-
- Args:
- tensor: The Tensor to be evaluated.
-
- Returns:
- A numpy ndarray containing the constant value of the given `tensor`,
- or None if it cannot be calculated.
-
- Raises:
- TypeError: if tensor is not an ops.Tensor.
- """
+def _ConstantValue(tensor):
# TODO(touts): Support Variables?
if not isinstance(tensor, ops.Tensor):
raise TypeError("tensor is not a Tensor")
@@ -561,3 +542,36 @@ def constant_value(tensor):
return np.concatenate(values, axis=dim)
else:
return None
+
+
+def constant_value(tensor):
+ """Returns the constant value of the given tensor, if efficiently calculable.
+
+ This function attempts to partially evaluate the given tensor, and
+ returns its value as a numpy ndarray if this succeeds.
+
+ TODO(mrry): Consider whether this function should use a registration
+ mechanism like gradients and ShapeFunctions, so that it is easily
+ extensible.
+
+ NOTE: If `constant_value(tensor)` returns a non-`None` result, it will no
+ longer be possible to feed a different value for `tensor`. This allows the
+ result of this function to influence the graph that is constructed, and
+ permits static shape optimizations.
+
+ Args:
+ tensor: The Tensor to be evaluated.
+
+ Returns:
+ A numpy ndarray containing the constant value of the given `tensor`,
+ or None if it cannot be calculated.
+
+ Raises:
+ TypeError: if tensor is not an ops.Tensor.
+ """
+ ret = _ConstantValue(tensor)
+ if ret is not None:
+ # The caller may now depend on the constant value of `tensor`, so we
+ # conservatively prevent it from being fed.
+ tensor.graph.prevent_feeding(tensor)
+ return ret
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index 3a0fe60344..e89d6d5c9c 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -156,7 +156,7 @@ class CastOpTest(tf.test.TestCase):
x = tf.constant(1.0, src_t)
z = tf.identity(x)
y = tf.cast(z, dst_t)
- err = tf.test.compute_gradient_error(x, [1], y, [1])
+ err = tf.test.compute_gradient_error(x, [], y, [])
self.assertLess(err, 1e-3)
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index d93020e825..2a49e3581d 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -560,5 +560,39 @@ class PlaceholderTest(tf.test.TestCase):
"<tf.Tensor 'c:0' shape=(32, ?, 2) dtype=qint32>",
repr(c))
+
+class PlaceholderWithDefaultTest(tf.test.TestCase):
+
+ def testFullShape(self):
+ with self.test_session():
+ p = tf.placeholder_with_default([[2, 2], [2, 2]], shape=[2, 2])
+ a = tf.identity(p)
+ self.assertAllEqual([[2, 2], [2, 2]], a.eval())
+ self.assertAllEqual([[3, 3], [3, 3]],
+ a.eval(feed_dict={p: [[3, 3], [3, 3]]}))
+
+ with self.assertRaises(ValueError):
+ a.eval(feed_dict={p: [[6, 6, 6], [6, 6, 6]]})
+
+ def testPartialShape(self):
+ with self.test_session():
+ p = tf.placeholder_with_default([1, 2, 3], shape=[None])
+ a = tf.identity(p)
+ self.assertAllEqual([1, 2, 3], a.eval())
+ self.assertAllEqual([3, 37], a.eval(feed_dict={p: [3, 37]}))
+
+ with self.assertRaises(ValueError):
+ a.eval(feed_dict={p: [[2, 2], [2, 2]]})
+
+ def testNoShape(self):
+ with self.test_session():
+ p = tf.placeholder_with_default([17], shape=None)
+ a = tf.identity(p)
+ self.assertAllEqual([17], a.eval())
+ self.assertAllEqual([3, 37], a.eval(feed_dict={p: [3, 37]}))
+ self.assertAllEqual([[3, 3], [3, 3]],
+ a.eval(feed_dict={p: [[3, 3], [3, 3]]}))
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index 04f7f8fa7c..d27106afdf 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -72,10 +72,10 @@ class ReshapeTest(tf.test.TestCase):
# reports errors.
def testFloatReshapeGradThreeDimensions(self):
- x = np.arange(1., 25.).reshape([1, 24]).astype(np.float32)
+ x = np.arange(1., 25.).reshape([2, 3, 4]).astype(np.float32)
s = list(np.shape(x))
with self.test_session():
- input_tensor = tf.constant(x, shape=[2, 3, 4])
+ input_tensor = tf.constant(x)
reshape_out = tf.reshape(input_tensor, [1, 8, 3])
err = tf.test.compute_gradient_error(input_tensor,
s,
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 00f35510e9..ccf38d5be1 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1596,3 +1596,25 @@ def _OneHotShape(op):
new_shape.insert(axis % (indices_dims + 1), depth)
return [tensor_shape.TensorShape(new_shape)]
+
+
+@ops.RegisterShape("PlaceholderWithDefault")
+def _PlaceholderWithDefaultShape(op):
+ """Shape function for the PlaceholderWithDefault op.
+
+ This op acts as an identity when it is not fed (passing through a
+ default value), but allows the user to feed it with tensors of a
+ possibly less precise shape than its default value.
+
+ Args:
+ op: A PlaceholderWithDefault `Operation`.
+
+ Returns:
+ A single-element list containing the shape of the output.
+ """
+ input_shape = op.inputs[0].get_shape()
+ output_shape = tensor_shape.TensorShape(op.get_attr("shape"))
+ # NOTE(mrry): We don't merge these shapes, because `output_shape`
+ # may be *less* precise than `input_shape`.
+ input_shape.assert_is_compatible_with(output_shape)
+ return [output_shape]
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 3e6a1a5951..15cf4736a3 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -21,6 +21,7 @@ on execution. For more info, see the section on [Feeding
data](../../how_tos/reading_data/index.md#feeding).
@@placeholder
+@@placeholder_with_default
## Readers