diff options
author | Asim Shankar <ashankar@google.com> | 2018-08-14 18:58:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-14 19:05:11 -0700 |
commit | 94ba1c4f0eccd234b4e0e5b504ddf1803067f1bc (patch) | |
tree | 87ec88e39bd3dc81b258f155549200ce676d6d89 /tensorflow/core/kernels/shape_ops.h | |
parent | c16517f22b4c4601c1e02ec5cb55193743443878 (diff) |
ExpandDims requires the 'dim' argument to be a scalar.
This change brings the kernel implementation in sync with the shape function.
Prior to this change, when executing eagerly, for example with:
import tensorflow as tf
tf.enable_eager_execution()
print(tf.expand_dims(1, axis=[0, 1]))
the operation would succeed (because the kernel was effectively considering
axis=0). However, the same line (tf.expand_dims(1, axis=[0, 1])) would fail
in graph construction since the shape function for the ExpandDims operation
required a scalar.
This change addresses this one discrepancy, but there are likely more and
a more comprehensive approach will still need some figuring out.
PiperOrigin-RevId: 208755018
Diffstat (limited to 'tensorflow/core/kernels/shape_ops.h')
-rw-r--r-- | tensorflow/core/kernels/shape_ops.h | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index 55be308901..f75723af7d 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -154,6 +154,9 @@ class ExpandDimsOp : public OpKernel { OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, errors::InvalidArgument("ExpandDims on Variant not supported")); + OP_REQUIRES( + ctx, (ctx->input(1).NumElements() == 1), + errors::InvalidArgument("'dim' must be a tensor with a single value")); Tdim dim = ctx->input(1).flat<Tdim>()(0); OP_REQUIRES( ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), @@ -236,9 +239,8 @@ class SqueezeOp : public OpKernel { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, errors::InvalidArgument( - "Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", existing_dim)); + "Can not squeeze dim[", i, + "], expected a dimension of 1, got ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); |