aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/shape_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/shape_ops.h')
-rw-r--r--tensorflow/core/kernels/shape_ops.h5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h
index ac607f4e8b..55be308901 100644
--- a/tensorflow/core/kernels/shape_ops.h
+++ b/tensorflow/core/kernels/shape_ops.h
@@ -145,6 +145,7 @@ class SizeOp : public OpKernel {
bool IsExpensive() override { return false; }
};
+template <typename Tdim>
class ExpandDimsOp : public OpKernel {
public:
explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -153,7 +154,7 @@ class ExpandDimsOp : public OpKernel {
OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
errors::InvalidArgument("ExpandDims on Variant not supported"));
- int32 dim = ctx->input(1).flat<int32>()(0);
+ Tdim dim = ctx->input(1).flat<Tdim>()(0);
OP_REQUIRES(
ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
errors::InvalidArgument("Tried to expand dim index ", dim,
@@ -175,7 +176,7 @@ class ExpandDimsOp : public OpKernel {
}
// Clamp to the end if needed.
- dim = std::min<int32>(dim, existing_dims_size);
+ dim = std::min<Tdim>(dim, existing_dims_size);
new_shape.emplace(new_shape.begin() + dim, 1);
const TensorShape output_shape(new_shape);