aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/index_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/index_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops.cc12
1 files changed, 4 insertions, 8 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index 36eb4c7545..f396474858 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -60,19 +60,15 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
input_shape.DebugString()));
DataType index_type = output_type(0);
+ xla::PrimitiveType index_xla_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(index_type, &index_xla_type));
- xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
-
xla::XlaOp output;
if (is_min_) {
- OP_REQUIRES_OK(ctx,
- XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0),
- index_type, axis, &output));
+ output = XlaHelpers::ArgMin(input, index_xla_type, axis);
} else {
- OP_REQUIRES_OK(ctx,
- XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0),
- index_type, axis, &output));
+ output = XlaHelpers::ArgMax(input, index_xla_type, axis);
}
ctx->SetOutput(0, output);