diff options
author | Rachel Lim <rachelim@google.com> | 2018-08-28 11:07:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 11:12:51 -0700 |
commit | cf879ef5dddee8d1b5081afe5bd8f49f15245d08 (patch) | |
tree | 7da03b5946bda44edd6b2220ea0aeb2a4f32e9a4 /tensorflow | |
parent | 8f99e5ad11040a6f0b5c12648e98bdbfe4dc3970 (diff) |
Adds a tf.ensure_shape function as a substitute for tensor.set_shape, which validates the true shape of the tensor at runtime.
PiperOrigin-RevId: 210570878
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt | 26 | ||||
-rw-r--r-- | tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/shape_ops.cc | 93 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 24 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 5 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/check_ops_test.py | 146 | ||||
-rw-r--r-- | tensorflow/python/ops/check_ops.py | 49 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 4 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/v2/tensorflow.pbtxt | 4 |
10 files changed, 357 insertions, 0 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt new file mode 100644 index 0000000000..1658472209 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt @@ -0,0 +1,26 @@ +op { + graph_op_name: "EnsureShape" + in_arg { + name: "input" + description: <<END +A tensor, whose shape is to be validated. +END + } + out_arg { + name: "output" + description: <<END +A tensor with the same shape and contents as the input tensor or value. +END + } + attr { + name: "shape" + description: <<END +The expected (possibly partially specified) shape of the input tensor. +END + } + summary: "Ensures that the tensor's shape matches the expected shape." + description: <<END +Raises an error if the input tensor's shape does not match the specified shape. +Returns the input tensor otherwise. +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt new file mode 100644 index 0000000000..4414d973ac --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "EnsureShape" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index 28a39bae3f..ab1ce0f9c8 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // See docs in ../ops/array_ops.cc. #include "tensorflow/core/kernels/shape_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -460,4 +461,96 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze") SqueezeOp); #endif // TENSORFLOW_USE_SYCL +class EnsureShapeOp : public OpKernel { + public: + explicit EnsureShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_)); + } + + void Compute(OpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, + shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape)); + + if (!expected_shape_.IsCompatibleWith(shape)) { + ctx->SetStatus(errors::InvalidArgument( + "Shape of tensor ", this->def().input(0), " ", shape.DebugString(), + " is not compatible with expected shape ", + expected_shape_.DebugString(), ".")); + } + + // If shape matches, outputs the tensor. + if (IsRefType(ctx->input_dtype(0))) { + ctx->forward_ref_input_to_ref_output(0, 0); + } else { + ctx->set_output(0, ctx->input(0)); + } + } + + bool IsExpensive() override { return false; } + + private: + PartialTensorShape expected_shape_; +}; + +// NOTE(rachelim): The kernel registrations for EnsureShapeOp are identical to +// those of the identity op, since the ops have the same device type +// constraints. +REGISTER_KERNEL_BUILDER(Name("EnsureShape").Device(DEVICE_CPU), EnsureShapeOp); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("EnsureShape").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + EnsureShapeOp) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#undef REGISTER_SYCL_KERNEL + +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("EnsureShape") \ + .Device(DEVICE_SYCL) \ + .HostMemory("input") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + EnsureShapeOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(bool); + +#undef REGISTER_SYCL_HOST_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("EnsureShape").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ + EnsureShapeOp) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +REGISTER_GPU_KERNEL(Variant); + +#undef REGISTER_GPU_KERNEL + +#if GOOGLE_CUDA +// A special GPU kernel for int32 and bool. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("EnsureShape") \ + .Device(DEVICE_GPU) \ + .HostMemory("input") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + EnsureShapeOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(bool); +REGISTER_GPU_HOST_KERNEL(string); +REGISTER_GPU_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_GPU_HOST_KERNEL + +#endif } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 1d11ec00ce..7dbb18aa5d 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1446,6 +1446,30 @@ REGISTER_OP("ShapeN") .Attr("out_type: {int32, int64} = DT_INT32") .SetShapeFn(ShapeShapeFn); +REGISTER_OP("EnsureShape") + .Input("input: T") + .Output("output: T") + .Attr("shape: shape") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + // Merges desired shape and statically known shape of input + PartialTensorShape desired_shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); + + int rank = desired_shape.dims(); + ShapeHandle input_shape_handle; + ShapeHandle desired_shape_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle)); + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + desired_shape, &desired_shape_handle)); + + ShapeHandle merged_shape; + TF_RETURN_IF_ERROR( + c->Merge(desired_shape_handle, input_shape_handle, &merged_shape)); + c->set_output(0, merged_shape); + return Status::OK(); + }); + // -------------------------------------------------------------------------- REGISTER_OP("ReverseSequence") .Input("input: T") diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 192aadbaba..8d72eb39c0 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -515,6 +515,11 @@ class Tensor(_TensorLike): ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)]) ``` + NOTE: This shape is not enforced at runtime. Setting incorrect shapes can + result in inconsistencies between the statically-known graph and the runtime + value of tensors. For runtime validation of the shape, use `tf.ensure_shape` + instead. + Args: shape: A `TensorShape` representing the shape of this tensor, a `TensorShapeProto`, a list, a tuple, or None. diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index a9982a7ae0..f7c9d54758 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1388,6 +1388,8 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index bda6ca5ca9..05f998d0d2 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -18,8 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time import numpy as np +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,6 +33,8 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -745,6 +751,146 @@ class AssertPositiveTest(test.TestCase): self.evaluate(out) +class EnsureShapeTest(test.TestCase): + + # Static shape inference + def testStaticShape(self): + placeholder = array_ops.placeholder(dtypes.int32) + ensure_shape_op = check_ops.ensure_shape(placeholder, (3, 3, 3)) + self.assertEqual(ensure_shape_op.get_shape(), (3, 3, 3)) + + def testStaticShape_MergesShapes(self): + placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3)) + ensure_shape_op = check_ops.ensure_shape(placeholder, (5, 4, None)) + self.assertEqual(ensure_shape_op.get_shape(), (5, 4, 3)) + + def testStaticShape_RaisesErrorWhenRankIncompatible(self): + placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3)) + with self.assertRaises(ValueError): + check_ops.ensure_shape(placeholder, (2, 3)) + + def testStaticShape_RaisesErrorWhenDimIncompatible(self): + placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3)) + with self.assertRaises(ValueError): + check_ops.ensure_shape(placeholder, (2, 2, 4)) + + def testStaticShape_CanSetUnknownShape(self): + placeholder = array_ops.placeholder(dtypes.int32) + derived = placeholder / 3 + ensure_shape_op = check_ops.ensure_shape(derived, None) + self.assertEqual(ensure_shape_op.get_shape(), None) + + # Dynamic shape check + def testEnsuresDynamicShape_RaisesError(self): + placeholder = array_ops.placeholder(dtypes.int32) + derived = math_ops.divide(placeholder, 3, name="MyDivide") + derived = check_ops.ensure_shape(derived, (3, 3, 3)) + feed_val = [[1], [2]] + with self.test_session() as sess: + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + r"Shape of tensor MyDivide \[2,1\] is not compatible with " + r"expected shape \[3,3,3\]."): + sess.run(derived, feed_dict={placeholder: feed_val}) + + def testEnsuresDynamicShape_RaisesErrorDimUnknown(self): + placeholder = array_ops.placeholder(dtypes.int32) + derived = placeholder / 3 + derived = check_ops.ensure_shape(derived, (None, None, 3)) + feed_val = [[1], [2]] + with self.test_session() as sess: + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with " + r"expected shape \[\?,\?,3\]."): + sess.run(derived, feed_dict={placeholder: feed_val}) + + def testEnsuresDynamicShape(self): + placeholder = array_ops.placeholder(dtypes.int32) + derived = placeholder / 3 + derived = check_ops.ensure_shape(derived, (2, 1)) + feed_val = [[1], [2]] + with self.test_session() as sess: + sess.run(derived, feed_dict={placeholder: feed_val}) + + def testEnsuresDynamicShape_WithUnknownDims(self): + placeholder = array_ops.placeholder(dtypes.int32) + derived = placeholder / 3 + derived = check_ops.ensure_shape(derived, (None, None)) + feed_val = [[1], [2]] + with self.test_session() as sess: + sess.run(derived, feed_dict={placeholder: feed_val}) + + +class EnsureShapeBenchmark(test.Benchmark): + + def _grappler_all_off_config(self): + config = config_pb2.ConfigProto() + off = rewriter_config_pb2.RewriterConfig.OFF + config.graph_options.optimizer_options.opt_level = -1 + config.graph_options.rewrite_options.disable_model_pruning = 1 + config.graph_options.rewrite_options.constant_folding = off + config.graph_options.rewrite_options.layout_optimizer = off + config.graph_options.rewrite_options.arithmetic_optimization = off + config.graph_options.rewrite_options.dependency_optimization = off + return config + + def _run(self, op, feed_dict=None, num_iters=5000, name=None, **kwargs): + config = self._grappler_all_off_config() + with session.Session(config=config) as sess: + deltas = [] + # Warm up the session + for _ in range(5): + sess.run(op, feed_dict=feed_dict) + for _ in range(num_iters): + start = time.time() + sess.run(op, feed_dict=feed_dict) + end = time.time() + deltas.append(end - start) + mean_time = np.median(deltas) + mean_us = mean_time * 1e6 + # mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + name=name, + wall_time=mean_us, + extras=kwargs, + ) + + def benchmark_const_op(self): + # In this case, we expect that the overhead of a `session.run` call + # far outweighs the time taken to execute the op... + shape = (3, 3, 100) + input_op = random_ops.random_normal(shape) + self._run(array_ops.identity(input_op), name="SingleConstOp") + + def benchmark_single_ensure_op(self): + # In this case, we expect that the overhead of a `session.run` call + # far outweighs the time taken to execute the op... + shape = (3, 3, 100) + input_op = random_ops.random_normal(shape) + ensure_shape_op = check_ops.ensure_shape(input_op, shape) + self._run(ensure_shape_op, name="SingleEnsureShapeOp") + + def _apply_n_times(self, op, target, n=1000): + for _ in range(n): + target = op(target) + return target + + def benchmark_n_ops(self): + shape = (1000,) + input_op = random_ops.random_normal(shape) + n_ops = self._apply_n_times(array_ops.identity, input_op) + self._run(n_ops, name="NIdentityOps_1000") + + def benchmark_n_ensure_ops(self): + shape = (1000,) + input_op = random_ops.random_normal(shape) + n_ensure_ops = self._apply_n_times( + lambda x: check_ops.ensure_shape(array_ops.identity(x), shape), + input_op) + self._run(n_ensure_ops, name="NEnsureShapeAndIdentityOps_1000") + + class AssertRankTest(test.TestCase): @test_util.run_in_graph_and_eager_modes diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index c5a0f2949e..6528062f3c 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1243,3 +1244,51 @@ def assert_scalar(tensor, name=None): raise ValueError('Expected scalar shape for %s, saw shape: %s.' % (tensor.name, shape)) return tensor + + +@tf_export('ensure_shape') +def ensure_shape(x, shape, name=None): + """Updates the shape of a tensor and checks at runtime that the shape holds. + + For example: + ```python + x = tf.placeholder(tf.int32) + print(x.shape) + ==> TensorShape(None) + y = x * 2 + print(y.shape) + ==> TensorShape(None) + + y = tf.ensure_shape(y, (None, 3, 3)) + print(y.shape) + ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)]) + + with tf.Session() as sess: + # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not + # compatible with the shape (None, 3, 3) + sess.run(y, feed_dict={x: [1, 2, 3]}) + + ``` + + NOTE: This differs from `Tensor.set_shape` in that it sets the static shape + of the resulting tensor and enforces it at runtime, raising an error if the + tensor's runtime shape is incompatible with the specified shape. + `Tensor.set_shape` sets the static shape of the tensor without enforcing it + at runtime, which may result in inconsistencies between the statically-known + shape of tensors and the runtime value of tensors. + + Args: + x: A `Tensor`. + shape: A `TensorShape` representing the shape of this tensor, a + `TensorShapeProto`, a list, a tuple, or None. + name: A name for this operation (optional). Defaults to "EnsureShape". + + Returns: + A `Tensor`. Has the same type and contents as `x`. At runtime, raises a + `tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape + of `x`. + """ + if not isinstance(shape, tensor_shape.TensorShape): + shape = tensor_shape.TensorShape(shape) + + return array_ops.ensure_shape(x, shape, name=name) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 00fe63f55e..821ca7b140 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1049,6 +1049,10 @@ tf_module { argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { + name: "ensure_shape" + argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { name: "equal" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 807908617a..519cf66aa4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -1025,6 +1025,10 @@ tf_module { argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } member_method { + name: "ensure_shape" + argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { name: "equal" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } |