aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/shape_ops.h
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-08-14 18:58:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 19:05:11 -0700
commit94ba1c4f0eccd234b4e0e5b504ddf1803067f1bc (patch)
tree87ec88e39bd3dc81b258f155549200ce676d6d89 /tensorflow/core/kernels/shape_ops.h
parentc16517f22b4c4601c1e02ec5cb55193743443878 (diff)
ExpandDims requires the 'dim' argument to be a scalar.
This change brings the kernel implementation in sync with the shape function. Prior to this change, when executing eagerly, for example with: import tensorflow as tf tf.enable_eager_execution() print(tf.expand_dims(1, axis=[0, 1])) the operation would succeed (because the kernel was effectively considering axis=0). However, the same line (tf.expand_dims(1, axis=[0, 1])) would fail in graph construction since the shape function for the ExpandDims operation required a scalar. This change addresses this one discrepancy, but there are likely more and a more comprehensive approach will still need some figuring out. PiperOrigin-RevId: 208755018
Diffstat (limited to 'tensorflow/core/kernels/shape_ops.h')
-rw-r--r--tensorflow/core/kernels/shape_ops.h8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h
index 55be308901..f75723af7d 100644
--- a/tensorflow/core/kernels/shape_ops.h
+++ b/tensorflow/core/kernels/shape_ops.h
@@ -154,6 +154,9 @@ class ExpandDimsOp : public OpKernel {
OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
errors::InvalidArgument("ExpandDims on Variant not supported"));
+ OP_REQUIRES(
+ ctx, (ctx->input(1).NumElements() == 1),
+ errors::InvalidArgument("'dim' must be a tensor with a single value"));
Tdim dim = ctx->input(1).flat<Tdim>()(0);
OP_REQUIRES(
ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
@@ -236,9 +239,8 @@ class SqueezeOp : public OpKernel {
if (wrapped_squeeze_dims.count(i) > 0) {
OP_REQUIRES(ctx, existing_dim == 1,
errors::InvalidArgument(
- "Tried to explicitly squeeze "
- "dimension ",
- i, " but dimension was not 1: ", existing_dim));
+ "Can not squeeze dim[", i,
+ "], expected a dimension of 1, got ", existing_dim));
} else {
// This dimension is not being squeezed.
new_shape.push_back(existing_dim);