aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-27 12:12:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 12:15:39 -0700
commit35cb434a9a95bef7ca8d7880d87dd9775eeba336 (patch)
tree976358f9a935cbbdf76407f60688c08b6484aeae /tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
parent1536bba6be3e16f3983b79dd6931de313c900114 (diff)
[TF:XLA] Refactor TF/XLA code to use free functions in xla:: namespace to build XlaOps, rather than calling XlaBuilder methods.
PiperOrigin-RevId: 202348891
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc18
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 2c2d88486f..a020ebc729 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
// XLA passes <out> to the function, so it is not included here.
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
+ args.push_back(xla::ConstantLiteral(
+ &b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
- args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim)));
+ args.push_back(xla::ConstantLiteral(
+ &b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
+ args.push_back(
+ xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
@@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
- output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape);
+ output =
+ xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape);
break;
case 2:
- output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape);
+ output =
+ xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape);
break;
default:
OP_REQUIRES(ctx, false,