aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/shape_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/shape_ops.h')
-rw-r--r--tensorflow/core/kernels/shape_ops.h13
1 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h
index 8d9d0ea846..ac607f4e8b 100644
--- a/tensorflow/core/kernels/shape_ops.h
+++ b/tensorflow/core/kernels/shape_ops.h
@@ -145,7 +145,6 @@ class SizeOp : public OpKernel {
bool IsExpensive() override { return false; }
};
-template <typename Tdim>
class ExpandDimsOp : public OpKernel {
public:
explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -154,7 +153,7 @@ class ExpandDimsOp : public OpKernel {
OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
errors::InvalidArgument("ExpandDims on Variant not supported"));
- Tdim dim = ctx->input(1).flat<Tdim>()(0);
+ int32 dim = ctx->input(1).flat<int32>()(0);
OP_REQUIRES(
ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
errors::InvalidArgument("Tried to expand dim index ", dim,
@@ -176,7 +175,7 @@ class ExpandDimsOp : public OpKernel {
}
// Clamp to the end if needed.
- dim = std::min<Tdim>(dim, existing_dims_size);
+ dim = std::min<int32>(dim, existing_dims_size);
new_shape.emplace(new_shape.begin() + dim, 1);
const TensorShape output_shape(new_shape);
@@ -235,10 +234,10 @@ class SqueezeOp : public OpKernel {
if (!wrapped_squeeze_dims.empty()) {
if (wrapped_squeeze_dims.count(i) > 0) {
OP_REQUIRES(ctx, existing_dim == 1,
- errors::InvalidArgument("Tried to explicitly squeeze "
- "dimension ",
- i, " but dimension was not 1: ",
- existing_dim));
+ errors::InvalidArgument(
+ "Tried to explicitly squeeze "
+ "dimension ",
+ i, " but dimension was not 1: ", existing_dim));
} else {
// This dimension is not being squeezed.
new_shape.push_back(existing_dim);