diff options
Diffstat (limited to 'tensorflow/core/kernels/shape_ops.h')
-rw-r--r-- | tensorflow/core/kernels/shape_ops.h | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index ac607f4e8b..8d9d0ea846 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -145,6 +145,7 @@ class SizeOp : public OpKernel { bool IsExpensive() override { return false; } }; +template <typename Tdim> class ExpandDimsOp : public OpKernel { public: explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -153,7 +154,7 @@ class ExpandDimsOp : public OpKernel { OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, errors::InvalidArgument("ExpandDims on Variant not supported")); - int32 dim = ctx->input(1).flat<int32>()(0); + Tdim dim = ctx->input(1).flat<Tdim>()(0); OP_REQUIRES( ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), errors::InvalidArgument("Tried to expand dim index ", dim, @@ -175,7 +176,7 @@ class ExpandDimsOp : public OpKernel { } // Clamp to the end if needed. - dim = std::min<int32>(dim, existing_dims_size); + dim = std::min<Tdim>(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); const TensorShape output_shape(new_shape); @@ -234,10 +235,10 @@ class SqueezeOp : public OpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument( - "Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", existing_dim)); + errors::InvalidArgument("Tried to explicitly squeeze " + "dimension ", + i, " but dimension was not 1: ", + existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); |