aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-27 12:12:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 12:15:39 -0700
commit35cb434a9a95bef7ca8d7880d87dd9775eeba336 (patch)
tree976358f9a935cbbdf76407f60688c08b6484aeae
parent1536bba6be3e16f3983b79dd6931de313c900114 (diff)
[TF:XLA] Refactor TF/XLA code to use free functions in xla:: namespace to build XlaOps, rather than calling XlaBuilder methods.
PiperOrigin-RevId: 202348891
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc51
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bias_ops.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc120
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bucketize_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cast_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc31
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc49
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cross_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc39
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc28
-rw-r--r--tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc59
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fft_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fill_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc87
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc80
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/l2loss_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/listdiff_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/lrn_ops.cc39
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matmul_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc28
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pack_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc86
-rw-r--r--tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc29
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc20
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc99
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/select_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc60
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sort_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc33
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc42
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc42
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc77
-rw-r--r--tensorflow/compiler/tf2xla/kernels/transpose_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc160
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unpack_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc15
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc18
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc36
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.cc7
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc50
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc117
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc38
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc11
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc26
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc17
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc32
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc75
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc10
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.cc25
86 files changed, 1167 insertions, 1104 deletions
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 4a6622ed73..4900af6df1 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -230,7 +231,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
XlaContext& context = XlaContext::Get(op_context);
auto* b = context.builder();
- auto output_handle = b->Call(*result.computation, handles);
+ auto output_handle = xla::Call(b, *result.computation, handles);
// The output handle of `Call` computation is a tuple type. Unzip it so
// that it can fit into future computations.
int computation_output = 0;
@@ -239,7 +240,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
} else {
xla_op_context.SetOutput(
- i, b->GetTupleElement(output_handle, computation_output));
+ i, xla::GetTupleElement(output_handle, computation_output));
++computation_output;
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
index 1e59868621..e335328280 100644
--- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -31,7 +32,7 @@ class AddNOp : public XlaOpKernel {
xla::XlaOp sum = ctx->Input(0);
for (int i = 1; i < ctx->num_inputs(); ++i) {
- sum = ctx->builder()->Add(sum, ctx->Input(i));
+ sum = xla::Add(sum, ctx->Input(i));
}
ctx->SetOutput(0, sum);
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index 93fbc40461..259fe05b09 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -49,8 +50,6 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(1), &scale_type));
- xla::XlaBuilder* builder = ctx->builder();
-
xla::XlaOp input = ctx->Input(0);
TensorShape input_shape = ctx->InputShape(0);
@@ -60,30 +59,30 @@ class FusedBatchNormOp : public XlaOpKernel {
// TODO(b/69928690): support mixed precision in the XLA batch normalization
// operators. As a workaround, cast everything to the statistics type (which
// may be more precise than the input type).
- input = builder->ConvertElementType(input, scale_type);
+ input = xla::ConvertElementType(input, scale_type);
if (is_training_) {
- xla::XlaOp output = builder->BatchNormTraining(
+ xla::XlaOp output = xla::BatchNormTraining(
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
// In training mode, outputs the normalized value as well as the
// calculated mean and variance.
- ctx->SetOutput(0, builder->ConvertElementType(
- builder->GetTupleElement(output, 0), input_type));
- ctx->SetOutput(1, builder->GetTupleElement(output, 1));
- ctx->SetOutput(2, builder->GetTupleElement(output, 2));
+ ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0),
+ input_type));
+ ctx->SetOutput(1, xla::GetTupleElement(output, 1));
+ ctx->SetOutput(2, xla::GetTupleElement(output, 2));
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
// space 1 & 2". They are used to pass the per-batch mean and
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
- ctx->SetOutput(3, builder->GetTupleElement(output, 1));
- ctx->SetOutput(4, builder->GetTupleElement(output, 2));
+ ctx->SetOutput(3, xla::GetTupleElement(output, 1));
+ ctx->SetOutput(4, xla::GetTupleElement(output, 2));
} else {
- xla::XlaOp output = builder->BatchNormInference(
+ xla::XlaOp output = xla::BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index);
- ctx->SetOutput(0, builder->ConvertElementType(output, input_type));
+ ctx->SetOutput(0, xla::ConvertElementType(output, input_type));
// Directly send input to output as mean and variance in inference mode.
ctx->SetOutput(1, ctx->Input(3));
ctx->SetOutput(2, ctx->Input(4));
@@ -144,12 +143,12 @@ class FusedBatchNormGradOp : public XlaOpKernel {
xla::XlaOp offset_backprop;
if (is_training_) {
xla::XlaOp output =
- b->BatchNormGrad(activations, scale, mean, var, grad_backprop,
- epsilon_, feature_index);
+ xla::BatchNormGrad(activations, scale, mean, var, grad_backprop,
+ epsilon_, feature_index);
- x_backprop = b->GetTupleElement(output, 0);
- scale_backprop = b->GetTupleElement(output, 1);
- offset_backprop = b->GetTupleElement(output, 2);
+ x_backprop = xla::GetTupleElement(output, 0);
+ scale_backprop = xla::GetTupleElement(output, 1);
+ offset_backprop = xla::GetTupleElement(output, 2);
} else {
// Reduce over all dimensions except the feature dim.
std::vector<int64> reduction_dims(input_dims - 1);
@@ -166,27 +165,27 @@ class FusedBatchNormGradOp : public XlaOpKernel {
auto converted =
XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type);
auto reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype);
// scratch1 = rsqrt(pop_var + epsilon)
auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5);
- auto scratch1 =
- b->Pow(b->Add(var, b->ConstantR0<float>(epsilon_)), neg_half);
+ auto scratch1 = xla::Pow(
+ xla::Add(var, xla::ConstantR0<float>(b, epsilon_)), neg_half);
// scratch2 = sum(y_backprop * (x - mean))
auto mul =
- b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index}));
+ xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index}));
converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type);
reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype);
x_backprop =
- b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index});
- scale_backprop = b->Mul(scratch1, scratch2);
+ xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index});
+ scale_backprop = xla::Mul(scratch1, scratch2);
}
ctx->SetOutput(0,
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 642278ab99..26130fd9e7 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -45,7 +46,6 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
", 2] instead of ",
xla::ShapeUtil::HumanString(crops.shape())));
- xla::XlaBuilder* b = ctx->builder();
const int64 batch_size = input_shape[0];
// Compute the product of the block_shape values.
@@ -72,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
reshaped_shape[block_rank] = batch_size / block_num_elems;
std::copy(input_shape.begin() + 1, input_shape.end(),
reshaped_shape.begin() + block_rank + 1);
- xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
// [batch / prod(block_shape),
@@ -90,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
}
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
- xla::XlaOp permuted = b->Transpose(reshaped, permutation);
+ xla::XlaOp permuted = xla::Transpose(reshaped, permutation);
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
// [batch / prod(block_shape),
@@ -110,7 +110,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_permuted_shape.begin() + 1 + block_rank);
- xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape);
+ xla::XlaOp reshaped_permuted =
+ xla::Reshape(permuted, reshaped_permuted_shape);
// 4. Crop the start and end of dimensions `[1, ..., M]` of
// `reshaped_permuted` according to `crops` to produce the output of shape:
@@ -138,7 +139,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
" end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
}
xla::XlaOp output =
- b->Slice(reshaped_permuted, start_indices, end_indices, strides);
+ xla::Slice(reshaped_permuted, start_indices, end_indices, strides);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
index 9d677f4266..e9b2c0b16d 100644
--- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -60,8 +61,7 @@ class BiasOp : public XlaOpKernel {
"of the input tensor: ",
bias_shape.DebugString(), " vs. ", input_shape.DebugString()));
- xla::XlaOp result =
- ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim});
+ xla::XlaOp result = xla::Add(ctx->Input(0), ctx->Input(1), {feature_dim});
ctx->SetOutput(0, result);
}
@@ -109,8 +109,8 @@ class BiasAddGradOp : public XlaOpKernel {
auto converted =
XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type);
auto reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), reduce_dims);
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), reduce_dims);
ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0)));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index fee939bdea..d6d4ae8937 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -41,18 +41,19 @@ namespace {
const BCast& broadcast_helper, \
const std::vector<int64>& extend_dimensions) override { \
xla::XlaBuilder* b = ctx->builder(); \
+ (void)b; \
return HLO; \
} \
}; \
REGISTER_XLA_OP(Name(#NAME), NAME##Op)
-XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
@@ -67,13 +68,13 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
- auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero));
- auto abs_x = b->Abs(x);
- auto abs_y = b->Abs(y);
- auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one));
- auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y));
+ auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
+ auto abs_x = xla::Abs(x);
+ auto abs_y = xla::Abs(y);
+ auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
+ auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
if (DataTypeIsFloating(dtype)) {
- result = b->Floor(result);
+ result = xla::Floor(result);
}
return result;
}
@@ -87,76 +88,78 @@ static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
- auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero));
- auto trunc_mod = b->Rem(x, y);
- return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y));
+ auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
+ auto trunc_mod = xla::Rem(x, y);
+ return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y));
}
XLA_MAKE_BINARY(FloorMod,
FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
-XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(BitwiseXor, b->Xor(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(RightShift,
(DataTypeIsUnsigned(ctx->input_type(0))
- ? b->ShiftRightLogical(lhs, rhs, extend_dimensions)
- : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
-
-XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs))));
+ ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions)
+ : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
+
+XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs))));
XLA_MAKE_BINARY(
RsqrtGrad,
- b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
- b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
- extend_dimensions));
-XLA_MAKE_BINARY(SqrtGrad,
- b->Div(b->Mul(rhs,
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
- lhs, extend_dimensions));
+ xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
+ xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
+ extend_dimensions));
+XLA_MAKE_BINARY(
+ SqrtGrad,
+ xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
+ lhs, extend_dimensions));
static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) {
- return builder->Mul(x, x);
+ return xla::Mul(x, x);
}
XLA_MAKE_BINARY(SquaredDifference,
- Square(b, b->Sub(lhs, rhs, extend_dimensions)));
+ Square(b, xla::Sub(lhs, rhs, extend_dimensions)));
-XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions));
// Comparison ops
-XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions));
// Non-linear ops
XLA_MAKE_BINARY(SigmoidGrad,
- b->Mul(b->Mul(rhs, lhs),
- b->Sub(XlaHelpers::One(b, input_type(0)), lhs)));
+ xla::Mul(xla::Mul(rhs, lhs),
+ xla::Sub(XlaHelpers::One(b, input_type(0)), lhs)));
XLA_MAKE_BINARY(SoftplusGrad,
- b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
- XlaHelpers::One(b, input_type(1)))));
+ xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)),
+ XlaHelpers::One(b, input_type(1)))));
// softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
XLA_MAKE_BINARY(SoftsignGrad,
- b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)),
- b->Abs(rhs)))));
+ xla::Div(lhs,
+ Square(b, xla::Add(XlaHelpers::One(b, input_type(0)),
+ xla::Abs(rhs)))));
-XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(lhs, lhs))));
+XLA_MAKE_BINARY(TanhGrad,
+ xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)),
+ xla::Mul(lhs, lhs))));
-XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions));
#undef XLA_MAKE_BINARY
@@ -169,12 +172,13 @@ class ApproximateEqualOp : public XlaOpKernel {
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
- auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1)));
+ auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1)));
auto abs_shape = b->GetShape(abs);
OP_REQUIRES_OK(ctx, abs_shape.status());
auto abs_type = abs_shape.ValueOrDie().element_type();
- auto result = b->Lt(
- abs, b->ConvertElementType(b->ConstantR0<float>(tolerance_), abs_type));
+ auto result =
+ xla::Lt(abs, xla::ConvertElementType(
+ xla::ConstantR0<float>(b, tolerance_), abs_type));
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
index ca9a6b4068..efbdb76eaa 100644
--- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
@@ -36,22 +37,22 @@ class BucketizeOp : public XlaOpKernel {
const DataType dtype = context->input_type(0);
xla::XlaOp input = context->Input(0);
- xla::XlaOp boundaries = builder->ConstantR1<float>(boundaries_);
+ xla::XlaOp boundaries = xla::ConstantR1<float>(builder, boundaries_);
// TODO(phawkins): the following behavior matches the behavior of the core
// Bucketize kernel. However, comparing an int32 or int64 against float may
// lead to inaccurate bucketing due to rounding.
if (dtype == DT_DOUBLE) {
- input = builder->ConvertElementType(input, xla::F64);
- boundaries = builder->ConvertElementType(boundaries, xla::F64);
+ input = xla::ConvertElementType(input, xla::F64);
+ boundaries = xla::ConvertElementType(boundaries, xla::F64);
} else {
- input = builder->ConvertElementType(input, xla::F32);
+ input = xla::ConvertElementType(input, xla::F32);
}
- xla::XlaOp comparison = builder->ConvertElementType(
- builder->Ge(builder->Broadcast(input, {1}), boundaries,
- /*broadcast_dimensions=*/{0}),
- xla::S32);
- xla::XlaOp buckets = builder->Reduce(
- comparison, /*init_value=*/builder->ConstantR0<int32>(0),
+ xla::XlaOp comparison =
+ xla::ConvertElementType(xla::Ge(xla::Broadcast(input, {1}), boundaries,
+ /*broadcast_dimensions=*/{0}),
+ xla::S32);
+ xla::XlaOp buckets = xla::Reduce(
+ comparison, /*init_value=*/xla::ConstantR0<int32>(builder, 0),
/*computation=*/xla::CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
context->SetOutput(0, buckets);
diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
index e9d98c7685..62eebf762b 100644
--- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -40,14 +41,14 @@ class CastOp : public XlaOpKernel {
if (src_dtype_ == dst_dtype_) {
output = input;
} else if (dst_dtype_ == DT_BOOL) {
- output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_));
+ output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_));
} else if (xla::primitive_util::IsComplexType(src_type_) &&
!xla::primitive_util::IsComplexType(dst_type_)) {
// As in cast_op.h, we replicate the numpy behavior of truncating the
// imaginary part.
- output = builder->ConvertElementType(builder->Real(input), dst_type_);
+ output = xla::ConvertElementType(xla::Real(input), dst_type_);
} else {
- output = builder->ConvertElementType(input, dst_type_);
+ output = xla::ConvertElementType(input, dst_type_);
}
ctx->SetOutput(0, output);
@@ -72,7 +73,6 @@ class BitcastOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output;
@@ -92,7 +92,7 @@ class BitcastOp : public XlaOpKernel {
xla::primitive_util::BitWidth(dst_type_),
errors::Unimplemented(
"Only bitcasts between equally sized types supported."));
- output = builder->BitcastConvertType(input, dst_type_);
+ output = xla::BitcastConvertType(input, dst_type_);
}
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 835a7f5689..c137d026bd 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -65,17 +66,17 @@ class CategoricalOp : public XlaOpKernel {
DataTypeToPrimitiveType(input_type(0), &uniform_xla_type));
xla::Shape uniform_shape =
xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array);
- auto uniforms = builder->RngUniform(
- XlaHelpers::Zero(builder, input_type(0)),
- XlaHelpers::One(builder, input_type(0)), uniform_shape);
+ auto uniforms =
+ xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)),
+ XlaHelpers::One(builder, input_type(0)), uniform_shape);
// Use Gumbel softmax trick to generate categorical samples.
// See:
// https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
// TODO(b/68769470): Switch to using a cumulative sum approach.
auto softmax_entries =
- builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))),
- /*broadcast_dimensions=*/{0, 2});
+ xla::Sub(logits, xla::Log(xla::Neg(xla::Log(uniforms))),
+ /*broadcast_dimensions=*/{0, 2});
TensorShape softmax_shape(uniform_shape_array);
xla::XlaOp argmax;
diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
index a00bc912f9..4e6d33304c 100644
--- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@@ -29,7 +30,6 @@ class ClipByValueOp : public XlaOpKernel {
const TensorShape min_shape = ctx->InputShape(1);
const TensorShape max_shape = ctx->InputShape(2);
- xla::XlaBuilder* builder = ctx->builder();
auto input = ctx->Input(0);
auto min = ctx->Input(1);
auto max = ctx->Input(2);
@@ -45,13 +45,13 @@ class ClipByValueOp : public XlaOpKernel {
if (shape != min_shape) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error());
- min = builder->Broadcast(min, shape.dim_sizes());
+ min = xla::Broadcast(min, shape.dim_sizes());
}
if (shape != max_shape) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error());
- max = builder->Broadcast(max, shape.dim_sizes());
+ max = xla::Broadcast(max, shape.dim_sizes());
}
- ctx->SetOutput(0, builder->Clamp(min, input, max));
+ ctx->SetOutput(0, xla::Clamp(min, input, max));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index 78285affa1..e3a32a5c0e 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -88,7 +89,7 @@ class ConcatBaseOp : public XlaOpKernel {
"] = ", in_shape.DebugString()));
if (in_shape.dims() == 0) {
// Inputs that come in as scalars must be reshaped to 1-vectors.
- input_data.push_back(ctx->builder()->Reshape(handle, {1}));
+ input_data.push_back(xla::Reshape(handle, {1}));
} else {
input_data.push_back(handle);
}
@@ -96,7 +97,7 @@ class ConcatBaseOp : public XlaOpKernel {
}
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis));
+ ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index 59d06c654d..f4360d8c3f 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
@@ -53,41 +54,41 @@ class ConstOp : public XlaOpKernel {
switch (proto_.dtype()) {
case DT_BOOL:
if (proto_.bool_val_size() == 1) {
- ctx->SetOutput(0,
- b->Broadcast(b->ConstantR0<bool>(proto_.bool_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(
+ 0, xla::Broadcast(xla::ConstantR0<bool>(b, proto_.bool_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_FLOAT:
if (proto_.float_val_size() == 1) {
- ctx->SetOutput(
- 0, b->Broadcast(b->ConstantR0<float>(proto_.float_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<float>(
+ b, proto_.float_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_DOUBLE:
if (proto_.double_val_size() == 1) {
- ctx->SetOutput(
- 0, b->Broadcast(b->ConstantR0<double>(proto_.double_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<double>(
+ b, proto_.double_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_INT32:
if (proto_.int_val_size() == 1) {
- ctx->SetOutput(0,
- b->Broadcast(b->ConstantR0<int32>(proto_.int_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(
+ 0, xla::Broadcast(xla::ConstantR0<int32>(b, proto_.int_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_INT64:
if (proto_.int64_val_size() == 1) {
- ctx->SetOutput(
- 0, b->Broadcast(b->ConstantR0<int64>(proto_.int64_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<int64>(
+ b, proto_.int64_val(0)),
+ shape.dim_sizes()));
return;
}
break;
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 627bad12f3..5d41fc708a 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -51,8 +52,8 @@ xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return builder->Broadcast(XlaHelpers::Zero(builder, dtype),
- expanded_filter_shape.dim_sizes());
+ return xla::Broadcast(XlaHelpers::Zero(builder, dtype),
+ expanded_filter_shape.dim_sizes());
}
// Create a mask for depthwise convolution that will make a normal convolution
@@ -107,20 +108,20 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
// Divide the M*N sized linspace by the depthwise_multiplier to create
// [0 0 1 1 2 2] in the example in the function comment.
expanded_feature_iota =
- builder->Div(expanded_feature_iota,
- XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
- depthwise_multiplier));
+ xla::Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
// Broadcast the N*M linspace to [H, W, ..., M, M*N].
auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
expanded_feature_broadcast_dims.pop_back();
- auto broadcasted_expanded_feature_iota = builder->Broadcast(
- expanded_feature_iota, expanded_feature_broadcast_dims);
+ auto broadcasted_expanded_feature_iota =
+ xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
// Compare the broadcasted linspace to the input feature linspace in the
// input feature dimension to create a diagonal predicate.
- return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota,
- {expanded_filter_shape.dims() - 2});
+ return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dims() - 2});
}
// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
@@ -142,16 +143,16 @@ xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape,
implicit_broadcast_filter_shape.dims() - 1,
depthwise_multiplier * input_feature);
auto implicit_broadcast_filter =
- builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
+ xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
// Broadcast the filter to [H, W, ..., M, M*N].
auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder);
- auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero);
+ auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero);
// If the filter mask is set, choose the broadcasted filter, othwerwise,
// choose zero.
- return builder->Select(CreateExpandedFilterMask(filter_shape, builder),
- expanded_filter, expanded_zero);
+ return xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+ expanded_filter, expanded_zero);
}
// Inverse of ExpandFilterForDepthwiseConvolution.
@@ -162,17 +163,17 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- auto masked_expanded_filter = builder->Select(
+ auto masked_expanded_filter = xla::Select(
CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
CreateExpandedZero(filter_shape, dtype, builder));
- return builder->Reshape(
+ return xla::Reshape(
// This reduce does not need inputs to be converted with
// XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with
// ExpandedZero guarantees that only one element is non zero, so there
// cannot be accumulated precision error.
- builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype),
- {expanded_filter_shape.dims() - 2}),
+ xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
+ *ctx->GetOrCreateAdd(dtype),
+ {expanded_filter_shape.dims() - 2}),
filter_shape.dim_sizes());
}
@@ -289,8 +290,8 @@ class ConvOp : public XlaOpKernel {
}
xla::XlaOp conv =
- b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
- lhs_dilation, rhs_dilation, dims);
+ xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
+ lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
}
@@ -435,11 +436,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
}
// Mirror the filter in the spatial dimensions.
- xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims);
+ xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
- xla::XlaOp in_backprop = b->ConvGeneralDilated(
+ xla::XlaOp in_backprop = xla::ConvGeneralDilated(
out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
lhs_dilation, rhs_dilation, dnums);
@@ -638,8 +639,8 @@ class ConvBackpropFilterOp : public XlaOpKernel {
// This is done by specifying the window dilation factors in the
// convolution HLO below.
auto filter_backprop =
- b->ConvGeneralDilated(activations, gradients, window_strides, padding,
- /*lhs_dilation=*/ones, rhs_dilation, dnums);
+ xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums);
if (depthwise_) {
filter_backprop = ContractFilterForDepthwiseBackprop(
diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
index 7fcd4170fb..500a564f3f 100644
--- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -58,21 +59,21 @@ class CrossOp : public XlaOpKernel {
auto in1 = ctx->Input(1);
starts.back() = 0;
limits.back() = 1;
- auto u1 = b->Slice(in0, starts, limits, strides);
- auto v1 = b->Slice(in1, starts, limits, strides);
+ auto u1 = xla::Slice(in0, starts, limits, strides);
+ auto v1 = xla::Slice(in1, starts, limits, strides);
starts.back() = 1;
limits.back() = 2;
- auto u2 = b->Slice(in0, starts, limits, strides);
- auto v2 = b->Slice(in1, starts, limits, strides);
+ auto u2 = xla::Slice(in0, starts, limits, strides);
+ auto v2 = xla::Slice(in1, starts, limits, strides);
starts.back() = 2;
limits.back() = 3;
- auto u3 = b->Slice(in0, starts, limits, strides);
- auto v3 = b->Slice(in1, starts, limits, strides);
+ auto u3 = xla::Slice(in0, starts, limits, strides);
+ auto v3 = xla::Slice(in1, starts, limits, strides);
- auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2));
- auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3));
- auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1));
- auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1);
+ auto s1 = xla::Sub(xla::Mul(u2, v3), xla::Mul(u3, v2));
+ auto s2 = xla::Sub(xla::Mul(u3, v1), xla::Mul(u1, v3));
+ auto s3 = xla::Sub(xla::Mul(u1, v2), xla::Mul(u2, v1));
+ auto output = xla::ConcatInDim(b, {s1, s2, s3}, in0_shape.dims() - 1);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index 01aa1a83e7..9ff3e02228 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -96,18 +96,16 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// First reshape the inputs, which should be a metadata-only
// operation since we are flattening the dimensions in order.
- auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape());
+ auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
+ auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
// Next broadcast the necessary input dimensions. We rely on the
// XLA optimizer to be smart about the fact that we are asking
// it to broadcast size 1 on some of these dimensions, to avoid
// adding complexity to this code.
- auto lhs_broadcast =
- builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast());
+ auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast =
- builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast());
+ auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
int rhs_size = broadcast_helper.y_bcast().size();
// Now reshape them to the correct output shape. After the
@@ -122,15 +120,15 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
lhs_reorder.push_back(i);
lhs_reorder.push_back(i + lhs_size);
}
- auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder,
- broadcast_helper.output_shape());
+ auto lhs_output =
+ xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
std::vector<int64> rhs_reorder;
for (int i = 0; i < rhs_size; ++i) {
rhs_reorder.push_back(i);
rhs_reorder.push_back(i + rhs_size);
}
- auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder,
- broadcast_helper.output_shape());
+ auto rhs_output =
+ xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
return {lhs_output, rhs_output};
}
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index 23243f6246..f314920025 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -50,7 +51,6 @@ class DepthToSpaceOp : public XlaOpKernel {
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
- xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
@@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel {
") is not divisible by square of the block size (",
block_size_, ")"));
- xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -141,7 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2],
// block_size_,
// depth / (block_size_ * block_size_)]
- xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -151,7 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2] * block_size_,
// depth / (block_size_ * block_size_)]
//
- xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 931705ba83..17bf0c069c 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -41,13 +42,13 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
xla::XlaOp iota;
TF_RETURN_IF_ERROR(
XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
- xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size});
- xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0});
+ xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size});
+ xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0});
// If this is a batched diagonal, broadcast the mask across the other
// dimensions.
if (!other_dims.empty()) {
- mask = builder->Broadcast(mask, other_dims);
+ mask = xla::Broadcast(mask, other_dims);
}
// Broadcast the input, and then use the mask computed above to select the
@@ -64,7 +65,7 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end());
broadcast_dims.push_back(1LL);
broadcast_dims.push_back(last_dim_size);
- xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims);
+ xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims);
broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
xla::PrimitiveType element_type;
@@ -74,8 +75,8 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
xla::XlaOp zeros = Zeros(builder, broadcast_shape);
- input_broadcast = builder->Add(input_broadcast, zeros);
- return builder->Select(mask, input_broadcast, zeros);
+ input_broadcast = xla::Add(input_broadcast, zeros);
+ return xla::Select(mask, input_broadcast, zeros);
}
class DiagOp : public XlaOpKernel {
@@ -104,7 +105,7 @@ class DiagOp : public XlaOpKernel {
// Flattens the input to 1D.
int64 size = input_shape.num_elements();
- input = builder->Reshape(input, {size});
+ input = xla::Reshape(input, {size});
// Create an R2 with the R1 diagonal.
auto diag_or_status =
@@ -116,7 +117,7 @@ class DiagOp : public XlaOpKernel {
std::vector<int64> new_dims(dims.size() * 2);
std::copy(dims.begin(), dims.end(), new_dims.begin());
std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size());
- diag = builder->Reshape(diag, new_dims);
+ diag = xla::Reshape(diag, new_dims);
ctx->SetOutput(0, diag);
}
@@ -170,21 +171,21 @@ class DiagPartOp : public XlaOpKernel {
// Flattens the input to 1D.
int64 size = input_shape.num_elements();
- diag = builder->Reshape(diag, {size});
+ diag = xla::Reshape(diag, {size});
// Adds padding after the last element of 'new_size'.
xla::PaddingConfig config;
auto* dim = config.add_dimensions();
dim->set_edge_padding_high(new_size);
auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = builder->Pad(diag, zero, config);
+ diag = xla::Pad(diag, zero, config);
// Reshapes so the diagonal is now in the first column.
- diag = builder->Reshape(diag, {new_size, new_size + 1});
+ diag = xla::Reshape(diag, {new_size, new_size + 1});
// Slices out the first column and reshapes to the final shape.
- diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
- diag = builder->Reshape(diag, new_dims);
+ diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
+ diag = xla::Reshape(diag, new_dims);
ctx->SetOutput(0, diag);
}
@@ -265,7 +266,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
// Collapses the last two dimensions.
std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1);
flattened_dims.back() *= dims.back();
- diag = builder->Reshape(diag, flattened_dims);
+ diag = xla::Reshape(diag, flattened_dims);
// Slices or pads the last dimension to 'target_size'.
int64 actual_size = flattened_dims.back();
@@ -276,13 +277,13 @@ class MatrixDiagPartOp : public XlaOpKernel {
auto* dim = config.mutable_dimensions(flattened_dims.size() - 1);
dim->set_edge_padding_high(target_size - actual_size);
auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = builder->Pad(diag, zero, config);
+ diag = xla::Pad(diag, zero, config);
} else if (actual_size > target_size) {
std::vector<int64> start(flattened_dims.size(), 0);
std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
std::vector<int64> strides(flattened_dims.size(), 1);
limits[flattened_dims.size() - 1] = target_size;
- diag = builder->Slice(diag, start, limits, strides);
+ diag = xla::Slice(diag, start, limits, strides);
}
// Reshape so the target values are in the first position of the last
@@ -290,18 +291,18 @@ class MatrixDiagPartOp : public XlaOpKernel {
std::vector<int64> unflattened_dims(dims.begin(), dims.end());
dims[last_dim - 1] = smaller_dim_size;
dims[last_dim] = last_dim_size + 1;
- diag = builder->Reshape(diag, dims);
+ diag = xla::Reshape(diag, dims);
// Slices out the first column and reshapes to the final shape.
std::vector<int64> start(dims.size(), 0);
std::vector<int64> limits(dims.begin(), dims.end());
std::vector<int64> strides(dims.size(), 1);
limits[last_dim] = 1;
- diag = builder->Slice(diag, start, limits, strides);
+ diag = xla::Slice(diag, start, limits, strides);
// Collapses away the last dimension.
dims.pop_back();
- diag = builder->Reshape(diag, dims);
+ diag = xla::Reshape(diag, dims);
ctx->SetOutput(0, diag);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index 0419de78b2..3b86ea34c9 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -57,8 +57,8 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
- xla::XlaOp result = ctx->builder()->DynamicUpdateSlice(
- ctx->Input(0), ctx->Input(1), ctx->Input(2));
+ xla::XlaOp result =
+ xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2));
ctx->SetOutput(0, result);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index dd4a169087..958231505b 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -150,8 +151,7 @@ class DynamicStitchOp : public XlaOpKernel {
if (new_shape == data_shapes[input_num]) {
input[input_num] = handle;
} else {
- input[input_num] =
- ctx->builder()->Reshape(handle, new_shape.dim_sizes());
+ input[input_num] = xla::Reshape(handle, new_shape.dim_sizes());
}
}
@@ -175,10 +175,10 @@ class DynamicStitchOp : public XlaOpKernel {
// And place it in the concat list in the place indicated by
// the index.
to_concat[index_num] =
- ctx->builder()->Slice(expression, slice_start, slice_limit, stride);
+ xla::Slice(expression, slice_start, slice_limit, stride);
}
- ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0));
+ ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), to_concat, 0));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
index 493781a1e6..2c76bcee25 100644
--- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -34,9 +34,9 @@ class EluOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
- const auto pred = b->Gt(ctx->Input(0), zero);
- const auto expm1 = b->Expm1(ctx->Input(0));
- ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1));
+ const auto pred = xla::Gt(ctx->Input(0), zero);
+ const auto expm1 = xla::Expm1(ctx->Input(0));
+ ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), expm1));
}
};
@@ -51,9 +51,9 @@ class EluGradOp : public XlaOpKernel {
const auto one = XlaHelpers::One(b, input_type(0));
const auto grad = ctx->Input(0);
const auto activation = ctx->Input(1);
- const auto exp_grad = b->Mul(grad, b->Add(activation, one));
- const auto pred = b->Gt(activation, zero);
- ctx->SetOutput(0, b->Select(pred, grad, exp_grad));
+ const auto exp_grad = xla::Mul(grad, xla::Add(activation, one));
+ const auto pred = xla::Gt(activation, zero);
+ ctx->SetOutput(0, xla::Select(pred, grad, exp_grad));
}
};
@@ -71,10 +71,10 @@ class SeluOp : public XlaOpKernel {
1.0507009873554804934193349852946);
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
1.7580993408473768599402175208123);
- const auto pred = b->Gt(ctx->Input(0), zero);
- const auto expm1 = b->Expm1(ctx->Input(0));
- ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)),
- b->Mul(scale_alpha, expm1)));
+ const auto pred = xla::Gt(ctx->Input(0), zero);
+ const auto expm1 = xla::Expm1(ctx->Input(0));
+ ctx->SetOutput(0, xla::Select(pred, xla::Mul(scale, ctx->Input(0)),
+ xla::Mul(scale_alpha, expm1)));
}
};
@@ -92,10 +92,10 @@ class SeluGradOp : public XlaOpKernel {
1.7580993408473768599402175208123);
const auto grad = ctx->Input(0);
const auto activation = ctx->Input(1);
- const auto lin_grad = b->Mul(grad, scale);
- const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha));
- const auto pred = b->Gt(activation, zero);
- ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad));
+ const auto lin_grad = xla::Mul(grad, scale);
+ const auto exp_grad = xla::Mul(grad, xla::Add(activation, scale_alpha));
+ const auto pred = xla::Gt(activation, zero);
+ ctx->SetOutput(0, xla::Select(pred, lin_grad, exp_grad));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
index 6df01cabbf..b2451236de 100644
--- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -114,9 +115,9 @@ class ExtractImagePatchesOp : public XlaOpKernel {
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
kernel_size * depth, &iota));
- auto lhs = builder->Reshape(iota, lhs_shape);
- auto filter = builder->ConvertElementType(
- builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
+ auto lhs = xla::Reshape(iota, lhs_shape);
+ auto filter = xla::ConvertElementType(
+ xla::Eq(lhs, iota, {num_spatial_dims + 1}), type);
xla::ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides(num_spatial_dims);
@@ -148,8 +149,8 @@ class ExtractImagePatchesOp : public XlaOpKernel {
}
xla::XlaOp conv =
- builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
- padding, lhs_dilation, rhs_dilation, dims);
+ xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
+ lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
index 8f0de0a524..2fd1a34741 100644
--- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@@ -49,20 +50,20 @@ void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
const float quant_min_value, const float quant_max_value,
xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
xla::XlaOp* scale) {
- *scale = b->Div(b->Sub(max, min),
- XlaHelpers::FloatLiteral(b, data_type,
- quant_max_value - quant_min_value));
+ *scale = xla::Div(xla::Sub(max, min),
+ XlaHelpers::FloatLiteral(
+ b, data_type, quant_max_value - quant_min_value));
xla::XlaOp quant_min =
XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
- xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale));
+ xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale));
xla::XlaOp quant_max =
XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
xla::XlaOp nudged_zero_point =
- b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
- b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
- b->Round(zero_point_from_min)));
- *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale);
- *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
+ xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min,
+ xla::Select(xla::Ge(zero_point_from_min, quant_max),
+ quant_max, xla::Round(zero_point_from_min)));
+ *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale);
+ *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale);
}
xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
@@ -71,14 +72,14 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
const xla::XlaOp& nudged_input_max,
const xla::XlaOp& input_scale) {
xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
- xla::XlaOp inv_scale = b->Div(one, input_scale);
+ xla::XlaOp inv_scale = xla::Div(one, input_scale);
xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
- xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max);
- xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min);
+ xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max);
+ xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min);
xla::XlaOp rounded =
- b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
- return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
+ xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half));
+ return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
}
class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
@@ -163,11 +164,11 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
xla::XlaOp nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
- xla::XlaOp between_nudged_min_max =
- b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
- xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type),
- gradient_shape.dim_sizes());
- xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp between_nudged_min_max = xla::And(
+ xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
+ xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type),
+ gradient_shape.dim_sizes());
+ xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output);
}
@@ -249,25 +250,25 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
- xla::XlaOp between_nudged_min_max =
- b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
+ xla::XlaOp between_nudged_min_max = xla::And(
+ xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
- xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes());
- xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes());
+ xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output0);
- xla::XlaOp below_min = b->Lt(input, nudged_input_min);
- xla::XlaOp select1 = b->Select(below_min, gradient, zeroes);
- xla::XlaOp reduce1 = b->ReduceAll(
+ xla::XlaOp below_min = xla::Lt(input, nudged_input_min);
+ xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes);
+ xla::XlaOp reduce1 = xla::ReduceAll(
XlaHelpers::ConvertElementType(b, select1, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type);
ctx->SetOutput(1, output1);
- xla::XlaOp above_max = b->Gt(input, nudged_input_max);
- xla::XlaOp select2 = b->Select(above_max, gradient, zeroes);
- xla::XlaOp reduce2 = b->ReduceAll(
+ xla::XlaOp above_max = xla::Gt(input, nudged_input_max);
+ xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes);
+ xla::XlaOp reduce2 = xla::ReduceAll(
XlaHelpers::ConvertElementType(b, select2, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index 933924cad1..b2b00e51e3 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -62,8 +63,7 @@ class GenericFftOp : public XlaOpKernel {
}
}
- xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length);
+ xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length);
ctx->SetOutput(0, fft);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index e4467a0fb1..95faa1d058 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
@@ -59,11 +60,11 @@ class FillOp : public XlaOpKernel {
xla::XlaOp data = ctx->Input(1);
if (value_shape.dims() > 0) {
CHECK_EQ(value_shape.dims(), 1);
- data = ctx->builder()->Reshape(data, {});
+ data = xla::Reshape(data, {});
}
// Emit the actual computation, which broadcasts the scalar to the
// desired shape.
- auto result = ctx->builder()->Broadcast(data, broadcast);
+ auto result = xla::Broadcast(data, broadcast);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index d13e25bcdd..5f041be5df 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -75,8 +76,8 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
out_shape.AppendShape(indices_shape_no_index_vectors);
out_shape.AppendShape(input_shape_post_axis);
- *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
- out_shape.dim_sizes());
+ *gather_output =
+ xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
return Status::OK();
}
@@ -142,7 +143,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
dim_numbers.add_gather_dims_to_operand_dims(i);
}
- *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds);
+ *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index d48c6eea75..f5fcf3cacd 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -199,13 +200,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
}
- xla::XlaOp outputs =
- b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
- b->Tuple(inputs), *else_result.computation);
+ xla::XlaOp outputs = xla::Conditional(
+ ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation,
+ xla::Tuple(b, inputs), *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
- xla::XlaOp output_handle = b->GetTupleElement(outputs, i);
+ xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i;
auto shape_or = b->GetShape(output_handle);
@@ -233,7 +234,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
- b->GetTupleElement(outputs, pos), b));
+ xla::GetTupleElement(outputs, pos), b));
}
VLOG(2) << "If variable: pos: " << update.input_index
<< " name: " << resource->name()
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 1568b33679..7a36514eb3 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -32,23 +33,26 @@ std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
auto red = rgb[0];
auto green = rgb[1];
auto blue = rgb[2];
- auto value = b->Max(b->Max(red, green), blue);
- auto minimum = b->Min(b->Min(red, green), blue);
- auto range = b->Sub(value, minimum);
-
- auto zeros = b->Broadcast(zero, shape.dim_sizes());
- auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros);
-
- auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
-
- auto hue = b->Select(b->Eq(green, value),
- b->Add(b->Mul(norm, b->Sub(blue, red)),
- XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
- b->Add(b->Mul(norm, b->Sub(red, green)),
- XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
- hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue);
- hue = b->Select(b->Gt(range, zero), hue, zeros);
- hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue);
+ auto value = xla::Max(xla::Max(red, green), blue);
+ auto minimum = xla::Min(xla::Min(red, green), blue);
+ auto range = xla::Sub(value, minimum);
+
+ auto zeros = xla::Broadcast(zero, shape.dim_sizes());
+ auto saturation =
+ xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
+
+ auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
+
+ auto hue =
+ xla::Select(xla::Eq(green, value),
+ xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
+ XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
+ xla::Add(xla::Mul(norm, xla::Sub(red, green)),
+ XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
+ hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
+ hue);
+ hue = xla::Select(xla::Gt(range, zero), hue, zeros);
+ hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
return {hue, saturation, value};
}
@@ -66,15 +70,15 @@ std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
- auto dh = b->Mul(hue, six);
- auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one);
- auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one);
- auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one);
- auto one_minus_s = b->Sub(one, saturation);
+ auto dh = xla::Mul(hue, six);
+ auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
+ auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
+ auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
+ auto one_minus_s = xla::Sub(one, saturation);
- auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value);
- auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value);
- auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value);
+ auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
+ auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
+ auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
return {red, green, blue};
}
@@ -111,7 +115,7 @@ class RGBToHSVOp : public XlaOpKernel {
auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
channel_shape);
- context->SetOutput(0, b->ConcatInDim(hsv, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
}
};
REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
@@ -147,7 +151,7 @@ class HSVToRGBOp : public XlaOpKernel {
auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
context->input_type(0));
- context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
}
};
REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
@@ -182,18 +186,20 @@ class AdjustContrastOpV2 : public XlaOpKernel {
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
auto converted =
XlaHelpers::ConvertElementType(b, input, accumulation_type);
- auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *context->GetOrCreateAdd(accumulation_type),
- {height_dim, width_dim});
+ auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *context->GetOrCreateAdd(accumulation_type),
+ {height_dim, width_dim});
auto output = XlaHelpers::ConvertElementType(b, reduce, type);
- output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width));
+ output =
+ xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width));
std::vector<int64> broadcast_dims(input_shape.dims() - 2);
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
broadcast_dims.back() = channel_dim;
- output = b->Add(b->Mul(input, factor),
- b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)),
- broadcast_dims);
+ output =
+ xla::Add(xla::Mul(input, factor),
+ xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
+ broadcast_dims);
context->SetOutput(0, output);
}
};
@@ -240,12 +246,12 @@ class AdjustSaturationOp : public XlaOpKernel {
auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
channel_shape);
- hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale),
- XlaHelpers::One(b, type));
+ hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale),
+ XlaHelpers::One(b, type));
auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0));
- context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
}
};
REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
@@ -294,12 +300,13 @@ class AdjustHueOp : public XlaOpKernel {
auto one = XlaHelpers::One(b, type);
auto& hue = hsv[0];
- hue = b->Rem(b->Add(hsv[0], delta), one);
- hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue);
+ hue = xla::Rem(xla::Add(hsv[0], delta), one);
+ hue =
+ xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0));
- context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
}
};
REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index 79d3a6979c..de971ce4ac 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/math/math_util.h"
@@ -132,17 +133,16 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
TF_CHECK_OK(
XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
- auto diag = builder->ConvertElementType(
- builder->Eq(
- builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1,
+ auto diag = xla::ConvertElementType(
+ xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
2 * kernel_size[1] - 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
+ channels_iota, /*broadcast_dimensions=*/{2}),
xla::PrimitiveType::F32);
- return builder->Mul(
- builder->Mul(diag,
- builder->ConstantR1<float>(Make1DKernel(kernel_size[1])),
- /*broadcast_dimensions=*/{1}),
- builder->ConstantR1<float>(Make1DKernel(kernel_size[0])),
+ return xla::Mul(
+ xla::Mul(diag,
+ xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ /*broadcast_dimensions=*/{1}),
+ xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
/*broadcast_dimensions=*/{0});
}
@@ -154,21 +154,21 @@ xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
TF_CHECK_OK(
XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
- auto diag = builder->ConvertElementType(
- builder->Eq(builder->Broadcast(
- channels_iota,
- {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
- dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
+ auto diag = xla::ConvertElementType(
+ xla::Eq(
+ xla::Broadcast(channels_iota,
+ {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+ dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
+ channels_iota, /*broadcast_dimensions=*/{2}),
xla::PrimitiveType::F32);
if (dim == 1) {
- return builder->Mul(
- diag, builder->ConstantR1<float>(Make1DKernel(kernel_size[1])),
+ return xla::Mul(
+ diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
/*broadcast_dimensions=*/{1});
}
- return builder->Mul(diag,
- builder->ConstantR1<float>(Make1DKernel(kernel_size[0])),
- /*broadcast_dimensions=*/{0});
+ return xla::Mul(diag,
+ xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+ /*broadcast_dimensions=*/{0});
}
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
@@ -208,7 +208,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
- output = builder->ConvGeneralDilated(
+ output = xla::ConvGeneralDilated(
input, kernel, dims.stride,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
@@ -218,7 +218,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
- output = builder->ConvGeneralDilated(
+ output = xla::ConvGeneralDilated(
input, kernel0, {dims.stride[0], 1},
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
@@ -226,7 +226,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*rhs_dilation=*/{1, 1}, dimension_numbers);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
- output = builder->ConvGeneralDilated(
+ output = xla::ConvGeneralDilated(
output, kernel1, {1, dims.stride[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
@@ -238,8 +238,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
// size > 1 dimension.
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] == 1 && out_size[i] > 1) {
- output = builder->Add(output, builder->ConstantR1<float>(out_size[i], 0),
- /*broadcast_dimensions=*/{1 + i});
+ output = xla::Add(output, xla::ConstantR1<float>(builder, out_size[i], 0),
+ /*broadcast_dimensions=*/{1 + i});
}
}
return output;
@@ -279,12 +279,12 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] == 1 && grad_size[i] > 1) {
kernel =
- builder->Add(kernel, builder->ConstantR1<float>(grad_size[i], 0),
- /*broadcast_dimensions=*/{i});
+ xla::Add(kernel, xla::ConstantR1<float>(builder, grad_size[i], 0),
+ /*broadcast_dimensions=*/{i});
}
}
- output = builder->ConvGeneralDilated(
+ output = xla::ConvGeneralDilated(
grad, kernel, /*window_strides=*/dims.kernel_size,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
@@ -302,23 +302,23 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
// gradient contributions in that dimension.
if (in_size[0] == 1 && grad_size[0] > 1) {
kernel0 =
- builder->Add(kernel0, builder->ConstantR1<float>(grad_size[0], 0),
- /*broadcast_dimensions=*/{0});
+ xla::Add(kernel0, xla::ConstantR1<float>(builder, grad_size[0], 0),
+ /*broadcast_dimensions=*/{0});
}
if (in_size[1] == 1 && grad_size[1] > 1) {
kernel1 =
- builder->Add(kernel0, builder->ConstantR1<float>(grad_size[1], 0),
- /*broadcast_dimensions=*/{1});
+ xla::Add(kernel0, xla::ConstantR1<float>(builder, grad_size[1], 0),
+ /*broadcast_dimensions=*/{1});
}
- output = builder->ConvGeneralDilated(
+ output = xla::ConvGeneralDilated(
grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.stride[0], 1},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
- output = builder->ConvGeneralDilated(
+ output = xla::ConvGeneralDilated(
output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
@@ -337,7 +337,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
}
}
if (pad_output) {
- output = builder->Pad(output, builder->ConstantR0<float>(0.0f), padding);
+ output = xla::Pad(output, xla::ConstantR0<float>(builder, 0.0f), padding);
}
return output;
}
@@ -393,13 +393,13 @@ class ResizeBilinearOp : public XlaOpKernel {
}
}
if (slice_input) {
- input = b->Slice(input, {0, 0, 0, 0},
- {batch, slice_size[0], slice_size[1], channels},
- {1, 1, 1, 1});
+ input = xla::Slice(input, {0, 0, 0, 0},
+ {batch, slice_size[0], slice_size[1], channels},
+ {1, 1, 1, 1});
}
// Output is always type float.
- input = b->ConvertElementType(input, xla::F32);
+ input = xla::ConvertElementType(input, xla::F32);
// Special Case:
// Instead of doing a ResizeUsingDilationAndConvolution directly,
@@ -529,7 +529,7 @@ class ResizeBilinearGradOp : public XlaOpKernel {
}
}
- output = b->ConvertElementType(output, output_type_);
+ output = xla::ConvertElementType(output, output_type_);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 2c2d88486f..a020ebc729 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
// XLA passes <out> to the function, so it is not included here.
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
+ args.push_back(xla::ConstantLiteral(
+ &b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
- args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim)));
+ args.push_back(xla::ConstantLiteral(
+ &b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
+ args.push_back(
+ xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
@@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
- output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape);
+ output =
+ xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape);
break;
case 2:
- output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape);
+ output =
+ xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape);
break;
default:
OP_REQUIRES(ctx, false,
diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
index 1decf7d72d..9e64711051 100644
--- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
@@ -39,12 +39,12 @@ class L2LossOp : public XlaOpKernel {
const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
auto t =
XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type);
- auto square = b->Mul(t, t);
- auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), dims);
+ auto square = xla::Mul(t, t);
+ auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), dims);
auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype);
auto two = XlaHelpers::IntegerLiteral(b, dtype, 2);
- ctx->SetOutput(0, b->Div(deconverted, two));
+ ctx->SetOutput(0, xla::Div(deconverted, two));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
index 0388b4c830..2fb072f827 100644
--- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -90,8 +91,10 @@ class ListDiffOp : public XlaOpKernel {
idx_output.push_back(i);
}
- context->SetOutput(0, context->builder()->ConstantR1<Tval>(val_output));
- context->SetOutput(1, context->builder()->ConstantR1<Tidx>(idx_output));
+ context->SetOutput(0,
+ xla::ConstantR1<Tval>(context->builder(), val_output));
+ context->SetOutput(1,
+ xla::ConstantR1<Tidx>(context->builder(), idx_output));
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
index 39fbf98a62..dc934543cb 100644
--- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -50,8 +51,8 @@ class LRNOp : public XlaOpKernel {
auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
XlaHelpers::ConvertElementType(builder, input, accumulation_type);
- auto squared = builder->Mul(converted, converted);
- auto reduce = builder->ReduceWindow(
+ auto squared = xla::Mul(converted, converted);
+ auto reduce = xla::ReduceWindow(
squared, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
@@ -59,12 +60,12 @@ class LRNOp : public XlaOpKernel {
auto sqr_sum =
XlaHelpers::ConvertElementType(builder, reduce, input_type(0));
- auto scale = builder->Pow(
- builder->Add(builder->ConstantR0<float>(bias_),
- builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)),
- builder->ConstantR0<float>(-beta_));
+ auto scale = xla::Pow(
+ xla::Add(xla::ConstantR0<float>(builder, bias_),
+ xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum)),
+ xla::ConstantR0<float>(builder, -beta_));
- ctx->SetOutput(0, builder->Mul(input, scale));
+ ctx->SetOutput(0, xla::Mul(input, scale));
}
private:
@@ -138,8 +139,8 @@ class LRNGradOp : public XlaOpKernel {
auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
XlaHelpers::ConvertElementType(builder, in_image, accumulation_type);
- auto squared = builder->Mul(converted, converted);
- auto reduce = builder->ReduceWindow(
+ auto squared = xla::Mul(converted, converted);
+ auto reduce = xla::ReduceWindow(
squared, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
@@ -148,17 +149,17 @@ class LRNGradOp : public XlaOpKernel {
XlaHelpers::ConvertElementType(builder, reduce, input_type(0));
auto norm =
- builder->Add(builder->ConstantR0<float>(bias_),
- builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum));
+ xla::Add(xla::ConstantR0<float>(builder, bias_),
+ xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum));
- auto dy = builder->Mul(
- builder->Mul(builder->ConstantR0<float>(-2.0f * alpha_ * beta_),
- builder->Div(out_image, norm)),
+ auto dy = xla::Mul(
+ xla::Mul(xla::ConstantR0<float>(builder, -2.0f * alpha_ * beta_),
+ xla::Div(out_image, norm)),
in_grads);
auto converted_dy =
XlaHelpers::ConvertElementType(builder, dy, accumulation_type);
- auto dy_reduce = builder->ReduceWindow(
+ auto dy_reduce = xla::ReduceWindow(
converted_dy, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
@@ -166,10 +167,10 @@ class LRNGradOp : public XlaOpKernel {
auto dy_reduced =
XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0));
- xla::XlaOp gradients = builder->Add(
- builder->Mul(in_image, dy_reduced),
- builder->Mul(in_grads,
- builder->Pow(norm, builder->ConstantR0<float>(-beta_))));
+ xla::XlaOp gradients = xla::Add(
+ xla::Mul(in_image, dy_reduced),
+ xla::Mul(in_grads,
+ xla::Pow(norm, xla::ConstantR0<float>(builder, -beta_))));
ctx->SetOutput(0, gradients);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
index 6949b296f4..844080b8cf 100644
--- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
@@ -70,15 +71,15 @@ class MatMulOp : public XlaOpKernel {
xla::XlaOp b = ctx->Input(1);
if (is_sparse_) {
if (a_type_ == DT_BFLOAT16) {
- a = ctx->builder()->ConvertElementType(a, xla::F32);
+ a = xla::ConvertElementType(a, xla::F32);
}
if (b_type_ == DT_BFLOAT16) {
- b = ctx->builder()->ConvertElementType(b, xla::F32);
+ b = xla::ConvertElementType(b, xla::F32);
}
}
- auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a;
- auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b;
- ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs));
+ auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a;
+ auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b;
+ ctx->SetOutput(0, xla::Dot(lhs, rhs));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index fbd5dc0fda..9d3575e331 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@@ -64,27 +65,26 @@ class MatrixBandPartOp : public XlaOpKernel {
xla::XlaOp iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
- auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
- /*broadcast_dimensions=*/{0});
+ auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m,
+ /*broadcast_dimensions=*/{0});
// If num_lower or num_upper are negative, include all lower/upper
// diagonals.
auto zero_index = XlaHelpers::Zero(builder, index_type);
- num_lower = builder->Select(
- builder->Lt(num_lower, zero_index),
- XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
- num_upper = builder->Select(
- builder->Lt(num_upper, zero_index),
- XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
+ num_lower = xla::Select(xla::Lt(num_lower, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, m),
+ num_lower);
+ num_upper = xla::Select(xla::Lt(num_upper, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, n),
+ num_upper);
- auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
- builder->Le(offset, num_upper));
- indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+ auto indicator = xla::And(xla::Le(xla::Neg(num_lower), offset),
+ xla::Le(offset, num_upper));
+ indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
auto zero_input = XlaHelpers::Zero(builder, input_type);
- auto output = builder->Select(
- indicator, input,
- builder->Broadcast(zero_input, input_shape.dim_sizes()));
+ auto output = xla::Select(
+ indicator, input, xla::Broadcast(zero_input, input_shape.dim_sizes()));
context->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index db53f6fef8..7bf1894ea0 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -65,10 +66,9 @@ class MatrixSetDiagOp : public XlaOpKernel {
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
xla::XlaOp iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
- auto indicator = builder->Eq(iota_m,
- builder->Broadcast(iota_n, {m}),
- /*broadcast_dimensions=*/{0});
- indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+ auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}),
+ /*broadcast_dimensions=*/{0});
+ indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
// Broadcast diag up to the input shape. Use an implicit broadcast (Add)
// because we need to broadcast on the right.
@@ -77,10 +77,10 @@ class MatrixSetDiagOp : public XlaOpKernel {
if (min_dim != m) {
diag_broadcast_dims.back() = rank - 1;
}
- diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()),
- /*broadcast_dimensions=*/diag_broadcast_dims);
+ diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
+ /*broadcast_dimensions=*/diag_broadcast_dims);
- auto output = builder->Select(indicator, diag, input);
+ auto output = xla::Select(indicator, diag, input);
context->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index c3326b4d11..caeba66b52 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
namespace tensorflow {
@@ -32,7 +33,7 @@ class MirrorPadOp : public XlaOpKernel {
xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
- auto t_rev = b->Rev(accum, {dimno});
+ auto t_rev = xla::Rev(accum, {dimno});
TF_ASSIGN_OR_RETURN(int64 lhs_padding,
pad_literal.GetIntegralAsS64({dimno, 0}));
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
@@ -41,7 +42,7 @@ class MirrorPadOp : public XlaOpKernel {
auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding,
dim_size - 1, 1, dimno);
auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
- accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno);
+ accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno);
}
return accum;
}
diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
index aecaabb6dc..3aed47de26 100644
--- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -76,11 +77,10 @@ class PackOp : public XlaOpKernel {
for (int i = 0; i < num; ++i) {
// Reshape the inputs to have an extra dimension of size 1.
- reshaped_inputs[i] =
- ctx->builder()->Reshape(values[i], child_shape.dim_sizes());
+ reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes());
}
- ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis));
+ ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index 17b85338f7..89fd610bc6 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
@@ -74,11 +75,10 @@ class PadOp : public XlaOpKernel {
if (ctx->num_inputs() == 3) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)),
errors::InvalidArgument("constant_values must be a scalar."));
- ctx->SetOutput(0,
- ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config));
+ ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config));
} else {
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
- ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config));
+ ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config));
}
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index eb8b5b130f..771dcbab21 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -113,8 +114,8 @@ class PoolingOp : public XlaOpKernel {
xla::XlaBuilder* const b = ctx->builder();
auto input =
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
- auto reduce = ctx->builder()->ReduceWindow(
- input, InitValue(b), *Reduction(ctx), ksize, stride, padding_);
+ auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize,
+ stride, padding_);
auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
ctx->SetOutput(0,
PostProcessOutput(ctx, pooled, input_type(0), input_shape));
@@ -190,7 +191,7 @@ static xla::XlaOp AvgPoolDivideByCount(
auto divisor =
XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
- return ctx->builder()->Div(output, divisor);
+ return xla::Div(output, divisor);
} else {
// For SAME padding, the padding shouldn't be included in the
// counts. We use another ReduceWindow to find the right counts.
@@ -212,18 +213,18 @@ static xla::XlaOp AvgPoolDivideByCount(
// Build a matrix of all 1s, with the same width/height as the input.
const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto ones = ctx->builder()->Broadcast(
+ auto ones = xla::Broadcast(
XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes);
// Perform a ReduceWindow with the same window size, strides, and padding
// to count the number of contributions to each result element.
- auto reduce = ctx->builder()->ReduceWindow(
+ auto reduce = xla::ReduceWindow(
ones, XlaHelpers::Zero(ctx->builder(), accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride,
xla::Padding::kSame);
auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype);
- return ctx->builder()->Div(output, counts, window_dims);
+ return xla::Div(output, counts, window_dims);
}
}
@@ -347,9 +348,9 @@ class MaxPoolGradOp : public XlaOpKernel {
xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
auto select = CreateScalarGeComputation(element_type, ctx->builder());
auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
- xla::XlaOp gradients = ctx->builder()->SelectAndScatter(
- input, select, ksize_, stride_, xla_padding, out_backprop, init_value,
- scatter);
+ xla::XlaOp gradients =
+ xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
+ out_backprop, init_value, scatter);
ctx->SetOutput(0, gradients);
}
@@ -485,12 +486,12 @@ class AvgPoolGradOp : public XlaOpKernel {
}
auto zero = XlaHelpers::Zero(b, dtype);
- auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config);
+ auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config);
// in_backprop = padded_gradients <conv> ones
std::vector<int64> ones(num_dims(), 1LL);
auto accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto in_backprop = b->ReduceWindow(
+ auto in_backprop = xla::ReduceWindow(
XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), ksize_,
@@ -614,17 +615,18 @@ class MaxPoolGradGradOp : public XlaOpKernel {
auto b = ctx->builder();
- auto sixteen = b->ConstantR0<uint32>(16);
+ auto sixteen = xla::ConstantR0<uint32>(b, 16);
// in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32
- auto in_hi = b->BitcastConvertType(
- b->ConvertElementType(b->ConvertElementType(input, xla::BF16),
- xla::F32),
+ auto in_hi = xla::BitcastConvertType(
+ xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16),
+ xla::F32),
xla::U32);
- auto bp_int = b->BitcastConvertType(out_backprop, xla::U32);
- auto bp_hi = b->ShiftRightLogical(bp_int, sixteen);
- auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen);
- auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add.
- auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add.
+ auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
+ auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
+ auto bp_lo =
+ xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
+ auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
+ auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
auto init_value = XlaHelpers::MinValue(b, DT_FLOAT);
// We will reduce by taking the maximal value up to 16 bits (ignoring the lo
@@ -633,39 +635,41 @@ class MaxPoolGradGradOp : public XlaOpKernel {
{
// F32 parameters to satisfy lowering type restriction for reduce opcode.
const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
- auto lhs = rb->Parameter(0, scalar, "lhs");
- auto rhs = rb->Parameter(1, scalar, "rhs");
- auto sixteen = rb->ConstantR0<int32>(16);
- auto lhs_criteria = rb->ShiftLeft(
- rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen),
- sixteen);
- auto rhs_criteria = rb->ShiftLeft(
- rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen),
- sixteen);
+ auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
+ auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
+ auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
+ auto lhs_criteria =
+ xla::ShiftLeft(xla::ShiftRightLogical(
+ xla::BitcastConvertType(lhs, xla::S32), sixteen),
+ sixteen);
+ auto rhs_criteria =
+ xla::ShiftLeft(xla::ShiftRightLogical(
+ xla::BitcastConvertType(rhs, xla::S32), sixteen),
+ sixteen);
// Must use a F32 comparison, because S32 would not work for negatives.
- rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32),
- rb->BitcastConvertType(rhs_criteria, xla::F32)),
- lhs, rhs);
+ xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
+ xla::BitcastConvertType(rhs_criteria, xla::F32)),
+ lhs, rhs);
}
auto reduce = rb->BuildAndNoteError();
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
auto pooled_hi =
- b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32),
- init_value, reduce, ksize_, stride_, xla_padding);
+ xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
+ init_value, reduce, ksize_, stride_, xla_padding);
auto pooled_lo =
- b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32),
- init_value, reduce, ksize_, stride_, xla_padding);
+ xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
+ init_value, reduce, ksize_, stride_, xla_padding);
auto grads_hi =
- b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen);
- auto grads_lo = b->ShiftRightLogical(
- b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen),
+ xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
+ auto grads_lo = xla::ShiftRightLogical(
+ xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
sixteen);
- auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add.
+ auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add.
xla::PrimitiveType element_type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
- ctx->SetOutput(0, b->BitcastConvertType(grads, element_type));
+ ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
}
protected:
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 661cd5923e..9576354c5f 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@@ -58,11 +59,11 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type);
const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type);
input_min =
- b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
+ xla::ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
input_max =
- b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
+ xla::ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
}
- xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max));
+ xla::XlaOp m = xla::Max(xla::Abs(input_min), xla::Abs(input_max));
// Next, we choose our fixed-point quantization buckets, [min_fixed,
// max_fixed]. If signed_input is true, this is
@@ -86,14 +87,14 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
//
// s = (max_fixed - min_fixed) / (2 * m).
xla::XlaOp s =
- b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
- b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
+ xla::Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
+ xla::Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
// Now we can quantize and dequantize the elements of our tensor. An element
// e is transformed into e':
//
// e' = (e * s).round_to_nearest() / s.
- xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s);
+ xla::XlaOp result = xla::Div(xla::Round(xla::Mul(input, s)), s);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 3bab4ae917..51f2cdc9f4 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -46,8 +47,8 @@ class RandomUniformOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype),
- XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -79,8 +80,8 @@ class RandomShuffleOp : public XlaOpKernel {
// Generate the random swaps for the indices.
auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
auto swaps =
- builder->RngUniform(builder->ConstantR0<int32>(0),
- builder->ConstantR0<int32>(n), swaps_shape);
+ xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
+ xla::ConstantR0<int32>(builder, n), swaps_shape);
// Generate range(n) as the initial value for the indices to be swapped.
xla::XlaOp indices;
@@ -93,17 +94,17 @@ class RandomShuffleOp : public XlaOpKernel {
-> xla::StatusOr<std::vector<xla::XlaOp>> {
auto swaps = loop_vars[0];
auto indices = loop_vars[1];
- i = builder->Reshape(i, {1});
+ i = xla::Reshape(i, {1});
// temp = indices[i]
- auto temp = builder->DynamicSlice(indices, i, {1});
+ auto temp = xla::DynamicSlice(indices, i, {1});
// swap_index = swaps[i]
- auto swap_index = builder->DynamicSlice(swaps, i, {1});
+ auto swap_index = xla::DynamicSlice(swaps, i, {1});
// swap_value = indices[swaps[i]]
- auto swap_value = builder->DynamicSlice(indices, swap_index, {1});
+ auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
// indices[i] = indices[swaps[i]]
- indices = builder->DynamicUpdateSlice(indices, swap_value, i);
+ indices = xla::DynamicUpdateSlice(indices, swap_value, i);
// indices[swaps[i]] = temp
- indices = builder->DynamicUpdateSlice(indices, temp, swap_index);
+ indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
return std::vector<xla::XlaOp>{swaps, indices};
};
// for i in range(n):
@@ -153,7 +154,7 @@ class RandomUniformIntOp : public XlaOpKernel {
auto minval = ctx->Input(1);
auto maxval = ctx->Input(2);
- ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape));
+ ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape));
}
private:
@@ -179,8 +180,8 @@ class RandomStandardNormalOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
// Normal distribution with a mean of 0 and a standard deviation of 1:
- xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype),
- XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -209,7 +210,7 @@ class TruncatedNormalOp : public XlaOpKernel {
xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
xla::XlaOp min_positive =
XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min());
- auto uniform = b->RngUniform(min_positive, one, xla_shape);
+ auto uniform = xla::RngUniform(min_positive, one, xla_shape);
ctx->SetOutput(0, TruncatedNormal(dtype, uniform));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index 08894489ac..76bd1e62aa 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -98,10 +99,10 @@ class ReduceWindowOp : public XlaOpKernel {
{
std::unique_ptr<xla::XlaBuilder> cb =
builder->CreateSubBuilder("wrapper");
- auto x = cb->Parameter(0, scalar_shape, "x");
- auto y = cb->Parameter(1, scalar_shape, "y");
- auto outputs = cb->Call(*reducer.computation, {x, y});
- cb->GetTupleElement(outputs, 0);
+ auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x");
+ auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y");
+ auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y});
+ xla::GetTupleElement(outputs, 0);
xla::StatusOr<xla::XlaComputation> result = cb->Build();
OP_REQUIRES_OK(context, result.status());
wrapper = std::move(result.ValueOrDie());
@@ -112,7 +113,7 @@ class ReduceWindowOp : public XlaOpKernel {
padding[i] = {padding_low_[i], padding_high_[i]};
}
- xla::XlaOp output = builder->ReduceWindowWithGeneralPadding(
+ xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), wrapper, window_dimensions_,
window_strides_, padding);
context->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 0f42563779..d3573bac3d 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -35,7 +36,7 @@ class SumOp : public XlaReductionOp {
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Add(scalar_lhs, scalar_rhs);
+ xla::Add(scalar_lhs, scalar_rhs);
}
};
@@ -53,7 +54,7 @@ class ProdOp : public XlaReductionOp {
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Mul(scalar_lhs, scalar_rhs);
+ xla::Mul(scalar_lhs, scalar_rhs);
}
};
@@ -71,7 +72,7 @@ class MinOp : public XlaReductionOp {
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Min(scalar_lhs, scalar_rhs);
+ xla::Min(scalar_lhs, scalar_rhs);
}
};
@@ -88,7 +89,7 @@ class MaxOp : public XlaReductionOp {
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Max(scalar_lhs, scalar_rhs);
+ xla::Max(scalar_lhs, scalar_rhs);
}
};
@@ -105,7 +106,7 @@ class MeanOp : public XlaReductionOp {
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Add(scalar_lhs, scalar_rhs);
+ xla::Add(scalar_lhs, scalar_rhs);
}
xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder,
@@ -113,7 +114,7 @@ class MeanOp : public XlaReductionOp {
int64 num_elements_reduced) override {
auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
num_elements_reduced);
- return builder->Div(reduce_output, divisor);
+ return xla::Div(reduce_output, divisor);
}
};
@@ -126,12 +127,12 @@ class AllOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return builder->ConstantR0<bool>(true);
+ return xla::ConstantR0<bool>(builder, true);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->And(scalar_lhs, scalar_rhs);
+ xla::And(scalar_lhs, scalar_rhs);
}
};
@@ -143,12 +144,12 @@ class AnyOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return builder->ConstantR0<bool>(false);
+ return xla::ConstantR0<bool>(builder, false);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Or(scalar_lhs, scalar_rhs);
+ xla::Or(scalar_lhs, scalar_rhs);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 44510c731e..14506d65c4 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -101,20 +102,20 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
- auto data = b->ConvertElementType(ctx->Input(0), type);
+ auto data = xla::ConvertElementType(ctx->Input(0), type);
// Call virtual method to get the initial value.
- auto initial = b->ConvertElementType(InitialValue(b), type);
+ auto initial = xla::ConvertElementType(InitialValue(b), type);
// Make two scalar parameters of the desired type for the lambda.
- auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
- auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
+ auto rx = xla::Parameter(&r, 0, xla::ShapeUtil::MakeShape(type, {}), "x");
+ auto ry = xla::Parameter(&r, 1, xla::ShapeUtil::MakeShape(type, {}), "y");
// Call virtual method to build the reduction lambda.
BuildReducer(&r, rx, ry);
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
- auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes);
+ auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes);
auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced);
- auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized;
+ auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index ba7d484d53..a4ba6c748a 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -34,7 +34,7 @@ class ReluOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
- ctx->SetOutput(0, builder->Max(zero, ctx->Input(0)));
+ ctx->SetOutput(0, xla::Max(zero, ctx->Input(0)));
}
};
@@ -46,7 +46,7 @@ class Relu6Op : public XlaOpKernel {
xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6);
- ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six));
+ ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six));
}
};
@@ -59,9 +59,9 @@ class ReluGradOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
- b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
- const auto pred = b->Gt(ctx->Input(1), zero);
- ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero));
+ xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto pred = xla::Gt(ctx->Input(1), zero);
+ ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero));
}
};
@@ -74,12 +74,12 @@ class Relu6GradOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
- b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
- const auto six = b->Broadcast(
+ xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto six = xla::Broadcast(
XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes());
- auto out =
- b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)),
- ctx->Input(0), zero);
+ auto out = xla::Select(
+ xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)),
+ ctx->Input(0), zero);
ctx->SetOutput(0, out);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index af4d64b159..e0ca8dd8e2 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -90,8 +91,7 @@ class ReshapeOp : public XlaOpKernel {
VLOG(1) << "Reshape " << input_shape.DebugString() << " "
<< shape.DebugString();
- ctx->SetOutput(0,
- ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes()));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index a711278638..db7ea775e2 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -69,8 +69,7 @@ class RetvalOp : public XlaOpKernel {
xla::XlaOp output = input;
if (tc.is_entry_computation()) {
- output =
- ctx->builder()->Reshape(input, representation_shape.dim_sizes());
+ output = xla::Reshape(input, representation_shape.dim_sizes());
} else {
// The core from which a return value is returned depends on the
// device assignment of the input to the retval. Since we can't change
@@ -78,8 +77,8 @@ class RetvalOp : public XlaOpKernel {
// introduce an operator here, even if the shape does not change.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
- output = ctx->builder()->GetTupleElement(
- ctx->builder()->Tuple({output}), 0);
+ output =
+ xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0);
}
tc.AddRetval(index_, dtype_, shape, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index 2872a3c4d4..037c422258 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -62,7 +63,7 @@ class ReverseOp : public XlaOpKernel {
}
}
- ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions));
+ ctx->SetOutput(0, xla::Rev(ctx->Input(0), dimensions));
}
};
@@ -100,7 +101,7 @@ class ReverseV2Op : public XlaOpKernel {
x_shape.dims(), ")."));
}
- ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes));
+ ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 5d1c052684..16491002b4 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@@ -85,88 +86,83 @@ class ReverseSequenceOp : public XlaOpKernel {
auto condition_builder =
builder->CreateSubBuilder("reverse_sequence_condition");
{
- auto param = condition_builder->Parameter(0, tuple_shape, "param");
- auto i = condition_builder->GetTupleElement(param, 0);
- condition_builder->Lt(
- i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type,
- batch_size));
+ auto param =
+ xla::Parameter(condition_builder.get(), 0, tuple_shape, "param");
+ auto i = xla::GetTupleElement(param, 0);
+ xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(),
+ seq_lens_type, batch_size));
}
auto condition = condition_builder->Build();
OP_REQUIRES_OK(context, condition.status());
auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
{
- auto param = body_builder->Parameter(0, tuple_shape, "param");
- auto i = body_builder->GetTupleElement(param, 0);
- auto seq_lens = body_builder->GetTupleElement(param, 1);
- auto output = body_builder->GetTupleElement(param, 2);
+ auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param");
+ auto i = xla::GetTupleElement(param, 0);
+ auto seq_lens = xla::GetTupleElement(param, 1);
+ auto output = xla::GetTupleElement(param, 2);
// seq_len is the sequence length of the current batch element (rank 1)
- auto seq_len = body_builder->DynamicSlice(
- seq_lens, body_builder->Reshape(i, {1}), {1});
+ auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1});
// Indices is the offset of the batch element in the input.
- auto batch_element_indices = body_builder->Broadcast(
- XlaHelpers::Zero(body_builder.get(), seq_lens_type),
- {input_shape.dims()});
- batch_element_indices = body_builder->DynamicUpdateSlice(
- batch_element_indices, body_builder->Reshape(i, {1}),
- body_builder->Reshape(
- XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
- batch_dim_),
- {1}));
+ auto batch_element_indices =
+ xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
+ {input_shape.dims()});
+ batch_element_indices = xla::DynamicUpdateSlice(
+ batch_element_indices, xla::Reshape(i, {1}),
+ xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
+ seq_lens_type, batch_dim_),
+ {1}));
// Slice out the current batch element and pad it out in the sequence
// dimension.
TensorShape slice_shape = input_shape;
slice_shape.set_dim(batch_dim_, 1);
slice_shape.set_dim(seq_dim_, max_seq_len);
- auto slice = body_builder->DynamicSlice(output, batch_element_indices,
- slice_shape.dim_sizes());
+ auto slice = xla::DynamicSlice(output, batch_element_indices,
+ slice_shape.dim_sizes());
auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims());
padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
slice_shape.dim_size(seq_dim_));
- slice = body_builder->Pad(
- slice, XlaHelpers::Zero(body_builder.get(), input_type),
- padding_config);
+ slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type),
+ padding_config);
// Now slice out the reversed sequence from its actual start.
// sequence_start_indices is the offset of the start of the reversed
// sequence in the input. The slice will go into the padding, however, we
// will mask off these elements and replace them with elements from the
// original input so their values do not matter.
- auto sequence_start_indices = body_builder->Broadcast(
- XlaHelpers::Zero(body_builder.get(), seq_lens_type),
- {slice_shape.dims()});
- sequence_start_indices = body_builder->DynamicUpdateSlice(
+ auto sequence_start_indices =
+ xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
+ {slice_shape.dims()});
+ sequence_start_indices = xla::DynamicUpdateSlice(
sequence_start_indices,
- body_builder->Sub(XlaHelpers::IntegerLiteral(
- body_builder.get(), seq_lens_type, max_seq_len),
- seq_len),
- body_builder->Reshape(
- XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
- seq_dim_),
- {1}));
- slice = body_builder->DynamicSlice(slice, sequence_start_indices,
- slice_shape.dim_sizes());
+ xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
+ max_seq_len),
+ seq_len),
+ xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
+ seq_lens_type, seq_dim_),
+ {1}));
+ slice = xla::DynamicSlice(slice, sequence_start_indices,
+ slice_shape.dim_sizes());
// Shift the reversed sequence to the left.
- output = body_builder->DynamicUpdateSlice(output, slice,
- batch_element_indices);
+ output = xla::DynamicUpdateSlice(output, slice, batch_element_indices);
- body_builder->Tuple(
- {body_builder->Add(
- i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
+ xla::Tuple(
+ body_builder.get(),
+ {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
seq_lens, output});
}
auto body = body_builder->Build();
OP_REQUIRES_OK(context, body.status());
- auto loop_output = builder->While(
+ auto loop_output = xla::While(
condition.ValueOrDie(), body.ValueOrDie(),
- builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
- builder->Rev(input, {seq_dim_})}));
- auto output = builder->GetTupleElement(loop_output, 2);
+ xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
+ xla::Rev(input, {seq_dim_})}));
+ auto output = xla::GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
xla::XlaOp iota;
@@ -174,14 +170,13 @@ class ReverseSequenceOp : public XlaOpKernel {
context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
std::vector<int64> dims(input_shape.dims(), 1);
dims[batch_dim_] = batch_size;
- auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_});
+ auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_});
// Broadcast the mask up to the input shape.
- mask =
- builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false),
- input_shape.dim_sizes()));
+ mask = xla::Or(mask, xla::Broadcast(xla::ConstantR0<bool>(builder, false),
+ input_shape.dim_sizes()));
- output = builder->Select(mask, output, input);
+ output = xla::Select(mask, output, input);
context->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 1819fb5433..9247c7029a 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -100,7 +101,7 @@ class ScanOp : public XlaOpKernel {
init = XlaHelpers::One(builder, dtype);
reducer = ctx->GetOrCreateMul(dtype);
}
- auto output = builder->ReduceWindowWithGeneralPadding(
+ auto output = xla::ReduceWindowWithGeneralPadding(
XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
*reducer, window_dims, window_strides, padding);
output =
diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
index f2c63b4f90..14709bb6cb 100644
--- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -103,8 +104,8 @@ class ScatterNdOp : public XlaOpKernel {
updates_shape));
xla::XlaBuilder* builder = context->builder();
- auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
- buffer_shape.dim_sizes());
+ auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype),
+ buffer_shape.dim_sizes());
auto indices = context->Input(0);
auto updates = context->Input(1);
auto result =
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index ff14483347..db7e559420 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -75,7 +75,7 @@ class UnsortedSegmentReduce : public XlaOpKernel {
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
auto buffer =
- builder->Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
+ xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
xla::XlaBuilder* builder) {
@@ -102,7 +102,7 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce {
};
xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
xla::XlaBuilder* builder) override {
- return builder->Add(a, b);
+ return xla::Add(a, b);
};
};
@@ -120,7 +120,7 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce {
};
xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
xla::XlaBuilder* builder) override {
- return builder->Mul(a, b);
+ return xla::Mul(a, b);
};
};
@@ -138,7 +138,7 @@ class UnsortedSegmentMin : public UnsortedSegmentReduce {
};
xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
xla::XlaBuilder* builder) override {
- return builder->Min(a, b);
+ return xla::Min(a, b);
};
};
@@ -156,7 +156,7 @@ class UnsortedSegmentMax : public UnsortedSegmentReduce {
};
xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b,
xla::XlaBuilder* builder) override {
- return builder->Max(a, b);
+ return xla::Max(a, b);
};
};
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index f9f48164d6..5c010c9df2 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -40,8 +41,6 @@ class SelectOp : public XlaOpKernel {
"'then' and 'else' must have the same size. but received: ",
then_shape.DebugString(), " vs. ", else_shape.DebugString()));
- xla::XlaBuilder* builder = ctx->builder();
-
auto cond_handle = ctx->Input(0);
auto then_handle = ctx->Input(1);
auto else_handle = ctx->Input(2);
@@ -69,14 +68,14 @@ class SelectOp : public XlaOpKernel {
const auto dim_sizes = then_shape.dim_sizes();
gtl::ArraySlice<int64> bdims = dim_sizes;
bdims.pop_front();
- cond_handle = builder->Broadcast(cond_handle, bdims);
+ cond_handle = xla::Broadcast(cond_handle, bdims);
std::vector<int64> dim_order(then_shape.dims());
dim_order[0] = then_shape.dims() - 1;
std::iota(dim_order.begin() + 1, dim_order.end(), 0);
- cond_handle = builder->Transpose(cond_handle, dim_order);
+ cond_handle = xla::Transpose(cond_handle, dim_order);
}
- ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle));
+ ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
index 9ce01d0d44..6281d6c653 100644
--- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
@@ -45,7 +45,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) {
XlaCompiler* compiler = XlaContext::Get(ctx).compiler();
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel));
- ctx->builder()->Send(ctx->Input(0), channel);
+ xla::Send(ctx->Input(0), channel);
}
REGISTER_XLA_OP(Name("XlaSend"), SendOp);
@@ -76,7 +76,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) {
XlaCompiler* compiler = XlaContext::Get(ctx).compiler();
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel));
- ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel));
+ ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel));
}
REGISTER_XLA_OP(Name("XlaRecv"), RecvOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index d59720bef7..5798823cd5 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -147,7 +148,7 @@ class ExpandDimsOp : public XlaOpKernel {
dim = std::min<int32>(dim, existing_dims_size);
new_shape.emplace(new_shape.begin() + dim, 1);
- ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
+ ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
}
};
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp);
@@ -204,7 +205,7 @@ class SqueezeOp : public XlaOpKernel {
}
}
- ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
+ ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
}
private:
@@ -221,7 +222,7 @@ class ZerosLikeOp : public XlaOpKernel {
const TensorShape input_shape = ctx->InputShape(0);
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
- ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
}
};
@@ -235,7 +236,7 @@ class OnesLikeOp : public XlaOpKernel {
const TensorShape input_shape = ctx->InputShape(0);
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
- ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes()));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index be1e97bf26..1864584ade 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -92,8 +93,7 @@ class SliceOp : public XlaOpKernel {
limits.push_back(begin[i] + size[i]);
}
std::vector<int64> strides(begin.size(), 1);
- ctx->SetOutput(
- 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides));
+ ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides));
} else {
// `begin` is not a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
@@ -106,8 +106,7 @@ class SliceOp : public XlaOpKernel {
input_shape.dim_size(i), "], but ",
"got ", size[i]));
}
- ctx->SetOutput(
- 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size));
+ ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size));
}
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index bbf5ee8b12..d1c69f08b0 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -47,25 +48,25 @@ class SoftmaxOp : public XlaOpKernel {
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
// Find the max in each batch, resulting in a tensor of shape [batch]
- auto logits_max =
- b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type),
+ max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b. Broadcasts
// along the batch dimension.
- auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
- auto exp_shifted = b->Exp(shifted_logits);
+ auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
+ auto exp_shifted = xla::Exp(shifted_logits);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
auto converted =
XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type);
auto reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto sum = XlaHelpers::ConvertElementType(b, reduce, type);
auto softmax =
log_
// softmax = shifted_logits - log(sum(exp(shifted_logits)))
- ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim})
+ ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim})
// softmax = exp(shifted_logits) / sum(exp(shifted_logits))
- : b->Div(exp_shifted, sum, {kBatchDim});
+ : xla::Div(exp_shifted, sum, {kBatchDim});
ctx->SetOutput(0, softmax);
}
@@ -87,43 +88,44 @@ std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
xla::XlaBuilder* b = ctx->builder();
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
- b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b.
// Broadcasts along the batch dimension.
- auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
+ auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
// exp(logits - max_logits)
- auto exp_shifted_logits = b->Exp(shifted_logits);
+ auto exp_shifted_logits = xla::Exp(shifted_logits);
// sum_{class} (exp(logits - max_logits))
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
auto converted =
XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type);
- auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
+ auto reduce =
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type);
// log(sum(exp(logits - max_logits)))
- auto log_sum_exp = b->Log(sum_exp);
+ auto log_sum_exp = xla::Log(sum_exp);
// sum(-labels *
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
// along classes
// (The subtraction broadcasts along the batch dimension.)
- auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim});
- auto mul = b->Mul(b->Neg(labels), sub);
+ auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim});
+ auto mul = xla::Mul(xla::Neg(labels), sub);
auto sum =
- b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type),
- XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
+ xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type),
+ XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto loss = XlaHelpers::ConvertElementType(b, sum, type);
// backprop: prob - labels, where
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
// (where the division broadcasts along the batch dimension)
xla::XlaOp backprop =
- b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
+ xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
return {loss, backprop};
}
@@ -206,16 +208,14 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
// Builds a vector of {batch_size} that is 0 if the index is in range, or
// NaN otherwise; then add that vector to the labels to force out-of-range
// values to NaNs.
- xla::XlaOp nan_or_zero = builder->Select(
- builder->And(
- builder->Le(XlaHelpers::Zero(builder, indices_type), indices),
- builder->Lt(indices, XlaHelpers::IntegerLiteral(
- builder, indices_type, depth))),
- builder->Broadcast(XlaHelpers::Zero(builder, logits_type),
- {batch_size}),
- builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
- {batch_size}));
- labels = builder->Add(labels, nan_or_zero, {0});
+ xla::XlaOp nan_or_zero = xla::Select(
+ xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices),
+ xla::Lt(indices, XlaHelpers::IntegerLiteral(
+ builder, indices_type, depth))),
+ xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}),
+ xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
+ {batch_size}));
+ labels = xla::Add(labels, nan_or_zero, {0});
xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index 204ae84582..faaf8964ff 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -25,8 +25,7 @@ class XlaSortOp : public XlaOpKernel {
explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
- xla::XlaBuilder* const b = context->builder();
- context->SetOutput(0, b->Sort(context->Input(0)));
+ context->SetOutput(0, xla::Sort(context->Input(0)));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index ec077924b5..8a8525efa1 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -73,7 +74,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
"The product of the block dimensions must be positive"));
xla::XlaOp padded =
- b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
+ xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
// 2. Reshape `padded` to `reshaped_padded` of shape:
//
@@ -100,7 +101,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_padded_shape.begin() + 1 + 2 * block_rank);
- xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape);
+ xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape);
// 3. Permute dimensions of `reshaped_padded` to produce
// `permuted_reshaped_padded` of shape:
@@ -120,7 +121,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
xla::XlaOp permuted_reshaped_padded =
- b->Transpose(reshaped_padded, permutation);
+ xla::Transpose(reshaped_padded, permutation);
// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -140,7 +141,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::copy(remainder_shape.begin(), remainder_shape.end(),
output_shape.begin() + 1 + block_rank);
- xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape);
+ xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 4c5886ee2a..47d282fe9e 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -50,7 +51,6 @@ class SpaceToDepthOp : public XlaOpKernel {
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
- xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
@@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[1] / block_size_, block_size_,
// input_shape[2] / block_size_, block_size_,
// depth]
- xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -145,7 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_, block_size_,
// depth]
- xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -155,7 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_ * block_size_ * depth]
//
- xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 9b54058541..ca74cf2450 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -98,7 +99,7 @@ class SplitOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
begin[split_dim] = i * slice_size;
limits[split_dim] = (i + 1) * slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
+ ctx->SetOutput(i, xla::Slice(input, begin, limits, strides));
}
}
};
@@ -199,7 +200,7 @@ class SplitVOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
limits[split_dim] = begin[split_dim] + slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
+ ctx->SetOutput(i, xla::Slice(input, begin, limits, strides));
begin[split_dim] = limits[split_dim];
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 0fb05a2be7..591e61b4c8 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -144,24 +144,25 @@ class StackPushOp : public XlaOpKernel {
// Initializes the Stack, if the element shape was not already known.
OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
- xla::XlaOp ta = b->GetTupleElement(resource->value(), 0);
- xla::XlaOp index = b->GetTupleElement(resource->value(), 1);
+ xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0);
+ xla::XlaOp index = xla::GetTupleElement(resource->value(), 1);
xla::XlaOp value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
- auto update = b->Reshape(value, slice_shape.dim_sizes());
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
- {b->DynamicUpdateSlice(ta, update, start_indices),
- b->Add(index, b->ConstantR0<int32>(1))})));
+ OP_REQUIRES_OK(ctx,
+ resource->SetValue(xla::Tuple(
+ b, {xla::DynamicUpdateSlice(ta, update, start_indices),
+ xla::Add(index, xla::ConstantR0<int32>(b, 1))})));
ctx->SetOutput(0, value);
}
@@ -197,27 +198,27 @@ class StackPopOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
xla::XlaOp state = resource->value();
- xla::XlaOp ta = b->GetTupleElement(state, 0);
- xla::XlaOp index = b->GetTupleElement(state, 1);
+ xla::XlaOp ta = xla::GetTupleElement(state, 0);
+ xla::XlaOp index = xla::GetTupleElement(state, 1);
- index = b->Sub(index, b->ConstantR0<int32>(1));
- OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
+ index = Sub(index, xla::ConstantR0<int32>(b, 1));
+ OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index})));
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}}));
auto slice_shape = stack_shape.dim_sizes();
slice_shape[0] = 1LL;
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
- ctx->SetOutput(0, b->Reshape(read, value_shape));
+ ctx->SetOutput(0, xla::Reshape(read, value_shape));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 43ab4642e9..3b19f8d872 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -33,9 +34,9 @@ namespace {
// Rotates a 32-bit integer 'v' left by 'distance' bits.
xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v,
int distance) {
- return builder->Or(
- builder->ShiftLeft(v, builder->ConstantR0<int>(distance)),
- builder->ShiftRightLogical(v, builder->ConstantR0<int>(32 - distance)));
+ return xla::Or(
+ xla::ShiftLeft(v, xla::ConstantR0<int>(builder, distance)),
+ xla::ShiftRightLogical(v, xla::ConstantR0<int>(builder, 32 - distance)));
}
using ThreeFry2x32State = std::array<xla::XlaOp, 2>;
@@ -51,22 +52,22 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder,
std::array<xla::XlaOp, 3> ks;
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
- ks[2] = builder->ConstantR0<int32>(0x1BD11BDA);
+ ks[2] = xla::ConstantR0<int32>(builder, 0x1BD11BDA);
for (int i = 0; i < 2; ++i) {
ks[i] = key[i];
x[i] = input[i];
- ks[2] = builder->Xor(ks[2], key[i]);
+ ks[2] = xla::Xor(ks[2], key[i]);
}
- x[0] = builder->Add(x[0], ks[0]);
- x[1] = builder->Add(x[1], ks[1]);
+ x[0] = xla::Add(x[0], ks[0]);
+ x[1] = xla::Add(x[1], ks[1]);
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
auto round = [builder](ThreeFry2x32State v, int rotation) {
- v[0] = builder->Add(v[0], v[1]);
+ v[0] = xla::Add(v[0], v[1]);
v[1] = RotateLeftS32(builder, v[1], rotation);
- v[1] = builder->Xor(v[0], v[1]);
+ v[1] = xla::Xor(v[0], v[1]);
return v;
};
@@ -76,36 +77,36 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder,
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
- x[0] = builder->Add(x[0], ks[1]);
- x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(1));
+ x[0] = xla::Add(x[0], ks[1]);
+ x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 1));
x = round(x, rotations[4]);
x = round(x, rotations[5]);
x = round(x, rotations[6]);
x = round(x, rotations[7]);
- x[0] = builder->Add(x[0], ks[2]);
- x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(2));
+ x[0] = xla::Add(x[0], ks[2]);
+ x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 2));
x = round(x, rotations[0]);
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
- x[0] = builder->Add(x[0], ks[0]);
- x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0<int32>(3));
+ x[0] = xla::Add(x[0], ks[0]);
+ x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0<int32>(builder, 3));
x = round(x, rotations[4]);
x = round(x, rotations[5]);
x = round(x, rotations[6]);
x = round(x, rotations[7]);
- x[0] = builder->Add(x[0], ks[1]);
- x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(4));
+ x[0] = xla::Add(x[0], ks[1]);
+ x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 4));
x = round(x, rotations[0]);
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
- x[0] = builder->Add(x[0], ks[2]);
- x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(5));
+ x[0] = xla::Add(x[0], ks[2]);
+ x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 5));
return x;
}
@@ -116,8 +117,8 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
const TensorShape& shape, double minval,
double maxval) {
// Split the seed into two 32-bit scalars to form a key.
- auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {});
- auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {});
+ auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
+ auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
ThreeFry2x32State key = {seed0, seed1};
const int64 size = shape.num_elements();
@@ -127,32 +128,32 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
// Fill the generator inputs with unique counter values.
ThreeFry2x32State inputs;
TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0]));
- inputs[1] = builder->Add(inputs[0], builder->ConstantR0<int32>(half_size));
+ inputs[1] = xla::Add(inputs[0], xla::ConstantR0<int32>(builder, half_size));
ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key);
if (size_is_odd) {
- outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1});
+ outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1});
}
auto bits =
- builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes());
+ xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes());
// Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit
// forces the random bits into the mantissa.
constexpr int kFloatBits = 32;
constexpr int kMantissaBits = 23;
- bits = builder->Or(
- builder->ShiftRightLogical(
- bits, builder->ConstantR0<int32>(kFloatBits - kMantissaBits)),
- builder->ConstantR0<int32>(bit_cast<int32>(1.0f)));
- auto floats = builder->BitcastConvertType(bits, xla::F32);
+ bits = xla::Or(
+ xla::ShiftRightLogical(
+ bits, xla::ConstantR0<int32>(builder, kFloatBits - kMantissaBits)),
+ xla::ConstantR0<int32>(builder, bit_cast<int32>(1.0f)));
+ auto floats = xla::BitcastConvertType(bits, xla::F32);
// We have a floating point number in the range [1.0, 2.0).
// Subtract 1.0f to shift to the range [0.0, 1.0)
- floats = builder->Sub(floats, builder->ConstantR0<float>(1.0f));
+ floats = xla::Sub(floats, xla::ConstantR0<float>(builder, 1.0f));
// Multiply and add to shift to the range [minval, maxval).
- floats = builder->Mul(floats, builder->ConstantR0<float>(maxval - minval));
- floats = builder->Add(floats, builder->ConstantR0<float>(minval));
+ floats = xla::Mul(floats, xla::ConstantR0<float>(builder, maxval - minval));
+ floats = xla::Add(floats, xla::ConstantR0<float>(builder, minval));
return floats;
}
@@ -207,8 +208,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
- auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)),
- ErfInv(uniform));
+ auto normal = xla::Mul(xla::ConstantR0<float>(builder, std::sqrt(2.0)),
+ ErfInv(uniform));
ctx->SetOutput(0, normal);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 55254c746e..c2165ccd86 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -92,12 +93,12 @@ class StridedSliceOp : public XlaOpKernel {
xla::XlaOp slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
- slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
+ slice = xla::Rev(slice, dimensions_to_reverse);
}
- slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
+ slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
- slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
+ slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
}
@@ -171,7 +172,7 @@ class StridedSliceGradOp : public XlaOpKernel {
xla::XlaOp grad = ctx->Input(4);
// Undo any new/shrink axes.
- grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes());
+ grad = xla::Reshape(grad, processing_shape.dim_sizes());
// Pad the input gradients.
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
@@ -204,9 +205,9 @@ class StridedSliceGradOp : public XlaOpKernel {
}
}
if (!dimensions_to_reverse.empty()) {
- grad = ctx->builder()->Rev(grad, dimensions_to_reverse);
+ grad = xla::Rev(grad, dimensions_to_reverse);
}
- grad = ctx->builder()->Pad(grad, zero, padding_config);
+ grad = xla::Pad(grad, zero, padding_config);
ctx->SetOutput(0, grad);
}
@@ -306,17 +307,17 @@ class StridedSliceAssignOp : public XlaOpKernel {
}
if (!dimensions_to_reverse.empty()) {
- rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse);
+ rhs = xla::Rev(rhs, dimensions_to_reverse);
}
- rhs = ctx->builder()->Reshape(rhs, slice_dims);
+ rhs = xla::Reshape(rhs, slice_dims);
if (lhs_shape.dims() == 0) {
// TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix
// and remove this workaround.
lhs = rhs;
} else {
- lhs = ctx->builder()->DynamicUpdateSlice(
- lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
+ lhs = xla::DynamicUpdateSlice(
+ lhs, rhs, xla::ConstantR1<int64>(ctx->builder(), slice_begin));
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 9adee78a1f..2f650ce305 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -123,10 +124,9 @@ xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
const xla::XlaOp& update,
const gtl::ArraySlice<int64>& update_dims,
const xla::XlaOp& start_indices) {
- xla::XlaOp current =
- builder->DynamicSlice(operand, start_indices, update_dims);
- xla::XlaOp sum = builder->Add(current, update);
- return builder->DynamicUpdateSlice(operand, sum, start_indices);
+ xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
+ xla::XlaOp sum = xla::Add(current, update);
+ return xla::DynamicUpdateSlice(operand, sum, start_indices);
}
class TensorArrayOp : public XlaOpKernel {
@@ -162,7 +162,7 @@ class TensorArrayOp : public XlaOpKernel {
ta_shape.AddDim(size);
ta_shape.AppendShape(shape);
xla::XlaOp zero = XlaHelpers::Zero(b, dtype_);
- value = b->Broadcast(zero, ta_shape.dim_sizes());
+ value = xla::Broadcast(zero, ta_shape.dim_sizes());
}
XlaContext& xc = XlaContext::Get(ctx);
@@ -215,12 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
- auto update = b->Reshape(value, slice_shape.dim_sizes());
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
xla::XlaOp written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
@@ -259,17 +259,17 @@ class TensorArrayReadOp : public XlaOpKernel {
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
- xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
- ctx->SetOutput(0, b->Reshape(read, value_shape));
+ ctx->SetOutput(0, xla::Reshape(read, value_shape));
}
private:
@@ -326,7 +326,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
for (auto i = 1; i < ta_shape.dims(); i++) {
end[i] = ta_shape.dim_size(i);
}
- ctx->SetOutput(0, b->Slice(ta, begin, end, strides));
+ ctx->SetOutput(0, xla::Slice(ta, begin, end, strides));
return;
}
}
@@ -391,7 +391,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
if (scatter_all_elements_in_order) {
- ta = b->Add(ta, value);
+ ta = xla::Add(ta, value);
} else {
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
@@ -407,13 +407,13 @@ class TensorArrayScatterOp : public XlaOpKernel {
// Slice out part of the value.
value_starts[0] = i;
value_ends[0] = i + 1;
- auto slice = b->Slice(value, value_starts, value_ends, value_strides);
+ auto slice = xla::Slice(value, value_starts, value_ends, value_strides);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto index = b->Slice(indices, {i}, {i + 1}, {1});
+ auto index = xla::Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
}
@@ -452,7 +452,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
shape[0] *= ta_shape.dim_size(0);
- ctx->SetOutput(0, b->Reshape(ta, shape));
+ ctx->SetOutput(0, xla::Reshape(ta, shape));
Tensor lengths(DT_INT64, {ta_dims[0]});
auto lengths_vec = lengths.vec<int64>();
@@ -522,8 +522,8 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
- ta, b->Reshape(value, ta_shape.dim_sizes()))));
+ OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add(
+ ta, xla::Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index e91075196b..c9e5694262 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -93,9 +94,9 @@ class TileOp : public XlaOpKernel {
if (one_dimension_is_broadcasted_without_multiple) {
// Create a constant Zero the size of the output shape to leverage binary
// operation broadcast semantics.
- auto broadcasted_zero = ctx->builder()->Broadcast(
+ auto broadcasted_zero = xla::Broadcast(
XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape);
- ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input));
+ ctx->SetOutput(0, xla::Add(broadcasted_zero, input));
return;
}
@@ -103,7 +104,7 @@ class TileOp : public XlaOpKernel {
// dimension. This prepends the broadcasted dimensions, so an
// input of shape [2,3,1] broadcast with multiples [5,4,3] will
// end up with shape [5,4,3,2,3,1].
- auto broadcasted = ctx->builder()->Broadcast(input, multiples_array);
+ auto broadcasted = xla::Broadcast(input, multiples_array);
// Now flatten and reshape. The broadcasted dimensions are
// paired with the original dimensions so in the above example
// we flatten [0,3,1,4,2,5] then reshape to [10,12,3].
@@ -112,8 +113,7 @@ class TileOp : public XlaOpKernel {
flattened.push_back(i);
flattened.push_back(i + output_shape.size());
}
- xla::XlaOp output =
- ctx->builder()->Reshape(broadcasted, flattened, output_shape);
+ xla::XlaOp output = xla::Reshape(broadcasted, flattened, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index cbe3c8aaff..beb7cf263d 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -68,7 +68,7 @@ class TopKOp : public XlaOpKernel {
// TODO(b/73891930): add a key-value sort to HLO, rather than using
// bit-packing tricks here.
- xla::XlaOp zero = b->ConstantR0<int32>(0);
+ xla::XlaOp zero = xla::ConstantR0<int32>(b, 0);
// max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally
// ideal. The implications of the choice are:
@@ -83,22 +83,22 @@ class TopKOp : public XlaOpKernel {
// 1. +0.0 == -0.0
// 2. All -0.0 in the input are replaced with +0.0 in the output.
// 3. The sort is stable.
- xla::XlaOp max = b->ConstantR0<int32>(0x80000000);
- xla::XlaOp index_mask = b->ConstantR0<int32>(0x0000FFFF);
- xla::XlaOp value_mask = b->ConstantR0<int32>(0xFFFF0000);
+ xla::XlaOp max = xla::ConstantR0<int32>(b, 0x80000000);
+ xla::XlaOp index_mask = xla::ConstantR0<int32>(b, 0x0000FFFF);
+ xla::XlaOp value_mask = xla::ConstantR0<int32>(b, 0xFFFF0000);
// Convert to from bf16 to f32. The lower 16-bits are zero due to the
// definition of bf16.
- xla::XlaOp input_f32 = b->ConvertElementType(input_bf16, xla::F32);
+ xla::XlaOp input_f32 = xla::ConvertElementType(input_bf16, xla::F32);
// Negate the input to reverse sort it. The lower 16-bits are zero, because
// negating a float is just inverting the high-bit.
- xla::XlaOp negative_input_f32 = b->Neg(input_f32);
+ xla::XlaOp negative_input_f32 = xla::Neg(input_f32);
// Convert to a sign magnitude integer. The lower 16-bits are zero, since
// bitcast convert doesn't change any bits.
xla::XlaOp negative_input_sm32 =
- b->BitcastConvertType(negative_input_f32, xla::S32);
+ xla::BitcastConvertType(negative_input_f32, xla::S32);
// Convert from sign magnitude integer to two's complement integer. The
// lower 16-bits are zero on both sides of the select. On the false side,
@@ -106,8 +106,8 @@ class TopKOp : public XlaOpKernel {
// are all zero, so the lower 16-bits of the result of the subtraction will
// also be zero.
xla::XlaOp negative_input_s32 =
- b->Select(b->Lt(negative_input_sm32, zero),
- b->Sub(max, negative_input_sm32), negative_input_sm32);
+ xla::Select(xla::Lt(negative_input_sm32, zero),
+ xla::Sub(max, negative_input_sm32), negative_input_sm32);
// In order for the Or with iota_s32 to to work properly, the lower 16-bits
// of negative_input_32 must be zero.
@@ -115,32 +115,32 @@ class TopKOp : public XlaOpKernel {
// Pack elements as:
// * upper 16 bits are the value
// * lower 16 bits are the index.
- xla::XlaOp packed_s32 = b->Or(negative_input_s32, iota_s32);
+ xla::XlaOp packed_s32 = xla::Or(negative_input_s32, iota_s32);
// TODO(phawkins): use a more efficient algorithm that does not require a
// full sort.
- xla::XlaOp sorted_s32 = b->Slice(b->Sort(packed_s32),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1});
+ xla::XlaOp sorted_s32 = xla::Slice(xla::Sort(packed_s32),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1});
// Unpack the value/index.
- xla::XlaOp indices_s32 = b->And(sorted_s32, index_mask);
- xla::XlaOp negative_values_s32 = b->And(sorted_s32, value_mask);
+ xla::XlaOp indices_s32 = xla::And(sorted_s32, index_mask);
+ xla::XlaOp negative_values_s32 = xla::And(sorted_s32, value_mask);
// Convert from two's complement integer to sign magnitude integer.
xla::XlaOp negative_values_sm32 =
- b->Select(b->Lt(negative_values_s32, zero),
- b->Sub(max, negative_values_s32), negative_values_s32);
+ xla::Select(xla::Lt(negative_values_s32, zero),
+ xla::Sub(max, negative_values_s32), negative_values_s32);
xla::XlaOp negative_values_f32 =
- b->BitcastConvertType(negative_values_sm32, xla::F32);
+ xla::BitcastConvertType(negative_values_sm32, xla::F32);
// Negate the values to get back the original inputs.
- xla::XlaOp values_f32 = b->Neg(negative_values_f32);
+ xla::XlaOp values_f32 = xla::Neg(negative_values_f32);
// Convert from f32 to bf16.
- xla::XlaOp values_bf16 = b->ConvertElementType(values_f32, xla::BF16);
+ xla::XlaOp values_bf16 = xla::ConvertElementType(values_f32, xla::BF16);
context->SetOutput(0, values_bf16);
context->SetOutput(1, indices_s32);
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 34caefa050..2e5d61e111 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -31,7 +31,6 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp handle;
- xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(1);
TensorShape var_shape;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
@@ -48,7 +47,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
var_shape.DebugString(), " vs ",
delta_shape.DebugString()));
- handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
+ handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2)));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -63,8 +62,6 @@ class ResourceApplyMomentum : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
@@ -97,14 +94,14 @@ class ResourceApplyMomentum : public XlaOpKernel {
xla::XlaOp grad = ctx->Input(3);
xla::XlaOp momentum = ctx->Input(4);
- accum = b->Add(b->Mul(accum, momentum), grad);
+ accum = xla::Add(xla::Mul(accum, momentum), grad);
if (use_nesterov_) {
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
- var = b->Sub(
- var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr)));
+ var = xla::Sub(var, xla::Add(xla::Mul(grad, lr),
+ xla::Mul(xla::Mul(accum, momentum), lr)));
} else {
- var = b->Sub(var, b->Mul(accum, lr));
+ var = xla::Sub(var, xla::Mul(accum, lr));
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
@@ -149,10 +146,12 @@ class ResourceApplyAdagrad : public XlaOpKernel {
xla::XlaOp lr = ctx->Input(2);
xla::XlaOp grad = ctx->Input(3);
- accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
- var = b->Sub(
- var, b->Mul(b->Mul(grad, lr),
- b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5))));
+ accum =
+ xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
+ var = xla::Sub(
+ var,
+ xla::Mul(xla::Mul(grad, lr),
+ xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5))));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
@@ -232,12 +231,13 @@ class ResourceApplyAdam : public XlaOpKernel {
xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
xla::XlaOp alpha =
- b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)),
- b->Sub(one, beta1_power));
- m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1)));
- v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2)));
- var =
- b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon)));
+ xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)),
+ xla::Sub(one, beta1_power));
+ m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1)));
+ v = xla::Add(
+ v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2)));
+ var = xla::Sub(var, xla::Div(xla::Mul(m, alpha),
+ xla::Add(xla::Pow(v, half), epsilon)));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
@@ -320,16 +320,17 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::XlaOp new_ms = b->Add(
- ms,
- b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms),
- b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
+ xla::XlaOp new_ms = xla::Add(
+ ms, xla::Mul(
+ xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)),
+ ms),
+ xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
xla::XlaOp new_mom =
- b->Add(b->Mul(mom, momentum),
- b->Mul(b->Mul(grad, lr),
- b->Pow(b->Add(new_ms, epsilon),
- XlaHelpers::FloatLiteral(b, type, -0.5))));
- xla::XlaOp new_var = b->Sub(var, new_mom);
+ xla::Add(xla::Mul(mom, momentum),
+ xla::Mul(xla::Mul(grad, lr),
+ xla::Pow(xla::Add(new_ms, epsilon),
+ XlaHelpers::FloatLiteral(b, type, -0.5))));
+ xla::XlaOp new_var = xla::Sub(var, new_mom);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
@@ -424,21 +425,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
xla::XlaOp grad_to_use;
if (has_l2_shrinkage) {
- grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var)));
+ grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var)));
} else {
grad_to_use = grad;
}
- xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two));
- xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power));
- xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power));
- linear = b->Add(
+ xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two));
+ xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power));
+ xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power));
+ linear = xla::Add(
linear,
- b->Sub(grad_to_use,
- b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var)));
- xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1);
- xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2));
- var = b->Div(b->Sub(linear_clipped, linear), quadratic);
+ xla::Sub(grad_to_use,
+ xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr),
+ var)));
+ xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1);
+ xla::XlaOp quadratic =
+ xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2));
+ var = xla::Div(xla::Sub(linear_clipped, linear), quadratic);
accum = new_accum;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
index ef5aae81a8..6c721c48fe 100644
--- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -84,12 +85,12 @@ class TransposeOp : public XlaOpKernel {
if (dims <= 1 || is_identity) {
transposed = ctx->Input(0);
} else {
- transposed = ctx->builder()->Transpose(ctx->Input(0), transposed_order);
+ transposed = xla::Transpose(ctx->Input(0), transposed_order);
}
// Conjugate the transposed result if this is ConjugateTransposeOp.
if (conjugate_) {
- ctx->SetOutput(0, ctx->builder()->Conj(transposed));
+ ctx->SetOutput(0, xla::Conj(transposed));
} else {
ctx->SetOutput(0, transposed);
}
@@ -146,7 +147,7 @@ class InvertPermutationOp : public XlaOpKernel {
output[d] = i;
}
- ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output));
+ ctx->SetOutput(0, xla::ConstantR1<int32>(ctx->builder(), output));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index a39e5dcfc5..3823f5c087 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -27,15 +27,13 @@ limitations under the License.
namespace tensorflow {
namespace {
-// A subclass of a TlaUnaryOp must build the lambda computation that
-// describes the scalar->scalar function to apply to each element of
-// the input.
#define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \
class NAME##Op : public XlaOpKernel { \
public: \
explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \
void Compile(XlaOpKernelContext* ctx) { \
xla::XlaBuilder* b = ctx->builder(); \
+ (void)b; \
xla::XlaOp x = ctx->Input(0); \
xla::XlaOp y = COMPUTATION; \
ctx->SetOutput(0, y); \
@@ -43,84 +41,88 @@ namespace {
}; \
REGISTER_XLA_OP(Name(#NAME), NAME##Op);
-XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x));
+XLAJIT_MAKE_UNARY(ComplexAbs, xla::Abs(x));
-XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x)));
+XLAJIT_MAKE_UNARY(Angle, xla::Atan2(xla::Imag(x), xla::Real(x)));
-XLAJIT_MAKE_UNARY(Conj, b->Conj(x));
+XLAJIT_MAKE_UNARY(Conj, xla::Conj(x));
// Return x if x>0, otherwise -x.
-XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
+XLAJIT_MAKE_UNARY(Abs, xla::Abs(x));
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
XLAJIT_MAKE_UNARY(
Acos,
- b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(x, x)),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
- b->Add(XlaHelpers::One(b, input_type(0)), x))));
+ xla::Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
+ xla::Atan2(xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)),
+ xla::Mul(x, x)),
+ XlaHelpers::FloatLiteral(b, input_type(0),
+ 0.5)),
+ xla::Add(XlaHelpers::One(b, input_type(0)), x))));
// acosh(x) = log(x + sqrt(x^2 - 1))
// = log(x + sqrt((x+1)*(x-1)))
XLAJIT_MAKE_UNARY(
Acosh,
- b->Log(b->Add(x,
- b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))),
- b->Sub(x, XlaHelpers::One(b, input_type(0)))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+ xla::Log(xla::Add(
+ x, xla::Pow(xla::Mul(xla::Add(x, XlaHelpers::One(b, input_type(0))),
+ xla::Sub(x, XlaHelpers::One(b, input_type(0)))),
+ XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
XLAJIT_MAKE_UNARY(
Asin,
- b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)),
- b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(x, x)),
+ xla::Mul(
+ XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
+ xla::Atan2(x,
+ xla::Add(XlaHelpers::One(b, input_type(0)),
+ xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)),
+ xla::Mul(x, x)),
XlaHelpers::FloatLiteral(b, input_type(0),
0.5))))));
// asinh(x) = log(x + sqrt(x^2 + 1))
XLAJIT_MAKE_UNARY(
Asinh,
- b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
- XlaHelpers::One(b, input_type(0))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+ xla::Log(xla::Add(
+ x, xla::Pow(xla::Add(xla::Mul(x, x), XlaHelpers::One(b, input_type(0))),
+ XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
-XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0))));
+XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, XlaHelpers::One(b, input_type(0))));
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
XLAJIT_MAKE_UNARY(
- Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),
- b->Sub(XlaHelpers::One(b, input_type(0)), x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x));
-XLAJIT_MAKE_UNARY(Cos, b->Cos(x));
+ Atanh,
+ xla::Mul(xla::Log(xla::Div(xla::Add(XlaHelpers::One(b, input_type(0)), x),
+ xla::Sub(XlaHelpers::One(b, input_type(0)), x))),
+ XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x));
+XLAJIT_MAKE_UNARY(Cos, xla::Cos(x));
XLAJIT_MAKE_UNARY(Cosh,
- b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Sin, b->Sin(x));
-XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
-
-XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x));
-
-XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
-XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
-XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x),
- XlaHelpers::FloatLiteral(
- b, input_type(0),
- std::numeric_limits<double>::infinity())));
-XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x));
+ xla::Mul(xla::Add(xla::Exp(x), xla::Exp(xla::Neg(x))),
+ XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Sin, xla::Sin(x));
+XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
+
+XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
+
+XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
+XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
+XLAJIT_MAKE_UNARY(IsInf, xla::Eq(xla::Abs(x),
+ XlaHelpers::FloatLiteral(
+ b, input_type(0),
+ std::numeric_limits<double>::infinity())));
+XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
// Return 1/x
-XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Log, b->Log(x));
+XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Log, xla::Log(x));
XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x));
-XLAJIT_MAKE_UNARY(Invert, b->Not(x));
-XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x));
-XLAJIT_MAKE_UNARY(Neg, b->Neg(x));
+XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
+XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
+XLAJIT_MAKE_UNARY(Neg, xla::Neg(x));
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
@@ -130,35 +132,35 @@ static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype,
auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
- auto round_val = b->Floor(x);
- auto fraction = b->Sub(x, round_val);
+ auto round_val = xla::Floor(x);
+ auto fraction = xla::Sub(x, round_val);
auto nearest_even_int =
- b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x))));
- auto is_odd = b->Eq(nearest_even_int, one);
- return b->Select(
- b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)),
- b->Add(round_val, one), round_val);
+ xla::Sub(round_val, xla::Mul(two, xla::Floor(xla::Mul(half, x))));
+ auto is_odd = xla::Eq(nearest_even_int, one);
+ return xla::Select(xla::Or(xla::Gt(fraction, half),
+ xla::And(xla::Eq(fraction, half), is_odd)),
+ xla::Add(round_val, one), round_val);
}
XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
-XLAJIT_MAKE_UNARY(Rsqrt,
- b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
+XLAJIT_MAKE_UNARY(Rsqrt, xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0),
+ -0.5)));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype,
const xla::XlaOp& x) {
auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
- return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
+ return xla::Add(half, xla::Mul(half, xla::Tanh(xla::Mul(half, x))));
}
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
-XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
+XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
XLAJIT_MAKE_UNARY(Sinh,
- b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+ xla::Mul(xla::Sub(xla::Exp(x), xla::Exp(xla::Neg(x))),
+ XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
// softplus(x) = log(1 + exp(x))
//
@@ -169,21 +171,21 @@ XLAJIT_MAKE_UNARY(Sinh,
// This is equivalent to:
// max(x, 0) + log1p(exp(-abs(x)))
XLAJIT_MAKE_UNARY(Softplus,
- b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))),
- b->Log1p(b->Exp(b->Neg(b->Abs(x))))));
+ xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))),
+ b->Log1p(xla::Exp(xla::Neg(xla::Abs(x))))));
// softsign(x) = x / (abs(x) + 1)
XLAJIT_MAKE_UNARY(Softsign,
- b->Div(x,
- b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0)))));
+ xla::Div(x, xla::Add(xla::Abs(x),
+ XlaHelpers::One(b, input_type(0)))));
XLAJIT_MAKE_UNARY(Sqrt,
- b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
-XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x)));
-XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
+ xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Square, xla::Mul(x, x));
+XLAJIT_MAKE_UNARY(Tan, xla::Div(xla::Sin(x), xla::Cos(x)));
+XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x));
-XLAJIT_MAKE_UNARY(Real, b->Real(x));
-XLAJIT_MAKE_UNARY(Imag, b->Imag(x));
+XLAJIT_MAKE_UNARY(Real, xla::Real(x));
+XLAJIT_MAKE_UNARY(Imag, xla::Imag(x));
#undef XLAJIT_MAKE_UNARY
@@ -197,13 +199,14 @@ class ErfOp : public XlaOpKernel {
xla::PrimitiveType primitive_type;
xla::XlaOp one = XlaHelpers::One(b, input_type(0));
xla::XlaOp x = ctx->Input(0);
- xla::XlaOp abs_x = b->Abs(x);
+ xla::XlaOp abs_x = xla::Abs(x);
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(input_type(0), &primitive_type));
- auto y = b->Select(b->Gt(abs_x, one), b->Sub(one, Erfc(x, primitive_type)),
- Erf(x, primitive_type));
+ auto y =
+ xla::Select(xla::Gt(abs_x, one), xla::Sub(one, Erfc(x, primitive_type)),
+ Erf(x, primitive_type));
ctx->SetOutput(0, y);
}
};
@@ -216,14 +219,15 @@ class ErfcOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp one = XlaHelpers::One(b, input_type(0));
xla::XlaOp x = ctx->Input(0);
- xla::XlaOp abs_x = b->Abs(x);
+ xla::XlaOp abs_x = xla::Abs(x);
xla::PrimitiveType primitive_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(input_type(0), &primitive_type));
- auto y = b->Select(b->Lt(abs_x, one), b->Sub(one, Erf(x, primitive_type)),
- Erfc(x, primitive_type));
+ auto y =
+ xla::Select(xla::Lt(abs_x, one), xla::Sub(one, Erf(x, primitive_type)),
+ Erfc(x, primitive_type));
ctx->SetOutput(0, y);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
index f87586ba57..0e5d58ecba 100644
--- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -74,10 +75,9 @@ class UnpackOp : public XlaOpKernel {
for (int i = 0; i < num; ++i) {
start_indices[axis] = i;
limit_indices[axis] = i + 1;
- auto slice = ctx->builder()->Slice(input, start_indices, limit_indices,
- strides);
+ auto slice = xla::Slice(input, start_indices, limit_indices, strides);
// Reshape to drop the 'axis' dimension.
- auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes());
+ auto result = xla::Reshape(slice, output_shape.dim_sizes());
ctx->SetOutput(i, result);
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index ad51396bdf..febac82873 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -33,8 +33,8 @@ class VarIsInitializedOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
XlaResource* variable;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
- ctx->SetOutput(0,
- ctx->builder()->ConstantR0<bool>(variable->initialized()));
+ ctx->SetOutput(
+ 0, xla::ConstantR0<bool>(ctx->builder(), variable->initialized()));
}
};
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
@@ -96,7 +96,7 @@ class AssignAddVariableOp : public XlaOpKernel {
xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
- handle = ctx->builder()->Add(handle, ctx->Input(1));
+ handle = xla::Add(handle, ctx->Input(1));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -112,7 +112,7 @@ class AssignSubVariableOp : public XlaOpKernel {
xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
- handle = ctx->builder()->Sub(handle, ctx->Input(1));
+ handle = xla::Sub(handle, ctx->Input(1));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -191,7 +191,7 @@ class ResourceScatterAddOp : public ResourceScatterOp {
private:
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) {
- return builder->Add(x, y);
+ return xla::Add(x, y);
}
};
REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp);
@@ -204,7 +204,7 @@ class ResourceScatterSubOp : public ResourceScatterOp {
private:
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) {
- return builder->Sub(x, y);
+ return xla::Sub(x, y);
}
};
REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp);
@@ -217,7 +217,7 @@ class ResourceScatterMulOp : public ResourceScatterOp {
private:
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) {
- return builder->Mul(x, y);
+ return xla::Mul(x, y);
}
};
REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp);
@@ -230,7 +230,7 @@ class ResourceScatterDivOp : public ResourceScatterOp {
private:
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) {
- return builder->Div(x, y);
+ return xla::Div(x, y);
}
};
REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp);
@@ -243,7 +243,7 @@ class ResourceScatterMinOp : public ResourceScatterOp {
private:
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) {
- return builder->Min(x, y);
+ return xla::Min(x, y);
}
};
REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp);
@@ -256,7 +256,7 @@ class ResourceScatterMaxOp : public ResourceScatterOp {
private:
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) {
- return builder->Max(x, y);
+ return xla::Max(x, y);
}
};
REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp);
@@ -286,7 +286,7 @@ class ResourceScatterNdAddOp : public ResourceScatterOp {
private:
static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
xla::XlaBuilder* builder) {
- return builder->Add(x, y);
+ return xla::Add(x, y);
}
};
REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 5467c5d994..340165bac6 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -246,7 +246,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
}
}
- xla::XlaOp init = builder->Tuple(inputs);
+ xla::XlaOp init = xla::Tuple(builder, inputs);
VLOG(1) << "Building while loop";
@@ -255,22 +255,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
{
std::unique_ptr<xla::XlaBuilder> cb =
builder->CreateSubBuilder("cond_wrapper");
- auto inputs = cb->Parameter(0, cond_input_shape, "inputs");
- auto outputs = cb->Call(*cond.computation, {inputs});
- cb->GetTupleElement(outputs, 0);
+ auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs");
+ auto outputs = xla::Call(cb.get(), *cond.computation, {inputs});
+ xla::GetTupleElement(outputs, 0);
xla::StatusOr<xla::XlaComputation> result = cb->Build();
OP_REQUIRES_OK(ctx, result.status());
cond_wrapper = std::move(result.ValueOrDie());
}
- xla::XlaOp while_result =
- builder->While(cond_wrapper, *body.computation, init);
+ xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init);
// Sets non-variable outputs.
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
ctx->SetOutput(body.input_mapping[i],
- builder->GetTupleElement(while_result, i));
+ xla::GetTupleElement(while_result, i));
}
}
@@ -284,7 +283,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
- builder->GetTupleElement(while_result, pos), builder));
+ xla::GetTupleElement(while_result, pos), builder));
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name() << " modified: " << update.modified
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index ee0bb91a6b..dd29bafcd9 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -81,25 +82,26 @@ xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
dimensions.push_back(x_shape.dimensions(x_outer_dim));
dimensions.push_back(y_shape.dimensions(y_outer_dim));
- return builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())),
+ return xla::Broadcast(
+ xla::ConstantLiteral(builder,
+ xla::Literal::Zero(x_shape.element_type())),
dimensions);
}
if (x_shape.element_type() == xla::C64 && conjugate_x) {
- x = builder->Conj(x);
+ x = xla::Conj(x);
}
if (y_shape.element_type() == xla::C64 && conjugate_y) {
- y = builder->Conj(y);
+ y = xla::Conj(y);
}
// If there are no batch dimensions, use a regular Dot.
// TODO(b/69062148) Remove this code when Dot emitters can be passed
// dimensions to transpose directly (i.e. without requiring a Transpose HLO).
if (batch_dimension_numbers.empty()) {
- auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x;
- auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y;
- return builder->Dot(lhs, rhs);
+ auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
+ auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
+ return xla::Dot(lhs, rhs);
}
xla::DotDimensionNumbers dot_dnums;
@@ -109,7 +111,7 @@ xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
}
- return builder->DotGeneral(x, y, dot_dnums);
+ return xla::DotGeneral(x, y, dot_dnums);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 20925118bf..397f0e3a72 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -80,21 +81,20 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
std::vector<int32> mask_vector(n);
std::iota(mask_vector.begin(), mask_vector.end(), 0);
- auto mask_range = body_builder->ConstantR1<int32>(mask_vector);
- auto mask_range_row = body_builder->Broadcast(
- body_builder->Reshape(mask_range, {0}, {1, n}), major_dims);
- auto mask_range_col = body_builder->Broadcast(
- body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims);
+ auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
+ auto mask_range_row =
+ xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
+ auto mask_range_col =
+ xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
auto body_a = loop_vars[0];
auto body_l = loop_vars[1];
// row = l[..., i, :i]
// select the whole i-th row, then mask out all columns past i-1
- auto zero = body_builder->ConstantR0<int32>(0);
+ auto zero = xla::ConstantR0<int32>(body_builder, 0);
TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l,
{i, zero}, {1, n}));
- auto row = body_builder->Select(body_builder->Ge(mask_range_row, i),
- mask_zeros_row, l_i);
+ auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
// a[..., i, i]
TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
{i, i}, {1, 1}));
@@ -105,16 +105,15 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
/*transpose_y=*/true));
// l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
// np.swapaxes(row, -1, -2)))
- auto l_ii = body_builder->Pow(
- body_builder->Sub(a_ii, diag_dot),
- FloatLiteral(body_builder, a_shape.element_type(), 0.5));
+ auto l_ii =
+ xla::Pow(xla::Sub(a_ii, diag_dot),
+ FloatLiteral(body_builder, a_shape.element_type(), 0.5));
// a[..., i+1:, i]
// select the whole i-th column, then mask out all rows above i+1
TF_ASSIGN_OR_RETURN(
auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1}));
- auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i),
- mask_zeros_col, a_0i);
+ auto a_ip1i = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
// l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
// l[..., i, i]
@@ -125,11 +124,9 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
/*transpose_x=*/false,
/*transpose_y=*/true));
// np.dot(l[..., i+1:, :i], r.T)
- auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i),
- mask_zeros_col, dot);
+ auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
- auto col_update =
- body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii);
+ auto col_update = xla::Div(xla::Sub(a_ip1i, dot_ip1), l_ii);
TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
body_builder, body_l, col_update, {i}));
// Assign the diagonal after the rest of the column because otherwise the
@@ -191,9 +188,8 @@ xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
/*conjugate_y=*/false));
TF_ASSIGN_OR_RETURN(auto before,
SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
- TF_ASSIGN_OR_RETURN(
- a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta),
- {i, i}));
+ TF_ASSIGN_OR_RETURN(a, UpdateSliceInMinorDims(
+ builder, a, xla::Sub(before, delta), {i, i}));
}
// l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc
index e4f195901e..3dfa66029c 100644
--- a/tensorflow/compiler/tf2xla/lib/random.cc
+++ b/tensorflow/compiler/tf2xla/lib/random.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
@@ -49,9 +50,9 @@ xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) {
XlaHelpers::FloatLiteral(builder, dtype, kAlphaNormalCdf);
// probit(p) = sqrt(2) * erfinv(2*p-1)
- auto p = builder->Add(alpha_normal_cdf, builder->Mul(z, uniform));
- auto erfinv_input = builder->Sub(builder->Mul(p, two), one);
- return builder->Mul(sqrt_2, ErfInv(erfinv_input));
+ auto p = xla::Add(alpha_normal_cdf, xla::Mul(z, uniform));
+ auto erfinv_input = xla::Sub(xla::Mul(p, two), one);
+ return xla::Mul(sqrt_2, ErfInv(erfinv_input));
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index d5a27abb25..85e3d3ab85 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -97,8 +98,8 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
buffer_shape_post_axes.end());
// Construct the initial values of the loop-carried Tensors.
- auto flat_indices = builder->Reshape(indices, flat_indices_shape);
- auto flat_updates = builder->Reshape(updates, flat_updates_shape);
+ auto flat_indices = xla::Reshape(indices, flat_indices_shape);
+ auto flat_updates = xla::Reshape(updates, flat_updates_shape);
auto init = {flat_indices, flat_updates, buffer};
// Constructs the loop body. The implementation of scatter is essentially:
@@ -112,46 +113,44 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
auto updates = loop_vars[1];
auto buffer = loop_vars[2];
- auto zero_index = body_builder->ConstantLiteral(
- xla::Literal::Zero(indices_shape.element_type()));
+ auto zero_index = xla::ConstantLiteral(
+ body_builder, xla::Literal::Zero(indices_shape.element_type()));
// Slice the i-th index from the indices array.
xla::XlaOp index;
- auto indices_offset = body_builder->Reshape(i, {1});
+ auto indices_offset = xla::Reshape(i, {1});
if (indices_are_vectors) {
- indices_offset = body_builder->Pad(indices_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, 1}}));
+ indices_offset = xla::Pad(indices_offset, zero_index,
+ xla::MakeEdgePaddingConfig({{0, 1}}));
- index = body_builder->DynamicSlice(indices, indices_offset,
- {1, num_index_dims});
- index = body_builder->Collapse(index, {0, 1});
+ index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims});
+ index = xla::Collapse(index, {0, 1});
} else {
- index = body_builder->DynamicSlice(indices, indices_offset, {1});
+ index = xla::DynamicSlice(indices, indices_offset, {1});
}
// Discard updates with negative indices, since some users expect this.
- auto index_in_range =
- body_builder->ReduceAll(body_builder->Le(zero_index, index),
- body_builder->ConstantR0<bool>(true),
- xla::CreateScalarAndComputation(body_builder));
+ auto index_in_range = xla::ReduceAll(
+ xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true),
+ xla::CreateScalarAndComputation(body_builder));
// Make the index in bounds to prevent implementation defined behavior.
- index = body_builder->Max(index, zero_index);
- index = body_builder->Pad(
+ index = xla::Max(index, zero_index);
+ index = xla::Pad(
index, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
// Slice the i-th index from the updates array.
- auto updates_offset = body_builder->Reshape(i, {1});
- updates_offset = body_builder->Pad(
+ auto updates_offset = xla::Reshape(i, {1});
+ updates_offset = xla::Pad(
updates_offset, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
std::vector<int64> flat_updates_slice_shape({1});
flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
- auto update = body_builder->DynamicSlice(updates, updates_offset,
- flat_updates_slice_shape);
+ auto update =
+ xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape);
// Unflatten the major (iteration) dimensions of the slice to their
// original shape.
@@ -159,20 +158,19 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
updates_slice_shape.insert(updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
- update = body_builder->Reshape(update, updates_slice_shape);
+ update = xla::Reshape(update, updates_slice_shape);
// Apply the update to the buffer. If there is a combiner, use it to merge
// the current values with the update.
- auto current_value =
- body_builder->DynamicSlice(buffer, index, updates_slice_shape);
+ auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape);
if (combiner) {
update = combiner(current_value, update, body_builder);
}
// Use the current value instead of the update if the index is out of
// bounds.
- update = body_builder->Select(index_in_range, update, current_value);
+ update = xla::Select(index_in_range, update, current_value);
// Apply the update.
- buffer = body_builder->DynamicUpdateSlice(buffer, update, index);
+ buffer = xla::DynamicUpdateSlice(buffer, update, index);
return std::vector<xla::XlaOp>{indices, updates, buffer};
};
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index b4503601f9..b9f695ac4b 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -90,8 +91,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
tensorflow::strings::StrCat("trsm_base_", k));
- auto a_param = sub->Parameter(
- 0,
+ auto a_param = xla::Parameter(
+ sub.get(), 0,
xla::ShapeUtil::MakeShape(
b_shape.element_type(),
PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
@@ -103,8 +104,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
} else {
b_lastd = {m, k};
}
- auto b_param = sub->Parameter(
- 1,
+ auto b_param = xla::Parameter(
+ sub.get(), 1,
xla::ShapeUtil::MakeShape(
b_shape.element_type(),
PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
@@ -165,11 +166,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
+ update = xla::Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
@@ -196,7 +197,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
/*conjugate_y=*/conjugate_a));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
- b_update = builder->Sub(b_slice_2, b_update);
+ b_update = xla::Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
}
@@ -217,11 +218,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
+ update = xla::Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
@@ -247,7 +248,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
/*conjugate_y=*/false));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
- b_update = builder->Sub(b_slice_2, b_update);
+ b_update = xla::Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
}
@@ -268,11 +269,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
+ update = xla::Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
@@ -299,7 +300,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
/*conjugate_y=*/conjugate_a));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {0, 0}, {m, i}));
- b_update = builder->Sub(b_slice_2, b_update);
+ b_update = xla::Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
}
@@ -320,11 +321,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
+ update = xla::Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
@@ -350,7 +351,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
/*conjugate_y=*/false));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {0, 0}, {i, n}));
- b_update = builder->Sub(b_slice_2, b_update);
+ b_update = xla::Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
}
@@ -394,7 +395,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
- auto update = builder->Div(b_slice, a_slice_conj);
+ auto update = xla::Div(b_slice, a_slice_conj);
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
}
@@ -414,8 +415,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
// The right-hand-side matrix b is a loop invariant.
b_shape};
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
- auto init = builder->Tuple({init_i, output, a, b});
+ auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1);
+ auto init = xla::Tuple(builder, {init_i, output, a, b});
// Construct the loop condition function,
// def cond_fun(loop_carry):
@@ -424,14 +425,14 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
std::unique_ptr<xla::XlaBuilder> condb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
{
- auto i = condb->GetTupleElement(
- condb->Parameter(0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple"),
+ auto i = xla::GetTupleElement(
+ xla::Parameter(condb.get(), 0, tuple_shape,
+ "TriangularSolveLeftLookingWhileTuple"),
0);
if (transpose_a) {
- condb->Ge(i, condb->ConstantR0<int32>(0));
+ xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
} else {
- condb->Lt(i, condb->ConstantR0<int32>(m));
+ xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m));
}
}
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
@@ -454,15 +455,15 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
std::unique_ptr<xla::XlaBuilder> bodyb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
{
- auto input_tuple = bodyb->Parameter(0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple");
+ auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
+ "TriangularSolveLeftLookingWhileTuple");
// i, output, a, b = loop_carry
- auto i = bodyb->GetTupleElement(input_tuple, 0);
- auto body_out = bodyb->GetTupleElement(input_tuple, 1);
- auto body_a = bodyb->GetTupleElement(input_tuple, 2);
- auto body_b = bodyb->GetTupleElement(input_tuple, 3);
- auto zero = bodyb->ConstantR0<int32>(0);
+ auto i = xla::GetTupleElement(input_tuple, 0);
+ auto body_out = xla::GetTupleElement(input_tuple, 1);
+ auto body_a = xla::GetTupleElement(input_tuple, 2);
+ auto body_b = xla::GetTupleElement(input_tuple, 3);
+ auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
// We'd like to implement this:
// if transpose_a:
@@ -491,14 +492,14 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
TF_ASSIGN_OR_RETURN(
auto result_row_slice,
DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n}));
- auto result_row = bodyb->Sub(result_row_slice, b_update);
+ auto result_row = xla::Sub(result_row_slice, b_update);
// body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a,
{i, i}, {1, 1}));
TF_ASSIGN_OR_RETURN(auto a_elt_conj,
MaybeConjugate(bodyb.get(), a_elt, conjugate_a));
- auto div_result = bodyb->Div(result_row, a_elt_conj);
+ auto div_result = xla::Div(result_row, a_elt_conj);
TF_ASSIGN_OR_RETURN(body_out,
DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
div_result, {i, zero}));
@@ -507,15 +508,16 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
// return (i - 1, body_out, a, b)
// else:
// return (i + 1, body_out, a, b)
- auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1));
- bodyb->Tuple({next_i, body_out, body_a, body_b});
+ auto next_i =
+ xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1));
+ xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
}
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
// Construct the While loop and return the result,
// return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = builder->While(cond, body, init);
- return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
+ auto triangular_solve_left_looking_while = xla::While(cond, body, init);
+ return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
}
xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
@@ -553,8 +555,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
// The right-hand-side matrix b is a loop invariant.
b_shape};
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = builder->ConstantR0<int32>(transpose_a ? 0 : n - 1);
- auto init = builder->Tuple({init_i, output, a, b});
+ auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1);
+ auto init = xla::Tuple(builder, {init_i, output, a, b});
// Construct the loop condition function,
// def cond_fun(loop_carry):
@@ -563,14 +565,14 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
std::unique_ptr<xla::XlaBuilder> condb =
builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond");
{
- auto i = condb->GetTupleElement(
- condb->Parameter(0, tuple_shape,
- "TriangularSolveRightLookingWhileTuple"),
+ auto i = xla::GetTupleElement(
+ xla::Parameter(condb.get(), 0, tuple_shape,
+ "TriangularSolveRightLookingWhileTuple"),
0);
if (transpose_a) {
- condb->Lt(i, condb->ConstantR0<int32>(n));
+ xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n));
} else {
- condb->Ge(i, condb->ConstantR0<int32>(0));
+ xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
}
}
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
@@ -593,15 +595,15 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
std::unique_ptr<xla::XlaBuilder> bodyb =
builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody");
{
- auto input_tuple = bodyb->Parameter(
- 0, tuple_shape, "TriangularSolveRightLookingWhileTuple");
+ auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
+ "TriangularSolveRightLookingWhileTuple");
// i, output, a, b = loop_carry
- auto i = bodyb->GetTupleElement(input_tuple, 0);
- auto body_out = bodyb->GetTupleElement(input_tuple, 1);
- auto body_a = bodyb->GetTupleElement(input_tuple, 2);
- auto body_b = bodyb->GetTupleElement(input_tuple, 3);
- auto zero = bodyb->ConstantR0<int32>(0);
+ auto i = xla::GetTupleElement(input_tuple, 0);
+ auto body_out = xla::GetTupleElement(input_tuple, 1);
+ auto body_a = xla::GetTupleElement(input_tuple, 2);
+ auto body_b = xla::GetTupleElement(input_tuple, 3);
+ auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
// We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :,
// i:i+1]) But since we can't have intermediate array sizes depend on the
@@ -613,7 +615,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
/*conjugate_x=*/false,
/*conjugate_y=*/conjugate_a));
// result = b - np.matmul(output, a)
- auto result = bodyb->Sub(body_b, b_update);
+ auto result = xla::Sub(body_b, b_update);
// result_row = result[..., :, i:i+1]
TF_ASSIGN_OR_RETURN(
auto result_row,
@@ -624,7 +626,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
{i, i}, {1, 1}));
TF_ASSIGN_OR_RETURN(auto a_ii_conj,
MaybeConjugate(bodyb.get(), a_ii, conjugate_a));
- auto div_result = bodyb->Div(result_row, a_ii_conj);
+ auto div_result = xla::Div(result_row, a_ii_conj);
TF_ASSIGN_OR_RETURN(body_out,
DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
div_result, {zero, i}));
@@ -633,15 +635,16 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
// return (i + 1, body_out, a, b)
// else:
// return (i - 1, body_out, a, b)
- auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? 1 : -1));
- bodyb->Tuple({next_i, body_out, body_a, body_b});
+ auto next_i =
+ xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1));
+ xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
}
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
// Construct the While loop and return the result,
// return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = builder->While(cond, body, init);
- return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
+ auto triangular_solve_left_looking_while = xla::While(cond, body, init);
+ return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index d9ff7e6259..11774dde08 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -28,8 +29,8 @@ limitations under the License.
namespace tensorflow {
xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
- return builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
+ return xla::Broadcast(
+ xla::ConstantLiteral(builder, xla::Literal::Zero(shape.element_type())),
xla::AsInt64Slice(shape.dimensions()));
}
@@ -37,19 +38,19 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
double value) {
switch (type) {
case xla::F16:
- return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
+ return xla::ConstantR0<xla::half>(builder, static_cast<xla::half>(value));
break;
case xla::BF16:
- return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value));
+ return xla::ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
break;
case xla::F32:
- return builder->ConstantR0<float>(static_cast<float>(value));
+ return xla::ConstantR0<float>(builder, static_cast<float>(value));
break;
case xla::F64:
- return builder->ConstantR0<double>(value);
+ return xla::ConstantR0<double>(builder, value);
break;
case xla::C64:
- return builder->ConstantR0<xla::complex64>(value);
+ return xla::ConstantR0<xla::complex64>(builder, value);
break;
default:
LOG(FATAL) << "unhandled element type " << type;
@@ -107,7 +108,7 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
default:
LOG(FATAL) << "unhandled element type " << type;
}
- return builder->ConstantLiteral(literal);
+ return xla::ConstantLiteral(builder, literal);
}
xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
@@ -136,7 +137,7 @@ xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
std::vector<int64> strides(n_dims, 1);
- return builder->Slice(x, padded_start, padded_end, strides);
+ return xla::Slice(x, padded_start, padded_end, strides);
}
std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
@@ -163,7 +164,7 @@ xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
TF_ASSIGN_OR_RETURN(auto padded_starts,
PrependZerosInMajorDims(builder, x, starts));
auto padded_sizes = PrependMajorDims(builder, major_dims, sizes);
- return builder->DynamicSlice(x, padded_starts, padded_sizes);
+ return xla::DynamicSlice(x, padded_starts, padded_sizes);
}
xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
@@ -172,7 +173,7 @@ xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
gtl::ArraySlice<int64> start) {
// TODO(phawkins): make int64 work on all backends, remove the int32 cast.
std::vector<int32> start_as_int32(start.begin(), start.end());
- auto start_constant = builder->ConstantR1<int32>(start_as_int32);
+ auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32);
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
@@ -180,7 +181,7 @@ xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
const int64 start_length =
xla::ShapeUtil::GetDimension(start_constant_shape, -1);
TF_RET_CHECK(start_length == n_dims);
- return builder->DynamicUpdateSlice(x, update, start_constant);
+ return xla::DynamicUpdateSlice(x, update, start_constant);
}
xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
@@ -202,7 +203,7 @@ xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
const std::vector<xla::XlaOp>& starts) {
TF_ASSIGN_OR_RETURN(auto padded_starts,
PrependZerosInMajorDims(builder, x, starts));
- return builder->DynamicUpdateSlice(x, update, padded_starts);
+ return xla::DynamicUpdateSlice(x, update, padded_starts);
}
xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
@@ -210,13 +211,12 @@ xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
const std::vector<xla::XlaOp>& starts) {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(shape);
- auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1});
+ auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1});
std::vector<xla::XlaOp> padded_starts(n_dims, zero);
for (int i = 0; i < starts.size(); ++i) {
- padded_starts[n_dims - starts.size() + i] =
- builder->Reshape(starts[i], {1});
+ padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1});
}
- return builder->ConcatInDim(padded_starts, 0);
+ return xla::ConcatInDim(builder, padded_starts, 0);
}
xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
@@ -227,14 +227,14 @@ xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
- return builder->Transpose(x, permutation);
+ return xla::Transpose(x, permutation);
}
xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder,
const xla::XlaOp& x, bool conjugate) {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
auto perform_conj = shape.element_type() == xla::C64 && conjugate;
- return perform_conj ? builder->Conj(x) : x;
+ return perform_conj ? xla::Conj(x) : x;
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index 5f408f2ed0..2a332c933f 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -86,9 +86,10 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
- TF_ASSERT_OK(DynamicSliceInMinorDims(
- &builder, a, {index, builder.ConstantR0<int32>(0)}, {1, 4})
- .status());
+ TF_ASSERT_OK(
+ DynamicSliceInMinorDims(
+ &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, 4})
+ .status());
ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
{a_data.get(), index_data.get()});
@@ -129,8 +130,8 @@ XLA_TEST_F(UtilTest, RowBatchDot) {
TF_ASSERT_OK_AND_ASSIGN(
auto l_index,
- DynamicSliceInMinorDims(&builder, a,
- {index, builder.ConstantR0<int32>(0)}, {1, n}));
+ DynamicSliceInMinorDims(
+ &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n}));
TF_ASSERT_OK(BatchDot(&builder, l_index, row,
/*transpose_x=*/false, /*transpose_y=*/true)
.status());
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 09ce594930..7cc88f34d2 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -39,7 +40,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
xla::XlaBuilder* builder) {
std::vector<xla::XlaOp> elements(arity);
for (int i = 0; i < arity; ++i) {
- elements[i] = builder->GetTupleElement(tuple, i);
+ elements[i] = xla::GetTupleElement(tuple, i);
}
return elements;
};
@@ -48,7 +49,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
std::unique_ptr<xla::XlaBuilder> cond_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
{
- auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter");
+ auto parameter =
+ xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
TF_RETURN_IF_ERROR(
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
@@ -61,7 +63,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
std::unique_ptr<xla::XlaBuilder> body_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
{
- auto parameter = body_builder->Parameter(0, tuple_shape, "parameter");
+ auto parameter =
+ xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
TF_ASSIGN_OR_RETURN(
auto result,
@@ -69,11 +72,11 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
body_builder.get()));
TF_RET_CHECK(result.size() == initial_values.size());
- body_builder->Tuple(result);
+ xla::Tuple(body_builder.get(), result);
}
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
- auto outputs = builder->While(cond, body, builder->Tuple(initial_values));
+ auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values));
return unpack_tuple(outputs, arity, builder);
}
@@ -86,9 +89,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
auto while_cond_fn =
[&](gtl::ArraySlice<xla::XlaOp> values,
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
- return cond_builder->Lt(
- values[0],
- IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
+ return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
+ num_iterations));
};
auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
xla::XlaBuilder* body_builder)
@@ -97,9 +99,9 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
std::vector<xla::XlaOp> updated_values;
updated_values.reserve(values.size());
- updated_values.push_back(body_builder->Add(
- iteration,
- body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
+ updated_values.push_back(xla::Add(
+ iteration, xla::ConstantLiteral(
+ body_builder, xla::Literal::One(num_iterations_type))));
values.remove_prefix(1);
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
@@ -112,7 +114,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
std::vector<xla::XlaOp> values;
values.reserve(initial_values.size() + 1);
values.push_back(
- builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
+ xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type)));
values.insert(values.end(), initial_values.begin(), initial_values.end());
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 8a0fb41afb..82dd48d43d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -416,7 +417,7 @@ Status BuildComputation(
// create a tuple/get-tuple-element combination so that sharding
// assignment will be placed on this value, which will cause the resource
// update to be returned from the same device that provided the resource.
- handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
+ handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
elems.push_back(handle);
}
@@ -426,7 +427,7 @@ Status BuildComputation(
// Builds the XLA computation.
if (always_return_tuple || elems.size() != 1) {
- builder->Tuple(elems);
+ xla::Tuple(builder, elems);
}
builder->ClearOpMetadata();
@@ -554,16 +555,16 @@ Status XlaCompiler::BuildArguments(
}
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
tuple_sharding);
- tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
+ tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
} else {
- tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
+ tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
}
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
- arg_handles[i] = builder->GetTupleElement(tuple, i);
+ arg_handles[i] = xla::GetTupleElement(tuple, i);
}
} else {
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
@@ -571,8 +572,8 @@ Status XlaCompiler::BuildArguments(
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
- arg_handles[i] =
- builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
+ arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
+ strings::StrCat("arg", i));
}
}
@@ -603,7 +604,7 @@ Status XlaCompiler::BuildArguments(
// return values of functions, and then reshape unconditionally.
if (is_entry_computation) {
arg_expression.set_handle(
- builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
+ xla::Reshape(arg_handles[i], arg.shape.dim_sizes()));
} else {
arg_expression.set_handle(arg_handles[i]);
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 67174b251d..d0b5606907 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -131,9 +131,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
xla::XlaBuilder b("max<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Max(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Max(x, y);
return b.Build().ConsumeValueOrDie();
});
}
@@ -145,9 +147,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) {
xla::XlaBuilder b("min<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Min(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Min(x, y);
return b.Build().ConsumeValueOrDie();
});
}
@@ -159,9 +163,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) {
xla::XlaBuilder b("add<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Add(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Add(x, y);
return b.Build().ConsumeValueOrDie();
});
}
@@ -173,9 +179,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) {
xla::XlaBuilder b("mul<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Mul(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Mul(x, y);
return b.Build().ConsumeValueOrDie();
});
}
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 31115eea60..917ef4037d 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -50,14 +50,14 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
xla::PrimitiveType xla_output_type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type));
- xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer,
- /*dimensions_to_reduce=*/{axis});
+ xla::XlaOp input_max = xla::Reduce(input, init_value, *reducer,
+ /*dimensions_to_reduce=*/{axis});
std::vector<int64> broadcast_dims(input_shape.dims() - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
// Compute a mask that has 1s for elements equal to the maximum.
- xla::XlaOp partial_mask = builder->ConvertElementType(
- builder->Eq(input, input_max, broadcast_dims), xla_output_type);
+ xla::XlaOp partial_mask = xla::ConvertElementType(
+ xla::Eq(input, input_max, broadcast_dims), xla_output_type);
// In order to make identity elements for a bitwise And, we:
// Left shift the 1 to the leftmost bit, yielding 0x10...0
@@ -67,8 +67,8 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1;
xla::XlaOp shift_amount =
XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type);
- xla::XlaOp full_mask = builder->ShiftRightArithmetic(
- builder->ShiftLeft(partial_mask, shift_amount), shift_amount);
+ xla::XlaOp full_mask = xla::ShiftRightArithmetic(
+ xla::ShiftLeft(partial_mask, shift_amount), shift_amount);
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
// index.
@@ -77,14 +77,14 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
const int64 axis_size = input_shape.dim_size(axis);
TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota));
xla::XlaOp product =
- builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis});
+ xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis});
// If there are multiple maximum elements, choose the one with the highest
// index.
xla::XlaOp output =
- builder->Reduce(product, XlaHelpers::MinValue(builder, output_type),
- *ctx->GetOrCreateMax(output_type),
- /*dimensions_to_reduce=*/{axis});
+ xla::Reduce(product, XlaHelpers::MinValue(builder, output_type),
+ *ctx->GetOrCreateMax(output_type),
+ /*dimensions_to_reduce=*/{axis});
*argminmax = output;
return Status::OK();
}
@@ -94,7 +94,7 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::MinValue(type));
+ return xla::ConstantLiteral(b, xla::Literal::MinValue(type));
}
xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) {
@@ -102,23 +102,23 @@ xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) {
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
switch (type) {
case xla::F16:
- return b->ConstantR0<Eigen::half>(
- Eigen::NumTraits<Eigen::half>::lowest());
+ return xla::ConstantR0<Eigen::half>(
+ b, Eigen::NumTraits<Eigen::half>::lowest());
case xla::BF16:
- return b->ConstantR0<bfloat16>(bfloat16::lowest());
+ return xla::ConstantR0<bfloat16>(b, bfloat16::lowest());
case xla::F32:
- return b->ConstantR0<float>(-std::numeric_limits<float>::max());
+ return xla::ConstantR0<float>(b, -std::numeric_limits<float>::max());
case xla::F64:
- return b->ConstantR0<double>(-std::numeric_limits<double>::max());
+ return xla::ConstantR0<double>(b, -std::numeric_limits<double>::max());
default:
- return b->ConstantLiteral(xla::Literal::MinValue(type));
+ return xla::ConstantLiteral(b, xla::Literal::MinValue(type));
}
}
xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::MaxValue(type));
+ return xla::ConstantLiteral(b, xla::Literal::MaxValue(type));
}
xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) {
@@ -126,42 +126,43 @@ xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) {
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
switch (type) {
case xla::F16:
- return b->ConstantR0<Eigen::half>(
- Eigen::NumTraits<Eigen::half>::highest());
+ return xla::ConstantR0<Eigen::half>(
+ b, Eigen::NumTraits<Eigen::half>::highest());
case xla::BF16:
- return b->ConstantR0<bfloat16>(bfloat16::highest());
+ return xla::ConstantR0<bfloat16>(b, bfloat16::highest());
case xla::F32:
- return b->ConstantR0<float>(std::numeric_limits<float>::max());
+ return xla::ConstantR0<float>(b, std::numeric_limits<float>::max());
case xla::F64:
- return b->ConstantR0<double>(std::numeric_limits<double>::max());
+ return xla::ConstantR0<double>(b, std::numeric_limits<double>::max());
default:
- return b->ConstantLiteral(xla::Literal::MaxValue(type));
+ return xla::ConstantLiteral(b, xla::Literal::MaxValue(type));
}
}
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::Zero(type));
+ return xla::ConstantLiteral(b, xla::Literal::Zero(type));
}
xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::One(type));
+ return xla::ConstantLiteral(b, xla::Literal::One(type));
}
xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) {
switch (data_type) {
case DT_HALF:
- return b->ConstantR0<Eigen::half>(
+ return xla::ConstantR0<Eigen::half>(
+ b,
static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
case DT_BFLOAT16:
- return b->ConstantR0<bfloat16>(bfloat16::epsilon());
+ return xla::ConstantR0<bfloat16>(b, bfloat16::epsilon());
case DT_FLOAT:
- return b->ConstantR0<float>(std::numeric_limits<float>::epsilon());
+ return xla::ConstantR0<float>(b, std::numeric_limits<float>::epsilon());
case DT_DOUBLE:
- return b->ConstantR0<double>(std::numeric_limits<double>::epsilon());
+ return xla::ConstantR0<double>(b, std::numeric_limits<double>::epsilon());
default:
LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: "
<< DataTypeString(data_type);
@@ -250,7 +251,7 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
xla::BorrowingLiteral linspace_literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
- *iota = builder->ConstantLiteral(linspace_literal);
+ *iota = xla::ConstantLiteral(builder, linspace_literal);
return Status::OK();
}
@@ -292,13 +293,13 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
std::vector<int64> broadcast_dims(indices_shape.dims());
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
- xla::XlaOp one_hot_bool = builder->Eq(
- indices, builder->ConstantLiteral(linspace_literal), broadcast_dims);
+ xla::XlaOp one_hot_bool = xla::Eq(
+ indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims);
// Selects the user-provided off_value and on_value values.
- *one_hot = builder->Select(
- one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()),
- builder->Broadcast(off_value, output_shape.dim_sizes()));
+ *one_hot = xla::Select(one_hot_bool,
+ xla::Broadcast(on_value, output_shape.dim_sizes()),
+ xla::Broadcast(off_value, output_shape.dim_sizes()));
return Status::OK();
}
@@ -316,7 +317,7 @@ xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder,
const DataType new_element_type) {
xla::PrimitiveType convert_to;
TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
- return builder->ConvertElementType(operand, convert_to);
+ return xla::ConvertElementType(operand, convert_to);
}
} // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index b58959bd6c..c2298b97e1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
@@ -128,7 +129,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
xla::XlaOp handle = expression->handle();
if (new_shape != tensor.shape()) {
// Reshape the handle to the desired shape.
- handle = builder()->Reshape(handle, new_shape.dim_sizes());
+ handle = xla::Reshape(handle, new_shape.dim_sizes());
}
// The XLA layout is specified minor to major, and TensorFlow's minor
@@ -342,8 +343,7 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
if (representation_shape == variable->shape()) {
*value = variable->value();
} else {
- *value =
- builder()->Reshape(variable->value(), variable->shape().dim_sizes());
+ *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
}
return Status::OK();
}
@@ -394,7 +394,7 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
xla::BorrowingLiteral literal;
OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
- xla::XlaOp handle = builder()->ConstantLiteral(literal);
+ xla::XlaOp handle = xla::ConstantLiteral(builder(), literal);
CHECK(handle.valid());
// Make the Tensor that will refer to the expression.
@@ -462,7 +462,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
TensorShape representation_shape =
xla_context.RepresentationShape(shape, type);
if (shape != representation_shape) {
- handle = builder()->Reshape(handle, representation_shape.dim_sizes());
+ handle = xla::Reshape(handle, representation_shape.dim_sizes());
}
return variable->SetValue(handle);
}
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index 540c65c597..baea814965 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -89,16 +90,16 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
}
switch (kind_) {
case kVariable: {
- value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
- shape_.dim_sizes());
+ value_ =
+ xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
break;
}
case kTensorArray: {
TensorShape ta_shape;
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
- value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
- ta_shape.dim_sizes());
+ value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes());
break;
}
case kStack: {
@@ -106,9 +107,9 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
value_ =
- builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
- ta_shape.dim_sizes()),
- builder->ConstantR0<int32>(0)});
+ xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes()),
+ xla::ConstantR0<int32>(builder, 0)});
break;
}
@@ -130,8 +131,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
TensorShape ta_shape;
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
- xla::XlaOp gradient_value = builder->Broadcast(
- XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
+ xla::XlaOp gradient_value =
+ xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
@@ -152,7 +153,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
for (const auto& gradient : tensor_array_gradients_) {
elems.push_back(gradient.second->value_);
}
- *pack = builder->Tuple(elems);
+ *pack = xla::Tuple(builder, elems);
}
return Status::OK();
}
@@ -168,7 +169,7 @@ Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
} else {
TF_RET_CHECK(kind_ == kTensorArray);
int pos = 0;
- auto v = builder->GetTupleElement(pack, pos++);
+ auto v = xla::GetTupleElement(pack, pos++);
if (!initialized()) {
initial_value_ = v;
}
@@ -178,7 +179,7 @@ Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
XlaResource* gradient;
TF_RETURN_IF_ERROR(
GetOrCreateTensorArrayGradient(source, builder, &gradient));
- auto v = builder->GetTupleElement(pack, pos++);
+ auto v = xla::GetTupleElement(pack, pos++);
if (!gradient->initialized()) {
gradient->initial_value_ = v;
}