diff options
Diffstat (limited to 'tensorflow/core/kernels/shape_ops.h')
-rw-r--r-- | tensorflow/core/kernels/shape_ops.h | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index 8d9d0ea846..ac607f4e8b 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -145,7 +145,6 @@ class SizeOp : public OpKernel { bool IsExpensive() override { return false; } }; -template <typename Tdim> class ExpandDimsOp : public OpKernel { public: explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -154,7 +153,7 @@ class ExpandDimsOp : public OpKernel { OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, errors::InvalidArgument("ExpandDims on Variant not supported")); - Tdim dim = ctx->input(1).flat<Tdim>()(0); + int32 dim = ctx->input(1).flat<int32>()(0); OP_REQUIRES( ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), errors::InvalidArgument("Tried to expand dim index ", dim, @@ -176,7 +175,7 @@ class ExpandDimsOp : public OpKernel { } // Clamp to the end if needed. - dim = std::min<Tdim>(dim, existing_dims_size); + dim = std::min<int32>(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); const TensorShape output_shape(new_shape); @@ -235,10 +234,10 @@ class SqueezeOp : public OpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze " + "dimension ", + i, " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); |