diff options
author | 2018-06-27 12:12:33 -0700 | |
---|---|---|
committer | 2018-06-27 12:15:39 -0700 | |
commit | 35cb434a9a95bef7ca8d7880d87dd9775eeba336 (patch) | |
tree | 976358f9a935cbbdf76407f60688c08b6484aeae /tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc | |
parent | 1536bba6be3e16f3983b79dd6931de313c900114 (diff) |
[TF:XLA] Refactor TF/XLA code to use free functions in xla:: namespace to build XlaOps, rather than calling XlaBuilder methods.
PiperOrigin-RevId: 202348891
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 2c2d88486f..a020ebc729 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // XLA passes <out> to the function, so it is not included here. std::vector<xla::XlaOp> args; args.push_back(ctx->Input(0)); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1<int64>(input_shape.dim_sizes()))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1<int64>(output_shape.dim_sizes()))); - args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes()))); + args.push_back( + xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim))); } xla::Shape xla_shape = @@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); break; case 2: - output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); break; default: OP_REQUIRES(ctx, false, |