aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-08-28 11:07:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 11:12:51 -0700
commitcf879ef5dddee8d1b5081afe5bd8f49f15245d08 (patch)
tree7da03b5946bda44edd6b2220ea0aeb2a4f32e9a4 /tensorflow
parent8f99e5ad11040a6f0b5c12648e98bdbfe4dc3970 (diff)
Adds a tf.ensure_shape function as a substitute for tensor.set_shape, which validates the true shape of the tensor at runtime.
PiperOrigin-RevId: 210570878
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt26
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt4
-rw-r--r--tensorflow/core/kernels/shape_ops.cc93
-rw-r--r--tensorflow/core/ops/array_ops.cc24
-rw-r--r--tensorflow/python/framework/ops.py5
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py146
-rw-r--r--tensorflow/python/ops/check_ops.py49
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt4
10 files changed, 357 insertions, 0 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt
new file mode 100644
index 0000000000..1658472209
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt
@@ -0,0 +1,26 @@
+op {
+ graph_op_name: "EnsureShape"
+ in_arg {
+ name: "input"
+ description: <<END
+A tensor, whose shape is to be validated.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A tensor with the same shape and contents as the input tensor or value.
+END
+ }
+ attr {
+ name: "shape"
+ description: <<END
+The expected (possibly partially specified) shape of the input tensor.
+END
+ }
+ summary: "Ensures that the tensor's shape matches the expected shape."
+ description: <<END
+Raises an error if the input tensor's shape does not match the specified shape.
+Returns the input tensor otherwise.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt
new file mode 100644
index 0000000000..4414d973ac
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EnsureShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index 28a39bae3f..ab1ce0f9c8 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#include "tensorflow/core/kernels/shape_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
@@ -460,4 +461,96 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze")
SqueezeOp);
#endif // TENSORFLOW_USE_SYCL
+class EnsureShapeOp : public OpKernel {
+ public:
+ explicit EnsureShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx,
+ shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
+
+ if (!expected_shape_.IsCompatibleWith(shape)) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Shape of tensor ", this->def().input(0), " ", shape.DebugString(),
+ " is not compatible with expected shape ",
+ expected_shape_.DebugString(), "."));
+ }
+
+ // If shape matches, outputs the tensor.
+ if (IsRefType(ctx->input_dtype(0))) {
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ ctx->set_output(0, ctx->input(0));
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ PartialTensorShape expected_shape_;
+};
+
+// NOTE(rachelim): The kernel registrations for EnsureShapeOp are identical to
+// those of the identity op, since the ops have the same device type
+// constraints.
+REGISTER_KERNEL_BUILDER(Name("EnsureShape").Device(DEVICE_CPU), EnsureShapeOp);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EnsureShape").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#undef REGISTER_SYCL_KERNEL
+
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("EnsureShape") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("input") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(bool);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EnsureShape").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(Variant);
+
+#undef REGISTER_GPU_KERNEL
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32 and bool.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("EnsureShape") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("input") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(bool);
+REGISTER_GPU_HOST_KERNEL(string);
+REGISTER_GPU_HOST_KERNEL(ResourceHandle);
+
+#undef REGISTER_GPU_HOST_KERNEL
+
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 1d11ec00ce..7dbb18aa5d 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1446,6 +1446,30 @@ REGISTER_OP("ShapeN")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(ShapeShapeFn);
+REGISTER_OP("EnsureShape")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("shape: shape")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ // Merges desired shape and statically known shape of input
+ PartialTensorShape desired_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
+
+ int rank = desired_shape.dims();
+ ShapeHandle input_shape_handle;
+ ShapeHandle desired_shape_handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ desired_shape, &desired_shape_handle));
+
+ ShapeHandle merged_shape;
+ TF_RETURN_IF_ERROR(
+ c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
+ c->set_output(0, merged_shape);
+ return Status::OK();
+ });
+
// --------------------------------------------------------------------------
REGISTER_OP("ReverseSequence")
.Input("input: T")
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 192aadbaba..8d72eb39c0 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -515,6 +515,11 @@ class Tensor(_TensorLike):
==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
```
+ NOTE: This shape is not enforced at runtime. Setting incorrect shapes can
+ result in inconsistencies between the statically-known graph and the runtime
+ value of tensors. For runtime validation of the shape, use `tf.ensure_shape`
+ instead.
+
Args:
shape: A `TensorShape` representing the shape of this tensor, a
`TensorShapeProto`, a list, a tuple, or None.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index a9982a7ae0..f7c9d54758 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1388,6 +1388,8 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index bda6ca5ca9..05f998d0d2 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -18,8 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
import numpy as np
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,6 +33,8 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -745,6 +751,146 @@ class AssertPositiveTest(test.TestCase):
self.evaluate(out)
+class EnsureShapeTest(test.TestCase):
+
+ # Static shape inference
+ def testStaticShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ ensure_shape_op = check_ops.ensure_shape(placeholder, (3, 3, 3))
+ self.assertEqual(ensure_shape_op.get_shape(), (3, 3, 3))
+
+ def testStaticShape_MergesShapes(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ ensure_shape_op = check_ops.ensure_shape(placeholder, (5, 4, None))
+ self.assertEqual(ensure_shape_op.get_shape(), (5, 4, 3))
+
+ def testStaticShape_RaisesErrorWhenRankIncompatible(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ with self.assertRaises(ValueError):
+ check_ops.ensure_shape(placeholder, (2, 3))
+
+ def testStaticShape_RaisesErrorWhenDimIncompatible(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ with self.assertRaises(ValueError):
+ check_ops.ensure_shape(placeholder, (2, 2, 4))
+
+ def testStaticShape_CanSetUnknownShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ ensure_shape_op = check_ops.ensure_shape(derived, None)
+ self.assertEqual(ensure_shape_op.get_shape(), None)
+
+ # Dynamic shape check
+ def testEnsuresDynamicShape_RaisesError(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = math_ops.divide(placeholder, 3, name="MyDivide")
+ derived = check_ops.ensure_shape(derived, (3, 3, 3))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ r"Shape of tensor MyDivide \[2,1\] is not compatible with "
+ r"expected shape \[3,3,3\]."):
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (None, None, 3))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
+ r"expected shape \[\?,\?,3\]."):
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (2, 1))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape_WithUnknownDims(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (None, None))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+
+class EnsureShapeBenchmark(test.Benchmark):
+
+ def _grappler_all_off_config(self):
+ config = config_pb2.ConfigProto()
+ off = rewriter_config_pb2.RewriterConfig.OFF
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.disable_model_pruning = 1
+ config.graph_options.rewrite_options.constant_folding = off
+ config.graph_options.rewrite_options.layout_optimizer = off
+ config.graph_options.rewrite_options.arithmetic_optimization = off
+ config.graph_options.rewrite_options.dependency_optimization = off
+ return config
+
+ def _run(self, op, feed_dict=None, num_iters=5000, name=None, **kwargs):
+ config = self._grappler_all_off_config()
+ with session.Session(config=config) as sess:
+ deltas = []
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op, feed_dict=feed_dict)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(op, feed_dict=feed_dict)
+ end = time.time()
+ deltas.append(end - start)
+ mean_time = np.median(deltas)
+ mean_us = mean_time * 1e6
+ # mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ wall_time=mean_us,
+ extras=kwargs,
+ )
+
+ def benchmark_const_op(self):
+ # In this case, we expect that the overhead of a `session.run` call
+ # far outweighs the time taken to execute the op...
+ shape = (3, 3, 100)
+ input_op = random_ops.random_normal(shape)
+ self._run(array_ops.identity(input_op), name="SingleConstOp")
+
+ def benchmark_single_ensure_op(self):
+ # In this case, we expect that the overhead of a `session.run` call
+ # far outweighs the time taken to execute the op...
+ shape = (3, 3, 100)
+ input_op = random_ops.random_normal(shape)
+ ensure_shape_op = check_ops.ensure_shape(input_op, shape)
+ self._run(ensure_shape_op, name="SingleEnsureShapeOp")
+
+ def _apply_n_times(self, op, target, n=1000):
+ for _ in range(n):
+ target = op(target)
+ return target
+
+ def benchmark_n_ops(self):
+ shape = (1000,)
+ input_op = random_ops.random_normal(shape)
+ n_ops = self._apply_n_times(array_ops.identity, input_op)
+ self._run(n_ops, name="NIdentityOps_1000")
+
+ def benchmark_n_ensure_ops(self):
+ shape = (1000,)
+ input_op = random_ops.random_normal(shape)
+ n_ensure_ops = self._apply_n_times(
+ lambda x: check_ops.ensure_shape(array_ops.identity(x), shape),
+ input_op)
+ self._run(n_ensure_ops, name="NEnsureShapeAndIdentityOps_1000")
+
+
class AssertRankTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index c5a0f2949e..6528062f3c 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -1243,3 +1244,51 @@ def assert_scalar(tensor, name=None):
raise ValueError('Expected scalar shape for %s, saw shape: %s.'
% (tensor.name, shape))
return tensor
+
+
+@tf_export('ensure_shape')
+def ensure_shape(x, shape, name=None):
+ """Updates the shape of a tensor and checks at runtime that the shape holds.
+
+ For example:
+ ```python
+ x = tf.placeholder(tf.int32)
+ print(x.shape)
+ ==> TensorShape(None)
+ y = x * 2
+ print(y.shape)
+ ==> TensorShape(None)
+
+ y = tf.ensure_shape(y, (None, 3, 3))
+ print(y.shape)
+ ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)])
+
+ with tf.Session() as sess:
+ # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not
+ # compatible with the shape (None, 3, 3)
+ sess.run(y, feed_dict={x: [1, 2, 3]})
+
+ ```
+
+ NOTE: This differs from `Tensor.set_shape` in that it sets the static shape
+ of the resulting tensor and enforces it at runtime, raising an error if the
+ tensor's runtime shape is incompatible with the specified shape.
+ `Tensor.set_shape` sets the static shape of the tensor without enforcing it
+ at runtime, which may result in inconsistencies between the statically-known
+ shape of tensors and the runtime value of tensors.
+
+ Args:
+ x: A `Tensor`.
+ shape: A `TensorShape` representing the shape of this tensor, a
+ `TensorShapeProto`, a list, a tuple, or None.
+ name: A name for this operation (optional). Defaults to "EnsureShape".
+
+ Returns:
+ A `Tensor`. Has the same type and contents as `x`. At runtime, raises a
+ `tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape
+ of `x`.
+ """
+ if not isinstance(shape, tensor_shape.TensorShape):
+ shape = tensor_shape.TensorShape(shape)
+
+ return array_ops.ensure_shape(x, shape, name=name)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 00fe63f55e..821ca7b140 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1049,6 +1049,10 @@ tf_module {
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "ensure_shape"
+ argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 807908617a..519cf66aa4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -1025,6 +1025,10 @@ tf_module {
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "ensure_shape"
+ argspec: "args=[\'x\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}