diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 2c2d88486f..22a45b2a11 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // XLA passes <out> to the function, so it is not included here. std::vector<xla::XlaOp> args; args.push_back(ctx->Input(0)); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1<int64>(input_shape.dim_sizes()))); + args.push_back(xla::ConstantLiteral( + &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1<int64>(output_shape.dim_sizes()))); - args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim))); + args.push_back(xla::ConstantLiteral( + &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes()))); + args.push_back( + xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim))); } xla::Shape xla_shape = @@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); break; case 2: - output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); break; default: OP_REQUIRES(ctx, false, |