diff options
author | Peter Hawkins <phawkins@google.com> | 2018-06-27 12:12:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-27 12:15:39 -0700 |
commit | 35cb434a9a95bef7ca8d7880d87dd9775eeba336 (patch) | |
tree | 976358f9a935cbbdf76407f60688c08b6484aeae | |
parent | 1536bba6be3e16f3983b79dd6931de313c900114 (diff) |
[TF:XLA] Refactor TF/XLA code to use free functions in xla:: namespace to build XlaOps, rather than calling XlaBuilder methods.
PiperOrigin-RevId: 202348891
86 files changed, 1167 insertions, 1104 deletions
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 4a6622ed73..4900af6df1 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -230,7 +231,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, XlaContext& context = XlaContext::Get(op_context); auto* b = context.builder(); - auto output_handle = b->Call(*result.computation, handles); + auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. int computation_output = 0; @@ -239,7 +240,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { xla_op_context.SetOutput( - i, b->GetTupleElement(output_handle, computation_output)); + i, xla::GetTupleElement(output_handle, computation_output)); ++computation_output; } } diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 1e59868621..e335328280 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -31,7 +32,7 @@ class AddNOp : public XlaOpKernel { xla::XlaOp sum = ctx->Input(0); for (int i = 1; i < ctx->num_inputs(); ++i) { - sum = ctx->builder()->Add(sum, ctx->Input(i)); + sum = xla::Add(sum, ctx->Input(i)); } ctx->SetOutput(0, sum); diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 93fbc40461..259fe05b09 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -49,8 +50,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); - xla::XlaBuilder* builder = ctx->builder(); - xla::XlaOp input = ctx->Input(0); TensorShape input_shape = ctx->InputShape(0); @@ -60,30 +59,30 @@ class FusedBatchNormOp : public XlaOpKernel { // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. As a workaround, cast everything to the statistics type (which // may be more precise than the input type). - input = builder->ConvertElementType(input, scale_type); + input = xla::ConvertElementType(input, scale_type); if (is_training_) { - xla::XlaOp output = builder->BatchNormTraining( + xla::XlaOp output = xla::BatchNormTraining( input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - ctx->SetOutput(0, builder->ConvertElementType( - builder->GetTupleElement(output, 0), input_type)); - ctx->SetOutput(1, builder->GetTupleElement(output, 1)); - ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0), + input_type)); + ctx->SetOutput(1, xla::GetTupleElement(output, 1)); + ctx->SetOutput(2, xla::GetTupleElement(output, 2)); // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, builder->GetTupleElement(output, 1)); - ctx->SetOutput(4, builder->GetTupleElement(output, 2)); + ctx->SetOutput(3, xla::GetTupleElement(output, 1)); + ctx->SetOutput(4, xla::GetTupleElement(output, 2)); } else { - xla::XlaOp output = builder->BatchNormInference( + xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); - ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); + ctx->SetOutput(0, xla::ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -144,12 +143,12 @@ class FusedBatchNormGradOp : public XlaOpKernel { xla::XlaOp offset_backprop; if (is_training_) { xla::XlaOp output = - b->BatchNormGrad(activations, scale, mean, var, grad_backprop, - epsilon_, feature_index); + xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); - x_backprop = b->GetTupleElement(output, 0); - scale_backprop = b->GetTupleElement(output, 1); - offset_backprop = b->GetTupleElement(output, 2); + x_backprop = xla::GetTupleElement(output, 0); + scale_backprop = xla::GetTupleElement(output, 1); + offset_backprop = xla::GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. std::vector<int64> reduction_dims(input_dims - 1); @@ -166,27 +165,27 @@ class FusedBatchNormGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); - auto scratch1 = - b->Pow(b->Add(var, b->ConstantR0<float>(epsilon_)), neg_half); + auto scratch1 = xla::Pow( + xla::Add(var, xla::ConstantR0<float>(b, epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) auto mul = - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = - b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); - scale_backprop = b->Mul(scratch1, scratch2); + xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); + scale_backprop = xla::Mul(scratch1, scratch2); } ctx->SetOutput(0, diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 642278ab99..26130fd9e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -45,7 +46,6 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ", 2] instead of ", xla::ShapeUtil::HumanString(crops.shape()))); - xla::XlaBuilder* b = ctx->builder(); const int64 batch_size = input_shape[0]; // Compute the product of the block_shape values. @@ -72,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, reshaped_shape[block_rank] = batch_size / block_num_elems; std::copy(input_shape.begin() + 1, input_shape.end(), reshaped_shape.begin() + block_rank + 1); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce `permuted` of shape // [batch / prod(block_shape), @@ -90,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, } std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::XlaOp permuted = b->Transpose(reshaped, permutation); + xla::XlaOp permuted = xla::Transpose(reshaped, permutation); // 3. Reshape `permuted` to produce `reshaped_permuted` of shape // [batch / prod(block_shape), @@ -110,7 +110,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_permuted_shape.begin() + 1 + block_rank); - xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape); + xla::XlaOp reshaped_permuted = + xla::Reshape(permuted, reshaped_permuted_shape); // 4. Crop the start and end of dimensions `[1, ..., M]` of // `reshaped_permuted` according to `crops` to produce the output of shape: @@ -138,7 +139,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); } xla::XlaOp output = - b->Slice(reshaped_permuted, start_indices, end_indices, strides); + xla::Slice(reshaped_permuted, start_indices, end_indices, strides); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 9d677f4266..e9b2c0b16d 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" @@ -60,8 +61,7 @@ class BiasOp : public XlaOpKernel { "of the input tensor: ", bias_shape.DebugString(), " vs. ", input_shape.DebugString())); - xla::XlaOp result = - ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); + xla::XlaOp result = xla::Add(ctx->Input(0), ctx->Input(1), {feature_dim}); ctx->SetOutput(0, result); } @@ -109,8 +109,8 @@ class BiasAddGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index fee939bdea..d6d4ae8937 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -41,18 +41,19 @@ namespace { const BCast& broadcast_helper, \ const std::vector<int64>& extend_dimensions) override { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ return HLO; \ } \ }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) -XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { @@ -67,13 +68,13 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); - auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero)); - auto abs_x = b->Abs(x); - auto abs_y = b->Abs(y); - auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); - auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); + auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); + auto abs_x = xla::Abs(x); + auto abs_y = xla::Abs(y); + auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one)); + auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y)); if (DataTypeIsFloating(dtype)) { - result = b->Floor(result); + result = xla::Floor(result); } return result; } @@ -87,76 +88,78 @@ static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); - auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); - auto trunc_mod = b->Rem(x, y); - return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y)); + auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); + auto trunc_mod = xla::Rem(x, y); + return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y)); } XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(BitwiseXor, b->Xor(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(RightShift, (DataTypeIsUnsigned(ctx->input_type(0)) - ? b->ShiftRightLogical(lhs, rhs, extend_dimensions) - : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions))); - -XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs)))); + ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions) + : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions))); + +XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs)))); XLA_MAKE_BINARY( RsqrtGrad, - b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), - b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), - extend_dimensions)); -XLA_MAKE_BINARY(SqrtGrad, - b->Div(b->Mul(rhs, - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - lhs, extend_dimensions)); + xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), + xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), + extend_dimensions)); +XLA_MAKE_BINARY( + SqrtGrad, + xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + lhs, extend_dimensions)); static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return builder->Mul(x, x); + return xla::Mul(x, x); } XLA_MAKE_BINARY(SquaredDifference, - Square(b, b->Sub(lhs, rhs, extend_dimensions))); + Square(b, xla::Sub(lhs, rhs, extend_dimensions))); -XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); // Comparison ops -XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions)); // Non-linear ops XLA_MAKE_BINARY(SigmoidGrad, - b->Mul(b->Mul(rhs, lhs), - b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); + xla::Mul(xla::Mul(rhs, lhs), + xla::Sub(XlaHelpers::One(b, input_type(0)), lhs))); XLA_MAKE_BINARY(SoftplusGrad, - b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); + xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, - b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)), - b->Abs(rhs))))); + xla::Div(lhs, + Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); -XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(lhs, lhs)))); +XLA_MAKE_BINARY(TanhGrad, + xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(lhs, lhs)))); -XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); #undef XLA_MAKE_BINARY @@ -169,12 +172,13 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); + auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1))); auto abs_shape = b->GetShape(abs); OP_REQUIRES_OK(ctx, abs_shape.status()); auto abs_type = abs_shape.ValueOrDie().element_type(); - auto result = b->Lt( - abs, b->ConvertElementType(b->ConstantR0<float>(tolerance_), abs_type)); + auto result = + xla::Lt(abs, xla::ConvertElementType( + xla::ConstantR0<float>(b, tolerance_), abs_type)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index ca9a6b4068..efbdb76eaa 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -36,22 +37,22 @@ class BucketizeOp : public XlaOpKernel { const DataType dtype = context->input_type(0); xla::XlaOp input = context->Input(0); - xla::XlaOp boundaries = builder->ConstantR1<float>(boundaries_); + xla::XlaOp boundaries = xla::ConstantR1<float>(builder, boundaries_); // TODO(phawkins): the following behavior matches the behavior of the core // Bucketize kernel. However, comparing an int32 or int64 against float may // lead to inaccurate bucketing due to rounding. if (dtype == DT_DOUBLE) { - input = builder->ConvertElementType(input, xla::F64); - boundaries = builder->ConvertElementType(boundaries, xla::F64); + input = xla::ConvertElementType(input, xla::F64); + boundaries = xla::ConvertElementType(boundaries, xla::F64); } else { - input = builder->ConvertElementType(input, xla::F32); + input = xla::ConvertElementType(input, xla::F32); } - xla::XlaOp comparison = builder->ConvertElementType( - builder->Ge(builder->Broadcast(input, {1}), boundaries, - /*broadcast_dimensions=*/{0}), - xla::S32); - xla::XlaOp buckets = builder->Reduce( - comparison, /*init_value=*/builder->ConstantR0<int32>(0), + xla::XlaOp comparison = + xla::ConvertElementType(xla::Ge(xla::Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = xla::Reduce( + comparison, /*init_value=*/xla::ConstantR0<int32>(builder, 0), /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); context->SetOutput(0, buckets); diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index e9d98c7685..62eebf762b 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -40,14 +41,14 @@ class CastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else if (dst_dtype_ == DT_BOOL) { - output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); + output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_)); } else if (xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_)) { // As in cast_op.h, we replicate the numpy behavior of truncating the // imaginary part. - output = builder->ConvertElementType(builder->Real(input), dst_type_); + output = xla::ConvertElementType(xla::Real(input), dst_type_); } else { - output = builder->ConvertElementType(input, dst_type_); + output = xla::ConvertElementType(input, dst_type_); } ctx->SetOutput(0, output); @@ -72,7 +73,6 @@ class BitcastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); xla::XlaOp input = ctx->Input(0); xla::XlaOp output; @@ -92,7 +92,7 @@ class BitcastOp : public XlaOpKernel { xla::primitive_util::BitWidth(dst_type_), errors::Unimplemented( "Only bitcasts between equally sized types supported.")); - output = builder->BitcastConvertType(input, dst_type_); + output = xla::BitcastConvertType(input, dst_type_); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 835a7f5689..c137d026bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -65,17 +66,17 @@ class CategoricalOp : public XlaOpKernel { DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); xla::Shape uniform_shape = xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); - auto uniforms = builder->RngUniform( - XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. auto softmax_entries = - builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))), - /*broadcast_dimensions=*/{0, 2}); + xla::Sub(logits, xla::Log(xla::Neg(xla::Log(uniforms))), + /*broadcast_dimensions=*/{0, 2}); TensorShape softmax_shape(uniform_shape_array); xla::XlaOp argmax; diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index a00bc912f9..4e6d33304c 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -29,7 +30,6 @@ class ClipByValueOp : public XlaOpKernel { const TensorShape min_shape = ctx->InputShape(1); const TensorShape max_shape = ctx->InputShape(2); - xla::XlaBuilder* builder = ctx->builder(); auto input = ctx->Input(0); auto min = ctx->Input(1); auto max = ctx->Input(2); @@ -45,13 +45,13 @@ class ClipByValueOp : public XlaOpKernel { if (shape != min_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); - min = builder->Broadcast(min, shape.dim_sizes()); + min = xla::Broadcast(min, shape.dim_sizes()); } if (shape != max_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); - max = builder->Broadcast(max, shape.dim_sizes()); + max = xla::Broadcast(max, shape.dim_sizes()); } - ctx->SetOutput(0, builder->Clamp(min, input, max)); + ctx->SetOutput(0, xla::Clamp(min, input, max)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 78285affa1..e3a32a5c0e 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -88,7 +89,7 @@ class ConcatBaseOp : public XlaOpKernel { "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. - input_data.push_back(ctx->builder()->Reshape(handle, {1})); + input_data.push_back(xla::Reshape(handle, {1})); } else { input_data.push_back(handle); } @@ -96,7 +97,7 @@ class ConcatBaseOp : public XlaOpKernel { } VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 59d06c654d..f4360d8c3f 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -53,41 +54,41 @@ class ConstOp : public XlaOpKernel { switch (proto_.dtype()) { case DT_BOOL: if (proto_.bool_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0<bool>(proto_.bool_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0<bool>(b, proto_.bool_val(0)), + shape.dim_sizes())); return; } break; case DT_FLOAT: if (proto_.float_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0<float>(proto_.float_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<float>( + b, proto_.float_val(0)), + shape.dim_sizes())); return; } break; case DT_DOUBLE: if (proto_.double_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0<double>(proto_.double_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<double>( + b, proto_.double_val(0)), + shape.dim_sizes())); return; } break; case DT_INT32: if (proto_.int_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0<int32>(proto_.int_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0<int32>(b, proto_.int_val(0)), + shape.dim_sizes())); return; } break; case DT_INT64: if (proto_.int64_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0<int64>(proto_.int64_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<int64>( + b, proto_.int64_val(0)), + shape.dim_sizes())); return; } break; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 627bad12f3..5d41fc708a 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -51,8 +52,8 @@ xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); + return xla::Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); } // Create a mask for depthwise convolution that will make a normal convolution @@ -107,20 +108,20 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, // Divide the M*N sized linspace by the depthwise_multiplier to create // [0 0 1 1 2 2] in the example in the function comment. expanded_feature_iota = - builder->Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); // Broadcast the N*M linspace to [H, W, ..., M, M*N]. auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = builder->Broadcast( - expanded_feature_iota, expanded_feature_broadcast_dims); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); // Compare the broadcasted linspace to the input feature linspace in the // input feature dimension to create a diagonal predicate. - return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); } // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding @@ -142,16 +143,16 @@ xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, implicit_broadcast_filter_shape.dims() - 1, depthwise_multiplier * input_feature); auto implicit_broadcast_filter = - builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); // Broadcast the filter to [H, W, ..., M, M*N]. auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); - auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero); // If the filter mask is set, choose the broadcasted filter, othwerwise, // choose zero. - return builder->Select(CreateExpandedFilterMask(filter_shape, builder), - expanded_filter, expanded_zero); + return xla::Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. @@ -162,17 +163,17 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - auto masked_expanded_filter = builder->Select( + auto masked_expanded_filter = xla::Select( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); - return builder->Reshape( + return xla::Reshape( // This reduce does not need inputs to be converted with // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with // ExpandedZero guarantees that only one element is non zero, so there // cannot be accumulated precision error. - builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), - {expanded_filter_shape.dims() - 2}), + xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), filter_shape.dim_sizes()); } @@ -289,8 +290,8 @@ class ConvOp : public XlaOpKernel { } xla::XlaOp conv = - b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } @@ -435,11 +436,11 @@ class ConvBackpropInputOp : public XlaOpKernel { } // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims); + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) <conv> mirrored_weights - xla::XlaOp in_backprop = b->ConvGeneralDilated( + xla::XlaOp in_backprop = xla::ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums); @@ -638,8 +639,8 @@ class ConvBackpropFilterOp : public XlaOpKernel { // This is done by specifying the window dilation factors in the // convolution HLO below. auto filter_backprop = - b->ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); if (depthwise_) { filter_backprop = ContractFilterForDepthwiseBackprop( diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 7fcd4170fb..500a564f3f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -58,21 +59,21 @@ class CrossOp : public XlaOpKernel { auto in1 = ctx->Input(1); starts.back() = 0; limits.back() = 1; - auto u1 = b->Slice(in0, starts, limits, strides); - auto v1 = b->Slice(in1, starts, limits, strides); + auto u1 = xla::Slice(in0, starts, limits, strides); + auto v1 = xla::Slice(in1, starts, limits, strides); starts.back() = 1; limits.back() = 2; - auto u2 = b->Slice(in0, starts, limits, strides); - auto v2 = b->Slice(in1, starts, limits, strides); + auto u2 = xla::Slice(in0, starts, limits, strides); + auto v2 = xla::Slice(in1, starts, limits, strides); starts.back() = 2; limits.back() = 3; - auto u3 = b->Slice(in0, starts, limits, strides); - auto v3 = b->Slice(in1, starts, limits, strides); + auto u3 = xla::Slice(in0, starts, limits, strides); + auto v3 = xla::Slice(in1, starts, limits, strides); - auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2)); - auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3)); - auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1)); - auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1); + auto s1 = xla::Sub(xla::Mul(u2, v3), xla::Mul(u3, v2)); + auto s2 = xla::Sub(xla::Mul(u3, v1), xla::Mul(u1, v3)); + auto s3 = xla::Sub(xla::Mul(u1, v2), xla::Mul(u2, v1)); + auto output = xla::ConcatInDim(b, {s1, s2, s3}, in0_shape.dims() - 1); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 01aa1a83e7..9ff3e02228 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -96,18 +96,16 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // First reshape the inputs, which should be a metadata-only // operation since we are flattening the dimensions in order. - auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape()); + auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); + auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); // Next broadcast the necessary input dimensions. We rely on the // XLA optimizer to be smart about the fact that we are asking // it to broadcast size 1 on some of these dimensions, to avoid // adding complexity to this code. - auto lhs_broadcast = - builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast()); + auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = - builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast()); + auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); int rhs_size = broadcast_helper.y_bcast().size(); // Now reshape them to the correct output shape. After the @@ -122,15 +120,15 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { lhs_reorder.push_back(i); lhs_reorder.push_back(i + lhs_size); } - auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder, - broadcast_helper.output_shape()); + auto lhs_output = + xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); std::vector<int64> rhs_reorder; for (int i = 0; i < rhs_size; ++i) { rhs_reorder.push_back(i); rhs_reorder.push_back(i + rhs_size); } - auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder, - broadcast_helper.output_shape()); + auto rhs_output = + xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); return {lhs_output, rhs_output}; } diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 23243f6246..f314920025 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class DepthToSpaceOp : public XlaOpKernel { const gtl::InlinedVector<int64, 4> input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel { ") is not divisible by square of the block size (", block_size_, ")")); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -141,7 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2], // block_size_, // depth / (block_size_ * block_size_)] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -151,7 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 931705ba83..17bf0c069c 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -41,13 +42,13 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal( xla::XlaOp iota; TF_RETURN_IF_ERROR( XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); - xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size}); - xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0}); + xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size}); + xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0}); // If this is a batched diagonal, broadcast the mask across the other // dimensions. if (!other_dims.empty()) { - mask = builder->Broadcast(mask, other_dims); + mask = xla::Broadcast(mask, other_dims); } // Broadcast the input, and then use the mask computed above to select the @@ -64,7 +65,7 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal( std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end()); broadcast_dims.push_back(1LL); broadcast_dims.push_back(last_dim_size); - xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims); + xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; xla::PrimitiveType element_type; @@ -74,8 +75,8 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal( xla::ShapeUtil::MakeShape(element_type, broadcast_dims); xla::XlaOp zeros = Zeros(builder, broadcast_shape); - input_broadcast = builder->Add(input_broadcast, zeros); - return builder->Select(mask, input_broadcast, zeros); + input_broadcast = xla::Add(input_broadcast, zeros); + return xla::Select(mask, input_broadcast, zeros); } class DiagOp : public XlaOpKernel { @@ -104,7 +105,7 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - input = builder->Reshape(input, {size}); + input = xla::Reshape(input, {size}); // Create an R2 with the R1 diagonal. auto diag_or_status = @@ -116,7 +117,7 @@ class DiagOp : public XlaOpKernel { std::vector<int64> new_dims(dims.size() * 2); std::copy(dims.begin(), dims.end(), new_dims.begin()); std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size()); - diag = builder->Reshape(diag, new_dims); + diag = xla::Reshape(diag, new_dims); ctx->SetOutput(0, diag); } @@ -170,21 +171,21 @@ class DiagPartOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + diag = xla::Reshape(diag, {size}); // Adds padding after the last element of 'new_size'. xla::PaddingConfig config; auto* dim = config.add_dimensions(); dim->set_edge_padding_high(new_size); auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); + diag = xla::Pad(diag, zero, config); // Reshapes so the diagonal is now in the first column. - diag = builder->Reshape(diag, {new_size, new_size + 1}); + diag = xla::Reshape(diag, {new_size, new_size + 1}); // Slices out the first column and reshapes to the final shape. - diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); - diag = builder->Reshape(diag, new_dims); + diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); + diag = xla::Reshape(diag, new_dims); ctx->SetOutput(0, diag); } @@ -265,7 +266,7 @@ class MatrixDiagPartOp : public XlaOpKernel { // Collapses the last two dimensions. std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1); flattened_dims.back() *= dims.back(); - diag = builder->Reshape(diag, flattened_dims); + diag = xla::Reshape(diag, flattened_dims); // Slices or pads the last dimension to 'target_size'. int64 actual_size = flattened_dims.back(); @@ -276,13 +277,13 @@ class MatrixDiagPartOp : public XlaOpKernel { auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); dim->set_edge_padding_high(target_size - actual_size); auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); + diag = xla::Pad(diag, zero, config); } else if (actual_size > target_size) { std::vector<int64> start(flattened_dims.size(), 0); std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end()); std::vector<int64> strides(flattened_dims.size(), 1); limits[flattened_dims.size() - 1] = target_size; - diag = builder->Slice(diag, start, limits, strides); + diag = xla::Slice(diag, start, limits, strides); } // Reshape so the target values are in the first position of the last @@ -290,18 +291,18 @@ class MatrixDiagPartOp : public XlaOpKernel { std::vector<int64> unflattened_dims(dims.begin(), dims.end()); dims[last_dim - 1] = smaller_dim_size; dims[last_dim] = last_dim_size + 1; - diag = builder->Reshape(diag, dims); + diag = xla::Reshape(diag, dims); // Slices out the first column and reshapes to the final shape. std::vector<int64> start(dims.size(), 0); std::vector<int64> limits(dims.begin(), dims.end()); std::vector<int64> strides(dims.size(), 1); limits[last_dim] = 1; - diag = builder->Slice(diag, start, limits, strides); + diag = xla::Slice(diag, start, limits, strides); // Collapses away the last dimension. dims.pop_back(); - diag = builder->Reshape(diag, dims); + diag = xla::Reshape(diag, dims); ctx->SetOutput(0, diag); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 0419de78b2..3b86ea34c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -57,8 +57,8 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = ctx->builder()->DynamicUpdateSlice( - ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = + xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index dd4a169087..958231505b 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -150,8 +151,7 @@ class DynamicStitchOp : public XlaOpKernel { if (new_shape == data_shapes[input_num]) { input[input_num] = handle; } else { - input[input_num] = - ctx->builder()->Reshape(handle, new_shape.dim_sizes()); + input[input_num] = xla::Reshape(handle, new_shape.dim_sizes()); } } @@ -175,10 +175,10 @@ class DynamicStitchOp : public XlaOpKernel { // And place it in the concat list in the place indicated by // the index. to_concat[index_num] = - ctx->builder()->Slice(expression, slice_start, slice_limit, stride); + xla::Slice(expression, slice_start, slice_limit, stride); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), to_concat, 0)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 493781a1e6..2c76bcee25 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -34,9 +34,9 @@ class EluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Expm1(ctx->Input(0)); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), expm1)); } }; @@ -51,9 +51,9 @@ class EluGradOp : public XlaOpKernel { const auto one = XlaHelpers::One(b, input_type(0)); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto exp_grad = b->Mul(grad, b->Add(activation, one)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, one)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, grad, exp_grad)); } }; @@ -71,10 +71,10 @@ class SeluOp : public XlaOpKernel { 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Expm1(ctx->Input(0)); - ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), - b->Mul(scale_alpha, expm1))); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, xla::Mul(scale, ctx->Input(0)), + xla::Mul(scale_alpha, expm1))); } }; @@ -92,10 +92,10 @@ class SeluGradOp : public XlaOpKernel { 1.7580993408473768599402175208123); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto lin_grad = b->Mul(grad, scale); - const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad)); + const auto lin_grad = xla::Mul(grad, scale); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, scale_alpha)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, lin_grad, exp_grad)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 6df01cabbf..b2451236de 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -114,9 +115,9 @@ class ExtractImagePatchesOp : public XlaOpKernel { TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, kernel_size * depth, &iota)); - auto lhs = builder->Reshape(iota, lhs_shape); - auto filter = builder->ConvertElementType( - builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); + auto lhs = xla::Reshape(iota, lhs_shape); + auto filter = xla::ConvertElementType( + xla::Eq(lhs, iota, {num_spatial_dims + 1}), type); xla::ConvolutionDimensionNumbers dims; std::vector<int64> window_strides(num_spatial_dims); @@ -148,8 +149,8 @@ class ExtractImagePatchesOp : public XlaOpKernel { } xla::XlaOp conv = - builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, - padding, lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 8f0de0a524..2fd1a34741 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -49,20 +50,20 @@ void XlaNudge(xla::XlaBuilder* b, const DataType data_type, const float quant_min_value, const float quant_max_value, xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, xla::XlaOp* scale) { - *scale = b->Div(b->Sub(max, min), - XlaHelpers::FloatLiteral(b, data_type, - quant_max_value - quant_min_value)); + *scale = xla::Div(xla::Sub(max, min), + XlaHelpers::FloatLiteral( + b, data_type, quant_max_value - quant_min_value)); xla::XlaOp quant_min = XlaHelpers::FloatLiteral(b, data_type, quant_min_value); - xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale)); + xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); xla::XlaOp quant_max = XlaHelpers::FloatLiteral(b, data_type, quant_max_value); xla::XlaOp nudged_zero_point = - b->Select(b->Le(zero_point_from_min, quant_min), quant_min, - b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, - b->Round(zero_point_from_min))); - *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale); - *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); + xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min, + xla::Select(xla::Ge(zero_point_from_min, quant_max), + quant_max, xla::Round(zero_point_from_min))); + *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale); + *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale); } xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, @@ -71,14 +72,14 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, const xla::XlaOp& nudged_input_max, const xla::XlaOp& input_scale) { xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); - xla::XlaOp inv_scale = b->Div(one, input_scale); + xla::XlaOp inv_scale = xla::Div(one, input_scale); xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); - xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max); - xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min); + xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max); + xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min); xla::XlaOp rounded = - b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); - return b->Add(b->Mul(rounded, input_scale), nudged_input_min); + xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half)); + return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min); } class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { @@ -163,11 +164,11 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type), - gradient_shape.dim_sizes()); - xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); + xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type), + gradient_shape.dim_sizes()); + xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -249,25 +250,25 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes()); - xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); - xla::XlaOp below_min = b->Lt(input, nudged_input_min); - xla::XlaOp select1 = b->Select(below_min, gradient, zeroes); - xla::XlaOp reduce1 = b->ReduceAll( + xla::XlaOp below_min = xla::Lt(input, nudged_input_min); + xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes); + xla::XlaOp reduce1 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); - xla::XlaOp above_max = b->Gt(input, nudged_input_max); - xla::XlaOp select2 = b->Select(above_max, gradient, zeroes); - xla::XlaOp reduce2 = b->ReduceAll( + xla::XlaOp above_max = xla::Gt(input, nudged_input_max); + xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes); + xla::XlaOp reduce2 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 933924cad1..b2b00e51e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,8 +63,7 @@ class GenericFftOp : public XlaOpKernel { } } - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length); ctx->SetOutput(0, fft); } diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index e4467a0fb1..95faa1d058 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -59,11 +60,11 @@ class FillOp : public XlaOpKernel { xla::XlaOp data = ctx->Input(1); if (value_shape.dims() > 0) { CHECK_EQ(value_shape.dims(), 1); - data = ctx->builder()->Reshape(data, {}); + data = xla::Reshape(data, {}); } // Emit the actual computation, which broadcasts the scalar to the // desired shape. - auto result = ctx->builder()->Broadcast(data, broadcast); + auto result = xla::Broadcast(data, broadcast); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index d13e25bcdd..5f041be5df 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -75,8 +76,8 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, out_shape.AppendShape(indices_shape_no_index_vectors); out_shape.AppendShape(input_shape_post_axis); - *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); + *gather_output = + xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); return Status::OK(); } @@ -142,7 +143,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, dim_numbers.add_gather_dims_to_operand_dims(i); } - *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index d48c6eea75..f5fcf3cacd 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -199,13 +200,13 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp outputs = - b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, - b->Tuple(inputs), *else_result.computation); + xla::XlaOp outputs = xla::Conditional( + ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, + xla::Tuple(b, inputs), *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { - xla::XlaOp output_handle = b->GetTupleElement(outputs, i); + xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); if (VLOG_IS_ON(2)) { LOG(INFO) << "Setting output " << i; auto shape_or = b->GetShape(output_handle); @@ -233,7 +234,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - b->GetTupleElement(outputs, pos), b)); + xla::GetTupleElement(outputs, pos), b)); } VLOG(2) << "If variable: pos: " << update.input_index << " name: " << resource->name() diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 1568b33679..7a36514eb3 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -32,23 +33,26 @@ std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, auto red = rgb[0]; auto green = rgb[1]; auto blue = rgb[2]; - auto value = b->Max(b->Max(red, green), blue); - auto minimum = b->Min(b->Min(red, green), blue); - auto range = b->Sub(value, minimum); - - auto zeros = b->Broadcast(zero, shape.dim_sizes()); - auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros); - - auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); - - auto hue = b->Select(b->Eq(green, value), - b->Add(b->Mul(norm, b->Sub(blue, red)), - XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), - b->Add(b->Mul(norm, b->Sub(red, green)), - XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); - hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue); - hue = b->Select(b->Gt(range, zero), hue, zeros); - hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue); + auto value = xla::Max(xla::Max(red, green), blue); + auto minimum = xla::Min(xla::Min(red, green), blue); + auto range = xla::Sub(value, minimum); + + auto zeros = xla::Broadcast(zero, shape.dim_sizes()); + auto saturation = + xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); + + auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); + + auto hue = + xla::Select(xla::Eq(green, value), + xla::Add(xla::Mul(norm, xla::Sub(blue, red)), + XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), + xla::Add(xla::Mul(norm, xla::Sub(red, green)), + XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); + hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)), + hue); + hue = xla::Select(xla::Gt(range, zero), hue, zeros); + hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue); return {hue, saturation, value}; } @@ -66,15 +70,15 @@ std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b, auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); - auto dh = b->Mul(hue, six); - auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one); - auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one); - auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one); - auto one_minus_s = b->Sub(one, saturation); + auto dh = xla::Mul(hue, six); + auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one); + auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one); + auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one); + auto one_minus_s = xla::Sub(one, saturation); - auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value); - auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value); - auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value); + auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value); + auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value); + auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value); return {red, green, blue}; } @@ -111,7 +115,7 @@ class RGBToHSVOp : public XlaOpKernel { auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - context->SetOutput(0, b->ConcatInDim(hsv, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim)); } }; REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); @@ -147,7 +151,7 @@ class HSVToRGBOp : public XlaOpKernel { auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); @@ -182,18 +186,20 @@ class AdjustContrastOpV2 : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, input, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *context->GetOrCreateAdd(accumulation_type), - {height_dim, width_dim}); + auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), + {height_dim, width_dim}); auto output = XlaHelpers::ConvertElementType(b, reduce, type); - output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + output = + xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector<int64> broadcast_dims(input_shape.dims() - 2); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); broadcast_dims.back() = channel_dim; - output = b->Add(b->Mul(input, factor), - b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)), - broadcast_dims); + output = + xla::Add(xla::Mul(input, factor), + xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)), + broadcast_dims); context->SetOutput(0, output); } }; @@ -240,12 +246,12 @@ class AdjustSaturationOp : public XlaOpKernel { auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale), - XlaHelpers::One(b, type)); + hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale), + XlaHelpers::One(b, type)); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); @@ -294,12 +300,13 @@ class AdjustHueOp : public XlaOpKernel { auto one = XlaHelpers::One(b, type); auto& hue = hsv[0]; - hue = b->Rem(b->Add(hsv[0], delta), one); - hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue); + hue = xla::Rem(xla::Add(hsv[0], delta), one); + hue = + xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 79d3a6979c..de971ce4ac 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/math/math_util.h" @@ -132,17 +133,16 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, TF_CHECK_OK( XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); - auto diag = builder->ConvertElementType( - builder->Eq( - builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, + auto diag = xla::ConvertElementType( + xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); - return builder->Mul( - builder->Mul(diag, - builder->ConstantR1<float>(Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}), - builder->ConstantR1<float>(Make1DKernel(kernel_size[0])), + return xla::Mul( + xla::Mul(diag, + xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}), + xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } @@ -154,21 +154,21 @@ xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, TF_CHECK_OK( XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); - auto diag = builder->ConvertElementType( - builder->Eq(builder->Broadcast( - channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + auto diag = xla::ConvertElementType( + xla::Eq( + xla::Broadcast(channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); if (dim == 1) { - return builder->Mul( - diag, builder->ConstantR1<float>(Make1DKernel(kernel_size[1])), + return xla::Mul( + diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}); } - return builder->Mul(diag, - builder->ConstantR1<float>(Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + return xla::Mul(diag, + xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -208,7 +208,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( input, kernel, dims.stride, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -218,7 +218,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( input, kernel0, {dims.stride[0], 1}, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, @@ -226,7 +226,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, @@ -238,8 +238,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, // size > 1 dimension. for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] == 1 && out_size[i] > 1) { - output = builder->Add(output, builder->ConstantR1<float>(out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); + output = xla::Add(output, xla::ConstantR1<float>(builder, out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); } } return output; @@ -279,12 +279,12 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] == 1 && grad_size[i] > 1) { kernel = - builder->Add(kernel, builder->ConstantR1<float>(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + xla::Add(kernel, xla::ConstantR1<float>(builder, grad_size[i], 0), + /*broadcast_dimensions=*/{i}); } } - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -302,23 +302,23 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, // gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { kernel0 = - builder->Add(kernel0, builder->ConstantR1<float>(grad_size[0], 0), - /*broadcast_dimensions=*/{0}); + xla::Add(kernel0, xla::ConstantR1<float>(builder, grad_size[0], 0), + /*broadcast_dimensions=*/{0}); } if (in_size[1] == 1 && grad_size[1] > 1) { kernel1 = - builder->Add(kernel0, builder->ConstantR1<float>(grad_size[1], 0), - /*broadcast_dimensions=*/{1}); + xla::Add(kernel0, xla::ConstantR1<float>(builder, grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, @@ -337,7 +337,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } } if (pad_output) { - output = builder->Pad(output, builder->ConstantR0<float>(0.0f), padding); + output = xla::Pad(output, xla::ConstantR0<float>(builder, 0.0f), padding); } return output; } @@ -393,13 +393,13 @@ class ResizeBilinearOp : public XlaOpKernel { } } if (slice_input) { - input = b->Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = xla::Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); } // Output is always type float. - input = b->ConvertElementType(input, xla::F32); + input = xla::ConvertElementType(input, xla::F32); // Special Case: // Instead of doing a ResizeUsingDilationAndConvolution directly, @@ -529,7 +529,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { } } - output = b->ConvertElementType(output, output_type_); + output = xla::ConvertElementType(output, output_type_); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 2c2d88486f..a020ebc729 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // XLA passes <out> to the function, so it is not included here. std::vector<xla::XlaOp> args; args.push_back(ctx->Input(0)); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1<int64>(input_shape.dim_sizes()))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1<int64>(output_shape.dim_sizes()))); - args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim))); + args.push_back(xla::ConstantLiteral( + &b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes()))); + args.push_back( + xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim))); } xla::Shape xla_shape = @@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); break; case 2: - output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 1decf7d72d..9e64711051 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -39,12 +39,12 @@ class L2LossOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto t = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); - auto square = b->Mul(t, t); - auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), dims); + auto square = xla::Mul(t, t); + auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - ctx->SetOutput(0, b->Div(deconverted, two)); + ctx->SetOutput(0, xla::Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 0388b4c830..2fb072f827 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" @@ -90,8 +91,10 @@ class ListDiffOp : public XlaOpKernel { idx_output.push_back(i); } - context->SetOutput(0, context->builder()->ConstantR1<Tval>(val_output)); - context->SetOutput(1, context->builder()->ConstantR1<Tidx>(idx_output)); + context->SetOutput(0, + xla::ConstantR1<Tval>(context->builder(), val_output)); + context->SetOutput(1, + xla::ConstantR1<Tidx>(context->builder(), idx_output)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 39fbf98a62..dc934543cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -50,8 +51,8 @@ class LRNOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, input, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -59,12 +60,12 @@ class LRNOp : public XlaOpKernel { auto sqr_sum = XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); - auto scale = builder->Pow( - builder->Add(builder->ConstantR0<float>(bias_), - builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)), - builder->ConstantR0<float>(-beta_)); + auto scale = xla::Pow( + xla::Add(xla::ConstantR0<float>(builder, bias_), + xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum)), + xla::ConstantR0<float>(builder, -beta_)); - ctx->SetOutput(0, builder->Mul(input, scale)); + ctx->SetOutput(0, xla::Mul(input, scale)); } private: @@ -138,8 +139,8 @@ class LRNGradOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -148,17 +149,17 @@ class LRNGradOp : public XlaOpKernel { XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = - builder->Add(builder->ConstantR0<float>(bias_), - builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)); + xla::Add(xla::ConstantR0<float>(builder, bias_), + xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum)); - auto dy = builder->Mul( - builder->Mul(builder->ConstantR0<float>(-2.0f * alpha_ * beta_), - builder->Div(out_image, norm)), + auto dy = xla::Mul( + xla::Mul(xla::ConstantR0<float>(builder, -2.0f * alpha_ * beta_), + xla::Div(out_image, norm)), in_grads); auto converted_dy = XlaHelpers::ConvertElementType(builder, dy, accumulation_type); - auto dy_reduce = builder->ReduceWindow( + auto dy_reduce = xla::ReduceWindow( converted_dy, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -166,10 +167,10 @@ class LRNGradOp : public XlaOpKernel { auto dy_reduced = XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); - xla::XlaOp gradients = builder->Add( - builder->Mul(in_image, dy_reduced), - builder->Mul(in_grads, - builder->Pow(norm, builder->ConstantR0<float>(-beta_)))); + xla::XlaOp gradients = xla::Add( + xla::Mul(in_image, dy_reduced), + xla::Mul(in_grads, + xla::Pow(norm, xla::ConstantR0<float>(builder, -beta_)))); ctx->SetOutput(0, gradients); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6949b296f4..844080b8cf 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -70,15 +71,15 @@ class MatMulOp : public XlaOpKernel { xla::XlaOp b = ctx->Input(1); if (is_sparse_) { if (a_type_ == DT_BFLOAT16) { - a = ctx->builder()->ConvertElementType(a, xla::F32); + a = xla::ConvertElementType(a, xla::F32); } if (b_type_ == DT_BFLOAT16) { - b = ctx->builder()->ConvertElementType(b, xla::F32); + b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs)); + auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; + auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; + ctx->SetOutput(0, xla::Dot(lhs, rhs)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index fbd5dc0fda..9d3575e331 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -64,27 +65,26 @@ class MatrixBandPartOp : public XlaOpKernel { xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); - auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, - /*broadcast_dimensions=*/{0}); + auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, + /*broadcast_dimensions=*/{0}); // If num_lower or num_upper are negative, include all lower/upper // diagonals. auto zero_index = XlaHelpers::Zero(builder, index_type); - num_lower = builder->Select( - builder->Lt(num_lower, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower); - num_upper = builder->Select( - builder->Lt(num_upper, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper); + num_lower = xla::Select(xla::Lt(num_lower, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, m), + num_lower); + num_upper = xla::Select(xla::Lt(num_upper, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, n), + num_upper); - auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset), - builder->Le(offset, num_upper)); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + auto indicator = xla::And(xla::Le(xla::Neg(num_lower), offset), + xla::Le(offset, num_upper)); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); auto zero_input = XlaHelpers::Zero(builder, input_type); - auto output = builder->Select( - indicator, input, - builder->Broadcast(zero_input, input_shape.dim_sizes())); + auto output = xla::Select( + indicator, input, xla::Broadcast(zero_input, input_shape.dim_sizes())); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index db53f6fef8..7bf1894ea0 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -65,10 +66,9 @@ class MatrixSetDiagOp : public XlaOpKernel { OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); - auto indicator = builder->Eq(iota_m, - builder->Broadcast(iota_n, {m}), - /*broadcast_dimensions=*/{0}); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), + /*broadcast_dimensions=*/{0}); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); // Broadcast diag up to the input shape. Use an implicit broadcast (Add) // because we need to broadcast on the right. @@ -77,10 +77,10 @@ class MatrixSetDiagOp : public XlaOpKernel { if (min_dim != m) { diag_broadcast_dims.back() = rank - 1; } - diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); + diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); - auto output = builder->Select(indicator, diag, input); + auto output = xla::Select(indicator, diag, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index c3326b4d11..caeba66b52 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { @@ -32,7 +33,7 @@ class MirrorPadOp : public XlaOpKernel { xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { - auto t_rev = b->Rev(accum, {dimno}); + auto t_rev = xla::Rev(accum, {dimno}); TF_ASSIGN_OR_RETURN(int64 lhs_padding, pad_literal.GetIntegralAsS64({dimno, 0})); TF_ASSIGN_OR_RETURN(int64 rhs_padding, @@ -41,7 +42,7 @@ class MirrorPadOp : public XlaOpKernel { auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding, dim_size - 1, 1, dimno); auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); - accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno); + accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; } diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index aecaabb6dc..3aed47de26 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,11 +77,10 @@ class PackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. - reshaped_inputs[i] = - ctx->builder()->Reshape(values[i], child_shape.dim_sizes()); + reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes()); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 17b85338f7..89fd610bc6 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -74,11 +75,10 @@ class PadOp : public XlaOpKernel { if (ctx->num_inputs() == 3) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, - ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index eb8b5b130f..771dcbab21 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -113,8 +114,8 @@ class PoolingOp : public XlaOpKernel { xla::XlaBuilder* const b = ctx->builder(); auto input = XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = ctx->builder()->ReduceWindow( - input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); + auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, + stride, padding_); auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); ctx->SetOutput(0, PostProcessOutput(ctx, pooled, input_type(0), input_shape)); @@ -190,7 +191,7 @@ static xla::XlaOp AvgPoolDivideByCount( auto divisor = XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return ctx->builder()->Div(output, divisor); + return xla::Div(output, divisor); } else { // For SAME padding, the padding shouldn't be included in the // counts. We use another ReduceWindow to find the right counts. @@ -212,18 +213,18 @@ static xla::XlaOp AvgPoolDivideByCount( // Build a matrix of all 1s, with the same width/height as the input. const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = ctx->builder()->Broadcast( + auto ones = xla::Broadcast( XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto reduce = ctx->builder()->ReduceWindow( + auto reduce = xla::ReduceWindow( ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - return ctx->builder()->Div(output, counts, window_dims); + return xla::Div(output, counts, window_dims); } } @@ -347,9 +348,9 @@ class MaxPoolGradOp : public XlaOpKernel { xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); auto select = CreateScalarGeComputation(element_type, ctx->builder()); auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); - xla::XlaOp gradients = ctx->builder()->SelectAndScatter( - input, select, ksize_, stride_, xla_padding, out_backprop, init_value, - scatter); + xla::XlaOp gradients = + xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding, + out_backprop, init_value, scatter); ctx->SetOutput(0, gradients); } @@ -485,12 +486,12 @@ class AvgPoolGradOp : public XlaOpKernel { } auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); + auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients <conv> ones std::vector<int64> ones(num_dims(), 1LL); auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = b->ReduceWindow( + auto in_backprop = xla::ReduceWindow( XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), ksize_, @@ -614,17 +615,18 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); - auto sixteen = b->ConstantR0<uint32>(16); + auto sixteen = xla::ConstantR0<uint32>(b, 16); // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 - auto in_hi = b->BitcastConvertType( - b->ConvertElementType(b->ConvertElementType(input, xla::BF16), - xla::F32), + auto in_hi = xla::BitcastConvertType( + xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16), + xla::F32), xla::U32); - auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); - auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); - auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); - auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. - auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. + auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen); + auto bp_lo = + xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); // We will reduce by taking the maximal value up to 16 bits (ignoring the lo @@ -633,39 +635,41 @@ class MaxPoolGradGradOp : public XlaOpKernel { { // F32 parameters to satisfy lowering type restriction for reduce opcode. const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto lhs = rb->Parameter(0, scalar, "lhs"); - auto rhs = rb->Parameter(1, scalar, "rhs"); - auto sixteen = rb->ConstantR0<int32>(16); - auto lhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), - sixteen); - auto rhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), - sixteen); + auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs"); + auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs"); + auto sixteen = xla::ConstantR0<int32>(rb.get(), 16); + auto lhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); // Must use a F32 comparison, because S32 would not work for negatives. - rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), - rb->BitcastConvertType(rhs_criteria, xla::F32)), - lhs, rhs); + xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32), + xla::BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); } auto reduce = rb->BuildAndNoteError(); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; auto pooled_hi = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto pooled_lo = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto grads_hi = - b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); - auto grads_lo = b->ShiftRightLogical( - b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = xla::ShiftRightLogical( + xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen), sixteen); - auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add. xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); - ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type)); } protected: diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 661cd5923e..9576354c5f 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -58,11 +59,11 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); input_min = - b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); + xla::ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); input_max = - b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); + xla::ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); } - xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max)); + xla::XlaOp m = xla::Max(xla::Abs(input_min), xla::Abs(input_max)); // Next, we choose our fixed-point quantization buckets, [min_fixed, // max_fixed]. If signed_input is true, this is @@ -86,14 +87,14 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // // s = (max_fixed - min_fixed) / (2 * m). xla::XlaOp s = - b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), - b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); + xla::Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), + xla::Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); // Now we can quantize and dequantize the elements of our tensor. An element // e is transformed into e': // // e' = (e * s).round_to_nearest() / s. - xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s); + xla::XlaOp result = xla::Div(xla::Round(xla::Mul(input, s)), s); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 3bab4ae917..51f2cdc9f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -46,8 +47,8 @@ class RandomUniformOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -79,8 +80,8 @@ class RandomShuffleOp : public XlaOpKernel { // Generate the random swaps for the indices. auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); auto swaps = - builder->RngUniform(builder->ConstantR0<int32>(0), - builder->ConstantR0<int32>(n), swaps_shape); + xla::RngUniform(xla::ConstantR0<int32>(builder, 0), + xla::ConstantR0<int32>(builder, n), swaps_shape); // Generate range(n) as the initial value for the indices to be swapped. xla::XlaOp indices; @@ -93,17 +94,17 @@ class RandomShuffleOp : public XlaOpKernel { -> xla::StatusOr<std::vector<xla::XlaOp>> { auto swaps = loop_vars[0]; auto indices = loop_vars[1]; - i = builder->Reshape(i, {1}); + i = xla::Reshape(i, {1}); // temp = indices[i] - auto temp = builder->DynamicSlice(indices, i, {1}); + auto temp = xla::DynamicSlice(indices, i, {1}); // swap_index = swaps[i] - auto swap_index = builder->DynamicSlice(swaps, i, {1}); + auto swap_index = xla::DynamicSlice(swaps, i, {1}); // swap_value = indices[swaps[i]] - auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); + auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); // indices[i] = indices[swaps[i]] - indices = builder->DynamicUpdateSlice(indices, swap_value, i); + indices = xla::DynamicUpdateSlice(indices, swap_value, i); // indices[swaps[i]] = temp - indices = builder->DynamicUpdateSlice(indices, temp, swap_index); + indices = xla::DynamicUpdateSlice(indices, temp, swap_index); return std::vector<xla::XlaOp>{swaps, indices}; }; // for i in range(n): @@ -153,7 +154,7 @@ class RandomUniformIntOp : public XlaOpKernel { auto minval = ctx->Input(1); auto maxval = ctx->Input(2); - ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape)); + ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape)); } private: @@ -179,8 +180,8 @@ class RandomStandardNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); // Normal distribution with a mean of 0 and a standard deviation of 1: - xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -209,7 +210,7 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); xla::XlaOp min_positive = XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min()); - auto uniform = b->RngUniform(min_positive, one, xla_shape); + auto uniform = xla::RngUniform(min_positive, one, xla_shape); ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 08894489ac..76bd1e62aa 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -98,10 +99,10 @@ class ReduceWindowOp : public XlaOpKernel { { std::unique_ptr<xla::XlaBuilder> cb = builder->CreateSubBuilder("wrapper"); - auto x = cb->Parameter(0, scalar_shape, "x"); - auto y = cb->Parameter(1, scalar_shape, "y"); - auto outputs = cb->Call(*reducer.computation, {x, y}); - cb->GetTupleElement(outputs, 0); + auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); + auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); + auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); + xla::GetTupleElement(outputs, 0); xla::StatusOr<xla::XlaComputation> result = cb->Build(); OP_REQUIRES_OK(context, result.status()); wrapper = std::move(result.ValueOrDie()); @@ -112,7 +113,7 @@ class ReduceWindowOp : public XlaOpKernel { padding[i] = {padding_low_[i], padding_high_[i]}; } - xla::XlaOp output = builder->ReduceWindowWithGeneralPadding( + xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), wrapper, window_dimensions_, window_strides_, padding); context->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 0f42563779..d3573bac3d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -35,7 +36,7 @@ class SumOp : public XlaReductionOp { } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } }; @@ -53,7 +54,7 @@ class ProdOp : public XlaReductionOp { void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Mul(scalar_lhs, scalar_rhs); + xla::Mul(scalar_lhs, scalar_rhs); } }; @@ -71,7 +72,7 @@ class MinOp : public XlaReductionOp { void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Min(scalar_lhs, scalar_rhs); + xla::Min(scalar_lhs, scalar_rhs); } }; @@ -88,7 +89,7 @@ class MaxOp : public XlaReductionOp { void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Max(scalar_lhs, scalar_rhs); + xla::Max(scalar_lhs, scalar_rhs); } }; @@ -105,7 +106,7 @@ class MeanOp : public XlaReductionOp { } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, @@ -113,7 +114,7 @@ class MeanOp : public XlaReductionOp { int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); - return builder->Div(reduce_output, divisor); + return xla::Div(reduce_output, divisor); } }; @@ -126,12 +127,12 @@ class AllOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0<bool>(true); + return xla::ConstantR0<bool>(builder, true); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->And(scalar_lhs, scalar_rhs); + xla::And(scalar_lhs, scalar_rhs); } }; @@ -143,12 +144,12 @@ class AnyOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0<bool>(false); + return xla::ConstantR0<bool>(builder, false); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Or(scalar_lhs, scalar_rhs); + xla::Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 44510c731e..14506d65c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -101,20 +102,20 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); - auto data = b->ConvertElementType(ctx->Input(0), type); + auto data = xla::ConvertElementType(ctx->Input(0), type); // Call virtual method to get the initial value. - auto initial = b->ConvertElementType(InitialValue(b), type); + auto initial = xla::ConvertElementType(InitialValue(b), type); // Make two scalar parameters of the desired type for the lambda. - auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); + auto rx = xla::Parameter(&r, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = xla::Parameter(&r, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); - auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); - auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; + auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index ba7d484d53..a4ba6c748a 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -34,7 +34,7 @@ class ReluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); - ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); + ctx->SetOutput(0, xla::Max(zero, ctx->Input(0))); } }; @@ -46,7 +46,7 @@ class Relu6Op : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); - ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); + ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six)); } }; @@ -59,9 +59,9 @@ class ReluGradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto pred = b->Gt(ctx->Input(1), zero); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero)); + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = xla::Gt(ctx->Input(1), zero); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero)); } }; @@ -74,12 +74,12 @@ class Relu6GradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto six = b->Broadcast( + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto six = xla::Broadcast( XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); - auto out = - b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), - ctx->Input(0), zero); + auto out = xla::Select( + xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); ctx->SetOutput(0, out); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index af4d64b159..e0ca8dd8e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -90,8 +91,7 @@ class ReshapeOp : public XlaOpKernel { VLOG(1) << "Reshape " << input_shape.DebugString() << " " << shape.DebugString(); - ctx->SetOutput(0, - ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes())); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index a711278638..db7ea775e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -69,8 +69,7 @@ class RetvalOp : public XlaOpKernel { xla::XlaOp output = input; if (tc.is_entry_computation()) { - output = - ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + output = xla::Reshape(input, representation_shape.dim_sizes()); } else { // The core from which a return value is returned depends on the // device assignment of the input to the retval. Since we can't change @@ -78,8 +77,8 @@ class RetvalOp : public XlaOpKernel { // introduce an operator here, even if the shape does not change. // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. - output = ctx->builder()->GetTupleElement( - ctx->builder()->Tuple({output}), 0); + output = + xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); } tc.AddRetval(index_, dtype_, shape, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 2872a3c4d4..037c422258 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -62,7 +63,7 @@ class ReverseOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), dimensions)); } }; @@ -100,7 +101,7 @@ class ReverseV2Op : public XlaOpKernel { x_shape.dims(), ").")); } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 5d1c052684..16491002b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -85,88 +86,83 @@ class ReverseSequenceOp : public XlaOpKernel { auto condition_builder = builder->CreateSubBuilder("reverse_sequence_condition"); { - auto param = condition_builder->Parameter(0, tuple_shape, "param"); - auto i = condition_builder->GetTupleElement(param, 0); - condition_builder->Lt( - i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, - batch_size)); + auto param = + xla::Parameter(condition_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(), + seq_lens_type, batch_size)); } auto condition = condition_builder->Build(); OP_REQUIRES_OK(context, condition.status()); auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); { - auto param = body_builder->Parameter(0, tuple_shape, "param"); - auto i = body_builder->GetTupleElement(param, 0); - auto seq_lens = body_builder->GetTupleElement(param, 1); - auto output = body_builder->GetTupleElement(param, 2); + auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + auto seq_lens = xla::GetTupleElement(param, 1); + auto output = xla::GetTupleElement(param, 2); // seq_len is the sequence length of the current batch element (rank 1) - auto seq_len = body_builder->DynamicSlice( - seq_lens, body_builder->Reshape(i, {1}), {1}); + auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto batch_element_indices = body_builder->Broadcast( - XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {input_shape.dims()}); - batch_element_indices = body_builder->DynamicUpdateSlice( - batch_element_indices, body_builder->Reshape(i, {1}), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - batch_dim_), - {1})); + auto batch_element_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + batch_element_indices = xla::DynamicUpdateSlice( + batch_element_indices, xla::Reshape(i, {1}), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, batch_dim_), + {1})); // Slice out the current batch element and pad it out in the sequence // dimension. TensorShape slice_shape = input_shape; slice_shape.set_dim(batch_dim_, 1); slice_shape.set_dim(seq_dim_, max_seq_len); - auto slice = body_builder->DynamicSlice(output, batch_element_indices, - slice_shape.dim_sizes()); + auto slice = xla::DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( slice_shape.dim_size(seq_dim_)); - slice = body_builder->Pad( - slice, XlaHelpers::Zero(body_builder.get(), input_type), - padding_config); + slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); // Now slice out the reversed sequence from its actual start. // sequence_start_indices is the offset of the start of the reversed // sequence in the input. The slice will go into the padding, however, we // will mask off these elements and replace them with elements from the // original input so their values do not matter. - auto sequence_start_indices = body_builder->Broadcast( - XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {slice_shape.dims()}); - sequence_start_indices = body_builder->DynamicUpdateSlice( + auto sequence_start_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = xla::DynamicUpdateSlice( sequence_start_indices, - body_builder->Sub(XlaHelpers::IntegerLiteral( - body_builder.get(), seq_lens_type, max_seq_len), - seq_len), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - seq_dim_), - {1})); - slice = body_builder->DynamicSlice(slice, sequence_start_indices, - slice_shape.dim_sizes()); + xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + max_seq_len), + seq_len), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, seq_dim_), + {1})); + slice = xla::DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, - batch_element_indices); + output = xla::DynamicUpdateSlice(output, slice, batch_element_indices); - body_builder->Tuple( - {body_builder->Add( - i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + xla::Tuple( + body_builder.get(), + {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)), seq_lens, output}); } auto body = body_builder->Build(); OP_REQUIRES_OK(context, body.status()); - auto loop_output = builder->While( + auto loop_output = xla::While( condition.ValueOrDie(), body.ValueOrDie(), - builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, - builder->Rev(input, {seq_dim_})})); - auto output = builder->GetTupleElement(loop_output, 2); + xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + xla::Rev(input, {seq_dim_})})); + auto output = xla::GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. xla::XlaOp iota; @@ -174,14 +170,13 @@ class ReverseSequenceOp : public XlaOpKernel { context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); std::vector<int64> dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; - auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); // Broadcast the mask up to the input shape. - mask = - builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false), - input_shape.dim_sizes())); + mask = xla::Or(mask, xla::Broadcast(xla::ConstantR0<bool>(builder, false), + input_shape.dim_sizes())); - output = builder->Select(mask, output, input); + output = xla::Select(mask, output, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 1819fb5433..9247c7029a 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -100,7 +101,7 @@ class ScanOp : public XlaOpKernel { init = XlaHelpers::One(builder, dtype); reducer = ctx->GetOrCreateMul(dtype); } - auto output = builder->ReduceWindowWithGeneralPadding( + auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, *reducer, window_dims, window_strides, padding); output = diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index f2c63b4f90..14709bb6cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -103,8 +104,8 @@ class ScatterNdOp : public XlaOpKernel { updates_shape)); xla::XlaBuilder* builder = context->builder(); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - buffer_shape.dim_sizes()); + auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype), + buffer_shape.dim_sizes()); auto indices = context->Input(0); auto updates = context->Input(1); auto result = diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index ff14483347..db7e559420 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -75,7 +75,7 @@ class UnsortedSegmentReduce : public XlaOpKernel { buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); auto buffer = - builder->Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); + xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); auto combiner = [this](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { @@ -102,7 +102,7 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce { }; xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) override { - return builder->Add(a, b); + return xla::Add(a, b); }; }; @@ -120,7 +120,7 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce { }; xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) override { - return builder->Mul(a, b); + return xla::Mul(a, b); }; }; @@ -138,7 +138,7 @@ class UnsortedSegmentMin : public UnsortedSegmentReduce { }; xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) override { - return builder->Min(a, b); + return xla::Min(a, b); }; }; @@ -156,7 +156,7 @@ class UnsortedSegmentMax : public UnsortedSegmentReduce { }; xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) override { - return builder->Max(a, b); + return xla::Max(a, b); }; }; diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index f9f48164d6..5c010c9df2 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -40,8 +41,6 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); - xla::XlaBuilder* builder = ctx->builder(); - auto cond_handle = ctx->Input(0); auto then_handle = ctx->Input(1); auto else_handle = ctx->Input(2); @@ -69,14 +68,14 @@ class SelectOp : public XlaOpKernel { const auto dim_sizes = then_shape.dim_sizes(); gtl::ArraySlice<int64> bdims = dim_sizes; bdims.pop_front(); - cond_handle = builder->Broadcast(cond_handle, bdims); + cond_handle = xla::Broadcast(cond_handle, bdims); std::vector<int64> dim_order(then_shape.dims()); dim_order[0] = then_shape.dims() - 1; std::iota(dim_order.begin() + 1, dim_order.end(), 0); - cond_handle = builder->Transpose(cond_handle, dim_order); + cond_handle = xla::Transpose(cond_handle, dim_order); } - ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle)); + ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 9ce01d0d44..6281d6c653 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -45,7 +45,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->builder()->Send(ctx->Input(0), channel); + xla::Send(ctx->Input(0), channel); } REGISTER_XLA_OP(Name("XlaSend"), SendOp); @@ -76,7 +76,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel)); + ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel)); } REGISTER_XLA_OP(Name("XlaRecv"), RecvOp); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index d59720bef7..5798823cd5 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -147,7 +148,7 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min<int32>(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); @@ -204,7 +205,7 @@ class SqueezeOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } private: @@ -221,7 +222,7 @@ class ZerosLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); } }; @@ -235,7 +236,7 @@ class OnesLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto one = XlaHelpers::One(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index be1e97bf26..1864584ade 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,8 +93,7 @@ class SliceOp : public XlaOpKernel { limits.push_back(begin[i] + size[i]); } std::vector<int64> strides(begin.size(), 1); - ctx->SetOutput( - 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides)); + ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides)); } else { // `begin` is not a compile-time constant. for (int i = 0; i < input_dims; ++i) { @@ -106,8 +106,7 @@ class SliceOp : public XlaOpKernel { input_shape.dim_size(i), "], but ", "got ", size[i])); } - ctx->SetOutput( - 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size)); + ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index bbf5ee8b12..d1c69f08b0 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -47,25 +48,25 @@ class SoftmaxOp : public XlaOpKernel { const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); // Find the max in each batch, resulting in a tensor of shape [batch] - auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type), + max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - auto exp_shifted = b->Exp(shifted_logits); + auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); + auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum = XlaHelpers::ConvertElementType(b, reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) - ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim}) // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - : b->Div(exp_shifted, sum, {kBatchDim}); + : xla::Div(exp_shifted, sum, {kBatchDim}); ctx->SetOutput(0, softmax); } @@ -87,43 +88,44 @@ std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits( xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. // Broadcasts along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); + auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); // exp(logits - max_logits) - auto exp_shifted_logits = b->Exp(shifted_logits); + auto exp_shifted_logits = xla::Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto reduce = + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) - auto log_sum_exp = b->Log(sum_exp); + auto log_sum_exp = xla::Log(sum_exp); // sum(-labels * // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - auto mul = b->Mul(b->Neg(labels), sub); + auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = xla::Mul(xla::Neg(labels), sub); auto sum = - b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) // (where the division broadcasts along the batch dimension) xla::XlaOp backprop = - b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); + xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); return {loss, backprop}; } @@ -206,16 +208,14 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // Builds a vector of {batch_size} that is 0 if the index is in range, or // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. - xla::XlaOp nan_or_zero = builder->Select( - builder->And( - builder->Le(XlaHelpers::Zero(builder, indices_type), indices), - builder->Lt(indices, XlaHelpers::IntegerLiteral( - builder, indices_type, depth))), - builder->Broadcast(XlaHelpers::Zero(builder, logits_type), - {batch_size}), - builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), - {batch_size})); - labels = builder->Add(labels, nan_or_zero, {0}); + xla::XlaOp nan_or_zero = xla::Select( + xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices), + xla::Lt(indices, XlaHelpers::IntegerLiteral( + builder, indices_type, depth))), + xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}), + xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), + {batch_size})); + labels = xla::Add(labels, nan_or_zero, {0}); xla::XlaOp loss, backprop; std::tie(loss, backprop) = diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index 204ae84582..faaf8964ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,8 +25,7 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - xla::XlaBuilder* const b = context->builder(); - context->SetOutput(0, b->Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index ec077924b5..8a8525efa1 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -73,7 +74,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, "The product of the block dimensions must be positive")); xla::XlaOp padded = - b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); // 2. Reshape `padded` to `reshaped_padded` of shape: // @@ -100,7 +101,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_padded_shape.begin() + 1 + 2 * block_rank); - xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape); + xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape); // 3. Permute dimensions of `reshaped_padded` to produce // `permuted_reshaped_padded` of shape: @@ -120,7 +121,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); xla::XlaOp permuted_reshaped_padded = - b->Transpose(reshaped_padded, permutation); + xla::Transpose(reshaped_padded, permutation); // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -140,7 +141,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), output_shape.begin() + 1 + block_rank); - xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4c5886ee2a..47d282fe9e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class SpaceToDepthOp : public XlaOpKernel { const gtl::InlinedVector<int64, 4> input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -145,7 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_, block_size_, // depth] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -155,7 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 9b54058541..ca74cf2450 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -98,7 +99,7 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); } } }; @@ -199,7 +200,7 @@ class SplitVOp : public XlaOpKernel { // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); begin[split_dim] = limits[split_dim]; } } diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 0fb05a2be7..591e61b4c8 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -144,24 +144,25 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::XlaOp ta = b->GetTupleElement(resource->value(), 0); - xla::XlaOp index = b->GetTupleElement(resource->value(), 1); + xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0); + xla::XlaOp index = xla::GetTupleElement(resource->value(), 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( - {b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0<int32>(1))}))); + OP_REQUIRES_OK(ctx, + resource->SetValue(xla::Tuple( + b, {xla::DynamicUpdateSlice(ta, update, start_indices), + xla::Add(index, xla::ConstantR0<int32>(b, 1))}))); ctx->SetOutput(0, value); } @@ -197,27 +198,27 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); xla::XlaOp state = resource->value(); - xla::XlaOp ta = b->GetTupleElement(state, 0); - xla::XlaOp index = b->GetTupleElement(state, 1); + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); - index = b->Sub(index, b->ConstantR0<int32>(1)); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); + index = Sub(index, xla::ConstantR0<int32>(b, 1)); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), + xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 43ab4642e9..3b19f8d872 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -33,9 +34,9 @@ namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, int distance) { - return builder->Or( - builder->ShiftLeft(v, builder->ConstantR0<int>(distance)), - builder->ShiftRightLogical(v, builder->ConstantR0<int>(32 - distance))); + return xla::Or( + xla::ShiftLeft(v, xla::ConstantR0<int>(builder, distance)), + xla::ShiftRightLogical(v, xla::ConstantR0<int>(builder, 32 - distance))); } using ThreeFry2x32State = std::array<xla::XlaOp, 2>; @@ -51,22 +52,22 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, std::array<xla::XlaOp, 3> ks; // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = builder->ConstantR0<int32>(0x1BD11BDA); + ks[2] = xla::ConstantR0<int32>(builder, 0x1BD11BDA); for (int i = 0; i < 2; ++i) { ks[i] = key[i]; x[i] = input[i]; - ks[2] = builder->Xor(ks[2], key[i]); + ks[2] = xla::Xor(ks[2], key[i]); } - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(x[1], ks[1]); + x[0] = xla::Add(x[0], ks[0]); + x[1] = xla::Add(x[1], ks[1]); // Performs a single round of the Threefry2x32 algorithm, with a rotation // amount 'rotation'. auto round = [builder](ThreeFry2x32State v, int rotation) { - v[0] = builder->Add(v[0], v[1]); + v[0] = xla::Add(v[0], v[1]); v[1] = RotateLeftS32(builder, v[1], rotation); - v[1] = builder->Xor(v[0], v[1]); + v[1] = xla::Xor(v[0], v[1]); return v; }; @@ -76,36 +77,36 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(1)); + x[0] = xla::Add(x[0], ks[1]); + x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 1)); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(2)); + x[0] = xla::Add(x[0], ks[2]); + x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 2)); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0<int32>(3)); + x[0] = xla::Add(x[0], ks[0]); + x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0<int32>(builder, 3)); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(4)); + x[0] = xla::Add(x[0], ks[1]); + x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 4)); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(5)); + x[0] = xla::Add(x[0], ks[2]); + x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 5)); return x; } @@ -116,8 +117,8 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, const TensorShape& shape, double minval, double maxval) { // Split the seed into two 32-bit scalars to form a key. - auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); ThreeFry2x32State key = {seed0, seed1}; const int64 size = shape.num_elements(); @@ -127,32 +128,32 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, // Fill the generator inputs with unique counter values. ThreeFry2x32State inputs; TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); - inputs[1] = builder->Add(inputs[0], builder->ConstantR0<int32>(half_size)); + inputs[1] = xla::Add(inputs[0], xla::ConstantR0<int32>(builder, half_size)); ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); if (size_is_odd) { - outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); + outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1}); } auto bits = - builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); + xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes()); // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit // forces the random bits into the mantissa. constexpr int kFloatBits = 32; constexpr int kMantissaBits = 23; - bits = builder->Or( - builder->ShiftRightLogical( - bits, builder->ConstantR0<int32>(kFloatBits - kMantissaBits)), - builder->ConstantR0<int32>(bit_cast<int32>(1.0f))); - auto floats = builder->BitcastConvertType(bits, xla::F32); + bits = xla::Or( + xla::ShiftRightLogical( + bits, xla::ConstantR0<int32>(builder, kFloatBits - kMantissaBits)), + xla::ConstantR0<int32>(builder, bit_cast<int32>(1.0f))); + auto floats = xla::BitcastConvertType(bits, xla::F32); // We have a floating point number in the range [1.0, 2.0). // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = builder->Sub(floats, builder->ConstantR0<float>(1.0f)); + floats = xla::Sub(floats, xla::ConstantR0<float>(builder, 1.0f)); // Multiply and add to shift to the range [minval, maxval). - floats = builder->Mul(floats, builder->ConstantR0<float>(maxval - minval)); - floats = builder->Add(floats, builder->ConstantR0<float>(minval)); + floats = xla::Mul(floats, xla::ConstantR0<float>(builder, maxval - minval)); + floats = xla::Add(floats, xla::ConstantR0<float>(builder, minval)); return floats; } @@ -207,8 +208,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) - auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)), - ErfInv(uniform)); + auto normal = xla::Mul(xla::ConstantR0<float>(builder, std::sqrt(2.0)), + ErfInv(uniform)); ctx->SetOutput(0, normal); } diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 55254c746e..c2165ccd86 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,12 +93,12 @@ class StridedSliceOp : public XlaOpKernel { xla::XlaOp slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { - slice = ctx->builder()->Rev(slice, dimensions_to_reverse); + slice = xla::Rev(slice, dimensions_to_reverse); } - slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); + slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); - slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } @@ -171,7 +172,7 @@ class StridedSliceGradOp : public XlaOpKernel { xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. - grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); + grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. gtl::InlinedVector<int64, 4> dimensions_to_reverse; @@ -204,9 +205,9 @@ class StridedSliceGradOp : public XlaOpKernel { } } if (!dimensions_to_reverse.empty()) { - grad = ctx->builder()->Rev(grad, dimensions_to_reverse); + grad = xla::Rev(grad, dimensions_to_reverse); } - grad = ctx->builder()->Pad(grad, zero, padding_config); + grad = xla::Pad(grad, zero, padding_config); ctx->SetOutput(0, grad); } @@ -306,17 +307,17 @@ class StridedSliceAssignOp : public XlaOpKernel { } if (!dimensions_to_reverse.empty()) { - rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); + rhs = xla::Rev(rhs, dimensions_to_reverse); } - rhs = ctx->builder()->Reshape(rhs, slice_dims); + rhs = xla::Reshape(rhs, slice_dims); if (lhs_shape.dims() == 0) { // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix // and remove this workaround. lhs = rhs; } else { - lhs = ctx->builder()->DynamicUpdateSlice( - lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin)); + lhs = xla::DynamicUpdateSlice( + lhs, rhs, xla::ConstantR1<int64>(ctx->builder(), slice_begin)); } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 9adee78a1f..2f650ce305 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -123,10 +124,9 @@ xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, const gtl::ArraySlice<int64>& update_dims, const xla::XlaOp& start_indices) { - xla::XlaOp current = - builder->DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = builder->Add(current, update); - return builder->DynamicUpdateSlice(operand, sum, start_indices); + xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); + xla::XlaOp sum = xla::Add(current, update); + return xla::DynamicUpdateSlice(operand, sum, start_indices); } class TensorArrayOp : public XlaOpKernel { @@ -162,7 +162,7 @@ class TensorArrayOp : public XlaOpKernel { ta_shape.AddDim(size); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); - value = b->Broadcast(zero, ta_shape.dim_sizes()); + value = xla::Broadcast(zero, ta_shape.dim_sizes()); } XlaContext& xc = XlaContext::Get(ctx); @@ -215,12 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); @@ -259,17 +259,17 @@ class TensorArrayReadOp : public XlaOpKernel { // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), + xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: @@ -326,7 +326,7 @@ class TensorArrayGatherOp : public XlaOpKernel { for (auto i = 1; i < ta_shape.dims(); i++) { end[i] = ta_shape.dim_size(i); } - ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + ctx->SetOutput(0, xla::Slice(ta, begin, end, strides)); return; } } @@ -391,7 +391,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = b->Add(ta, value); + ta = xla::Add(ta, value); } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -407,13 +407,13 @@ class TensorArrayScatterOp : public XlaOpKernel { // Slice out part of the value. value_starts[0] = i; value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + auto slice = xla::Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto index = xla::Slice(indices, {i}, {i + 1}, {1}); auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } } @@ -452,7 +452,7 @@ class TensorArrayConcatOp : public XlaOpKernel { auto ta_dims = ta_shape.dim_sizes(); std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end()); shape[0] *= ta_shape.dim_size(0); - ctx->SetOutput(0, b->Reshape(ta, shape)); + ctx->SetOutput(0, xla::Reshape(ta, shape)); Tensor lengths(DT_INT64, {ta_dims[0]}); auto lengths_vec = lengths.vec<int64>(); @@ -522,8 +522,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( - ta, b->Reshape(value, ta_shape.dim_sizes())))); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( + ta, xla::Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index e91075196b..c9e5694262 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -93,9 +94,9 @@ class TileOp : public XlaOpKernel { if (one_dimension_is_broadcasted_without_multiple) { // Create a constant Zero the size of the output shape to leverage binary // operation broadcast semantics. - auto broadcasted_zero = ctx->builder()->Broadcast( + auto broadcasted_zero = xla::Broadcast( XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); - ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input)); + ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); return; } @@ -103,7 +104,7 @@ class TileOp : public XlaOpKernel { // dimension. This prepends the broadcasted dimensions, so an // input of shape [2,3,1] broadcast with multiples [5,4,3] will // end up with shape [5,4,3,2,3,1]. - auto broadcasted = ctx->builder()->Broadcast(input, multiples_array); + auto broadcasted = xla::Broadcast(input, multiples_array); // Now flatten and reshape. The broadcasted dimensions are // paired with the original dimensions so in the above example // we flatten [0,3,1,4,2,5] then reshape to [10,12,3]. @@ -112,8 +113,7 @@ class TileOp : public XlaOpKernel { flattened.push_back(i); flattened.push_back(i + output_shape.size()); } - xla::XlaOp output = - ctx->builder()->Reshape(broadcasted, flattened, output_shape); + xla::XlaOp output = xla::Reshape(broadcasted, flattened, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index cbe3c8aaff..beb7cf263d 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -68,7 +68,7 @@ class TopKOp : public XlaOpKernel { // TODO(b/73891930): add a key-value sort to HLO, rather than using // bit-packing tricks here. - xla::XlaOp zero = b->ConstantR0<int32>(0); + xla::XlaOp zero = xla::ConstantR0<int32>(b, 0); // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally // ideal. The implications of the choice are: @@ -83,22 +83,22 @@ class TopKOp : public XlaOpKernel { // 1. +0.0 == -0.0 // 2. All -0.0 in the input are replaced with +0.0 in the output. // 3. The sort is stable. - xla::XlaOp max = b->ConstantR0<int32>(0x80000000); - xla::XlaOp index_mask = b->ConstantR0<int32>(0x0000FFFF); - xla::XlaOp value_mask = b->ConstantR0<int32>(0xFFFF0000); + xla::XlaOp max = xla::ConstantR0<int32>(b, 0x80000000); + xla::XlaOp index_mask = xla::ConstantR0<int32>(b, 0x0000FFFF); + xla::XlaOp value_mask = xla::ConstantR0<int32>(b, 0xFFFF0000); // Convert to from bf16 to f32. The lower 16-bits are zero due to the // definition of bf16. - xla::XlaOp input_f32 = b->ConvertElementType(input_bf16, xla::F32); + xla::XlaOp input_f32 = xla::ConvertElementType(input_bf16, xla::F32); // Negate the input to reverse sort it. The lower 16-bits are zero, because // negating a float is just inverting the high-bit. - xla::XlaOp negative_input_f32 = b->Neg(input_f32); + xla::XlaOp negative_input_f32 = xla::Neg(input_f32); // Convert to a sign magnitude integer. The lower 16-bits are zero, since // bitcast convert doesn't change any bits. xla::XlaOp negative_input_sm32 = - b->BitcastConvertType(negative_input_f32, xla::S32); + xla::BitcastConvertType(negative_input_f32, xla::S32); // Convert from sign magnitude integer to two's complement integer. The // lower 16-bits are zero on both sides of the select. On the false side, @@ -106,8 +106,8 @@ class TopKOp : public XlaOpKernel { // are all zero, so the lower 16-bits of the result of the subtraction will // also be zero. xla::XlaOp negative_input_s32 = - b->Select(b->Lt(negative_input_sm32, zero), - b->Sub(max, negative_input_sm32), negative_input_sm32); + xla::Select(xla::Lt(negative_input_sm32, zero), + xla::Sub(max, negative_input_sm32), negative_input_sm32); // In order for the Or with iota_s32 to to work properly, the lower 16-bits // of negative_input_32 must be zero. @@ -115,32 +115,32 @@ class TopKOp : public XlaOpKernel { // Pack elements as: // * upper 16 bits are the value // * lower 16 bits are the index. - xla::XlaOp packed_s32 = b->Or(negative_input_s32, iota_s32); + xla::XlaOp packed_s32 = xla::Or(negative_input_s32, iota_s32); // TODO(phawkins): use a more efficient algorithm that does not require a // full sort. - xla::XlaOp sorted_s32 = b->Slice(b->Sort(packed_s32), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1}); + xla::XlaOp sorted_s32 = xla::Slice(xla::Sort(packed_s32), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1}); // Unpack the value/index. - xla::XlaOp indices_s32 = b->And(sorted_s32, index_mask); - xla::XlaOp negative_values_s32 = b->And(sorted_s32, value_mask); + xla::XlaOp indices_s32 = xla::And(sorted_s32, index_mask); + xla::XlaOp negative_values_s32 = xla::And(sorted_s32, value_mask); // Convert from two's complement integer to sign magnitude integer. xla::XlaOp negative_values_sm32 = - b->Select(b->Lt(negative_values_s32, zero), - b->Sub(max, negative_values_s32), negative_values_s32); + xla::Select(xla::Lt(negative_values_s32, zero), + xla::Sub(max, negative_values_s32), negative_values_s32); xla::XlaOp negative_values_f32 = - b->BitcastConvertType(negative_values_sm32, xla::F32); + xla::BitcastConvertType(negative_values_sm32, xla::F32); // Negate the values to get back the original inputs. - xla::XlaOp values_f32 = b->Neg(negative_values_f32); + xla::XlaOp values_f32 = xla::Neg(negative_values_f32); // Convert from f32 to bf16. - xla::XlaOp values_bf16 = b->ConvertElementType(values_f32, xla::BF16); + xla::XlaOp values_bf16 = xla::ConvertElementType(values_f32, xla::BF16); context->SetOutput(0, values_bf16); context->SetOutput(1, indices_s32); diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 34caefa050..2e5d61e111 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -31,7 +31,6 @@ class ResourceApplyGradientDescent : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp handle; - xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(1); TensorShape var_shape; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); @@ -48,7 +47,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel { var_shape.DebugString(), " vs ", delta_shape.DebugString())); - handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); + handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2))); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -63,8 +62,6 @@ class ResourceApplyMomentum : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -97,14 +94,14 @@ class ResourceApplyMomentum : public XlaOpKernel { xla::XlaOp grad = ctx->Input(3); xla::XlaOp momentum = ctx->Input(4); - accum = b->Add(b->Mul(accum, momentum), grad); + accum = xla::Add(xla::Mul(accum, momentum), grad); if (use_nesterov_) { // See https://github.com/tensorflow/tensorflow/pull/2798 for an // explanation of the reparameterization used here. - var = b->Sub( - var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr))); + var = xla::Sub(var, xla::Add(xla::Mul(grad, lr), + xla::Mul(xla::Mul(accum, momentum), lr))); } else { - var = b->Sub(var, b->Mul(accum, lr)); + var = xla::Sub(var, xla::Mul(accum, lr)); } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); @@ -149,10 +146,12 @@ class ResourceApplyAdagrad : public XlaOpKernel { xla::XlaOp lr = ctx->Input(2); xla::XlaOp grad = ctx->Input(3); - accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); - var = b->Sub( - var, b->Mul(b->Mul(grad, lr), - b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + accum = + xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); + var = xla::Sub( + var, + xla::Mul(xla::Mul(grad, lr), + xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } @@ -232,12 +231,13 @@ class ResourceApplyAdam : public XlaOpKernel { xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); xla::XlaOp alpha = - b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), - b->Sub(one, beta1_power)); - m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); - v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2))); - var = - b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon))); + xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)), + xla::Sub(one, beta1_power)); + m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1))); + v = xla::Add( + v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2))); + var = xla::Sub(var, xla::Div(xla::Mul(m, alpha), + xla::Add(xla::Pow(v, half), epsilon))); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); @@ -320,16 +320,17 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::XlaOp new_ms = b->Add( - ms, - b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), - b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); + xla::XlaOp new_ms = xla::Add( + ms, xla::Mul( + xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), + ms), + xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); xla::XlaOp new_mom = - b->Add(b->Mul(mom, momentum), - b->Mul(b->Mul(grad, lr), - b->Pow(b->Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::XlaOp new_var = b->Sub(var, new_mom); + xla::Add(xla::Mul(mom, momentum), + xla::Mul(xla::Mul(grad, lr), + xla::Pow(xla::Add(new_ms, epsilon), + XlaHelpers::FloatLiteral(b, type, -0.5)))); + xla::XlaOp new_var = xla::Sub(var, new_mom); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); @@ -424,21 +425,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); xla::XlaOp grad_to_use; if (has_l2_shrinkage) { - grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); + grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var))); } else { grad_to_use = grad; } - xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two)); - xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power)); - xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); - linear = b->Add( + xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two)); + xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power)); + xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power)); + linear = xla::Add( linear, - b->Sub(grad_to_use, - b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); - xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1); - xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); - var = b->Div(b->Sub(linear_clipped, linear), quadratic); + xla::Sub(grad_to_use, + xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr), + var))); + xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1); + xla::XlaOp quadratic = + xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2)); + var = xla::Div(xla::Sub(linear_clipped, linear), quadratic); accum = new_accum; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index ef5aae81a8..6c721c48fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -84,12 +85,12 @@ class TransposeOp : public XlaOpKernel { if (dims <= 1 || is_identity) { transposed = ctx->Input(0); } else { - transposed = ctx->builder()->Transpose(ctx->Input(0), transposed_order); + transposed = xla::Transpose(ctx->Input(0), transposed_order); } // Conjugate the transposed result if this is ConjugateTransposeOp. if (conjugate_) { - ctx->SetOutput(0, ctx->builder()->Conj(transposed)); + ctx->SetOutput(0, xla::Conj(transposed)); } else { ctx->SetOutput(0, transposed); } @@ -146,7 +147,7 @@ class InvertPermutationOp : public XlaOpKernel { output[d] = i; } - ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output)); + ctx->SetOutput(0, xla::ConstantR1<int32>(ctx->builder(), output)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a39e5dcfc5..3823f5c087 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -27,15 +27,13 @@ limitations under the License. namespace tensorflow { namespace { -// A subclass of a TlaUnaryOp must build the lambda computation that -// describes the scalar->scalar function to apply to each element of -// the input. #define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \ class NAME##Op : public XlaOpKernel { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ xla::XlaOp x = ctx->Input(0); \ xla::XlaOp y = COMPUTATION; \ ctx->SetOutput(0, y); \ @@ -43,84 +41,88 @@ namespace { }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op); -XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x)); +XLAJIT_MAKE_UNARY(ComplexAbs, xla::Abs(x)); -XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x))); +XLAJIT_MAKE_UNARY(Angle, xla::Atan2(xla::Imag(x), xla::Real(x))); -XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); +XLAJIT_MAKE_UNARY(Conj, xla::Conj(x)); // Return x if x>0, otherwise -x. -XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); +XLAJIT_MAKE_UNARY(Abs, xla::Abs(x)); // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) XLAJIT_MAKE_UNARY( Acos, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - b->Add(XlaHelpers::One(b, input_type(0)), x)))); + xla::Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + xla::Atan2(xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), + 0.5)), + xla::Add(XlaHelpers::One(b, input_type(0)), x)))); // acosh(x) = log(x + sqrt(x^2 - 1)) // = log(x + sqrt((x+1)*(x-1))) XLAJIT_MAKE_UNARY( Acosh, - b->Log(b->Add(x, - b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))), - b->Sub(x, XlaHelpers::One(b, input_type(0)))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + xla::Log(xla::Add( + x, xla::Pow(xla::Mul(xla::Add(x, XlaHelpers::One(b, input_type(0))), + xla::Sub(x, XlaHelpers::One(b, input_type(0)))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) XLAJIT_MAKE_UNARY( Asin, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), - b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), + xla::Mul( + XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + xla::Atan2(x, + xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(x, x)), XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))))); // asinh(x) = log(x + sqrt(x^2 + 1)) XLAJIT_MAKE_UNARY( Asinh, - b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), - XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + xla::Log(xla::Add( + x, xla::Pow(xla::Add(xla::Mul(x, x), XlaHelpers::One(b, input_type(0))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); -XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, XlaHelpers::One(b, input_type(0)))); // atanh(x) = 0.5 * log((1 + x) / (1 - x)) XLAJIT_MAKE_UNARY( - Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), - b->Sub(XlaHelpers::One(b, input_type(0)), x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); -XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); + Atanh, + xla::Mul(xla::Log(xla::Div(xla::Add(XlaHelpers::One(b, input_type(0)), x), + xla::Sub(XlaHelpers::One(b, input_type(0)), x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, xla::Cos(x)); XLAJIT_MAKE_UNARY(Cosh, - b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); -XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); - -XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); - -XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); -XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); -XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x), - XlaHelpers::FloatLiteral( - b, input_type(0), - std::numeric_limits<double>::infinity()))); -XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x)); + xla::Mul(xla::Add(xla::Exp(x), xla::Exp(xla::Neg(x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); +XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); + +XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); + +XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); +XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); +XLAJIT_MAKE_UNARY(IsInf, xla::Eq(xla::Abs(x), + XlaHelpers::FloatLiteral( + b, input_type(0), + std::numeric_limits<double>::infinity()))); +XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); // Return 1/x -XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Log, b->Log(x)); +XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Log, xla::Log(x)); XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); -XLAJIT_MAKE_UNARY(Invert, b->Not(x)); -XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); -XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); +XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); +XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); +XLAJIT_MAKE_UNARY(Neg, xla::Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. @@ -130,35 +132,35 @@ static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); - auto round_val = b->Floor(x); - auto fraction = b->Sub(x, round_val); + auto round_val = xla::Floor(x); + auto fraction = xla::Sub(x, round_val); auto nearest_even_int = - b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); - auto is_odd = b->Eq(nearest_even_int, one); - return b->Select( - b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), - b->Add(round_val, one), round_val); + xla::Sub(round_val, xla::Mul(two, xla::Floor(xla::Mul(half, x)))); + auto is_odd = xla::Eq(nearest_even_int, one); + return xla::Select(xla::Or(xla::Gt(fraction, half), + xla::And(xla::Eq(fraction, half), is_odd)), + xla::Add(round_val, one), round_val); } XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); -XLAJIT_MAKE_UNARY(Rsqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); +XLAJIT_MAKE_UNARY(Rsqrt, xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), + -0.5))); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); + return xla::Add(half, xla::Mul(half, xla::Tanh(xla::Mul(half, x)))); } XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); +XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); XLAJIT_MAKE_UNARY(Sinh, - b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); + xla::Mul(xla::Sub(xla::Exp(x), xla::Exp(xla::Neg(x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); // softplus(x) = log(1 + exp(x)) // @@ -169,21 +171,21 @@ XLAJIT_MAKE_UNARY(Sinh, // This is equivalent to: // max(x, 0) + log1p(exp(-abs(x))) XLAJIT_MAKE_UNARY(Softplus, - b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), - b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); + xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))), + b->Log1p(xla::Exp(xla::Neg(xla::Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, - b->Div(x, - b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0))))); + xla::Div(x, xla::Add(xla::Abs(x), + XlaHelpers::One(b, input_type(0))))); XLAJIT_MAKE_UNARY(Sqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); -XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); -XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); + xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Square, xla::Mul(x, x)); +XLAJIT_MAKE_UNARY(Tan, xla::Div(xla::Sin(x), xla::Cos(x))); +XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); -XLAJIT_MAKE_UNARY(Real, b->Real(x)); -XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); +XLAJIT_MAKE_UNARY(Real, xla::Real(x)); +XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); #undef XLAJIT_MAKE_UNARY @@ -197,13 +199,14 @@ class ErfOp : public XlaOpKernel { xla::PrimitiveType primitive_type; xla::XlaOp one = XlaHelpers::One(b, input_type(0)); xla::XlaOp x = ctx->Input(0); - xla::XlaOp abs_x = b->Abs(x); + xla::XlaOp abs_x = xla::Abs(x); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &primitive_type)); - auto y = b->Select(b->Gt(abs_x, one), b->Sub(one, Erfc(x, primitive_type)), - Erf(x, primitive_type)); + auto y = + xla::Select(xla::Gt(abs_x, one), xla::Sub(one, Erfc(x, primitive_type)), + Erf(x, primitive_type)); ctx->SetOutput(0, y); } }; @@ -216,14 +219,15 @@ class ErfcOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); xla::XlaOp one = XlaHelpers::One(b, input_type(0)); xla::XlaOp x = ctx->Input(0); - xla::XlaOp abs_x = b->Abs(x); + xla::XlaOp abs_x = xla::Abs(x); xla::PrimitiveType primitive_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &primitive_type)); - auto y = b->Select(b->Lt(abs_x, one), b->Sub(one, Erf(x, primitive_type)), - Erfc(x, primitive_type)); + auto y = + xla::Select(xla::Lt(abs_x, one), xla::Sub(one, Erf(x, primitive_type)), + Erfc(x, primitive_type)); ctx->SetOutput(0, y); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index f87586ba57..0e5d58ecba 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -74,10 +75,9 @@ class UnpackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { start_indices[axis] = i; limit_indices[axis] = i + 1; - auto slice = ctx->builder()->Slice(input, start_indices, limit_indices, - strides); + auto slice = xla::Slice(input, start_indices, limit_indices, strides); // Reshape to drop the 'axis' dimension. - auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes()); + auto result = xla::Reshape(slice, output_shape.dim_sizes()); ctx->SetOutput(i, result); } } diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index ad51396bdf..febac82873 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -33,8 +33,8 @@ class VarIsInitializedOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { XlaResource* variable; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); - ctx->SetOutput(0, - ctx->builder()->ConstantR0<bool>(variable->initialized())); + ctx->SetOutput( + 0, xla::ConstantR0<bool>(ctx->builder(), variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); @@ -96,7 +96,7 @@ class AssignAddVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Add(handle, ctx->Input(1)); + handle = xla::Add(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -112,7 +112,7 @@ class AssignSubVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Sub(handle, ctx->Input(1)); + handle = xla::Sub(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -191,7 +191,7 @@ class ResourceScatterAddOp : public ResourceScatterOp { private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { - return builder->Add(x, y); + return xla::Add(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp); @@ -204,7 +204,7 @@ class ResourceScatterSubOp : public ResourceScatterOp { private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { - return builder->Sub(x, y); + return xla::Sub(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp); @@ -217,7 +217,7 @@ class ResourceScatterMulOp : public ResourceScatterOp { private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { - return builder->Mul(x, y); + return xla::Mul(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp); @@ -230,7 +230,7 @@ class ResourceScatterDivOp : public ResourceScatterOp { private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { - return builder->Div(x, y); + return xla::Div(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp); @@ -243,7 +243,7 @@ class ResourceScatterMinOp : public ResourceScatterOp { private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { - return builder->Min(x, y); + return xla::Min(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp); @@ -256,7 +256,7 @@ class ResourceScatterMaxOp : public ResourceScatterOp { private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { - return builder->Max(x, y); + return xla::Max(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp); @@ -286,7 +286,7 @@ class ResourceScatterNdAddOp : public ResourceScatterOp { private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { - return builder->Add(x, y); + return xla::Add(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 5467c5d994..340165bac6 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -246,7 +246,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp init = builder->Tuple(inputs); + xla::XlaOp init = xla::Tuple(builder, inputs); VLOG(1) << "Building while loop"; @@ -255,22 +255,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { { std::unique_ptr<xla::XlaBuilder> cb = builder->CreateSubBuilder("cond_wrapper"); - auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); - auto outputs = cb->Call(*cond.computation, {inputs}); - cb->GetTupleElement(outputs, 0); + auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs"); + auto outputs = xla::Call(cb.get(), *cond.computation, {inputs}); + xla::GetTupleElement(outputs, 0); xla::StatusOr<xla::XlaComputation> result = cb->Build(); OP_REQUIRES_OK(ctx, result.status()); cond_wrapper = std::move(result.ValueOrDie()); } - xla::XlaOp while_result = - builder->While(cond_wrapper, *body.computation, init); + xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); // Sets non-variable outputs. for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], - builder->GetTupleElement(while_result, i)); + xla::GetTupleElement(while_result, i)); } } @@ -284,7 +283,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), builder)); + xla::GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name() << " modified: " << update.modified diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index ee0bb91a6b..dd29bafcd9 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <vector> +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -81,25 +82,26 @@ xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); dimensions.push_back(x_shape.dimensions(x_outer_dim)); dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::Literal::Zero(x_shape.element_type())), dimensions); } if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = builder->Conj(x); + x = xla::Conj(x); } if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = builder->Conj(y); + y = xla::Conj(y); } // If there are no batch dimensions, use a regular Dot. // TODO(b/69062148) Remove this code when Dot emitters can be passed // dimensions to transpose directly (i.e. without requiring a Transpose HLO). if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; - return builder->Dot(lhs, rhs); + auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; + return xla::Dot(lhs, rhs); } xla::DotDimensionNumbers dot_dnums; @@ -109,7 +111,7 @@ xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return builder->DotGeneral(x, y, dot_dnums); + return xla::DotGeneral(x, y, dot_dnums); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 20925118bf..397f0e3a72 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -80,21 +81,20 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder, std::vector<int32> mask_vector(n); std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = body_builder->ConstantR1<int32>(mask_vector); - auto mask_range_row = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {1, n}), major_dims); - auto mask_range_col = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims); + auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector); + auto mask_range_row = + xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = + xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; // row = l[..., i, :i] // select the whole i-th row, then mask out all columns past i-1 - auto zero = body_builder->ConstantR0<int32>(0); + auto zero = xla::ConstantR0<int32>(body_builder, 0); TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, {i, zero}, {1, n})); - auto row = body_builder->Select(body_builder->Ge(mask_range_row, i), - mask_zeros_row, l_i); + auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); // a[..., i, i] TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, {i, i}, {1, 1})); @@ -105,16 +105,15 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder, /*transpose_y=*/true)); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) - auto l_ii = body_builder->Pow( - body_builder->Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + auto l_ii = + xla::Pow(xla::Sub(a_ii, diag_dot), + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); // a[..., i+1:, i] // select the whole i-th column, then mask out all rows above i+1 TF_ASSIGN_OR_RETURN( auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); - auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, a_0i); + auto a_ip1i = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i, i] @@ -125,11 +124,9 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder, /*transpose_x=*/false, /*transpose_y=*/true)); // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, dot); + auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); - auto col_update = - body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii); + auto col_update = xla::Div(xla::Sub(a_ip1i, dot_ip1), l_ii); TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( body_builder, body_l, col_update, {i})); // Assign the diagonal after the rest of the column because otherwise the @@ -191,9 +188,8 @@ xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, /*conjugate_y=*/false)); TF_ASSIGN_OR_RETURN(auto before, SliceInMinorDims(builder, a, {i, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN( - a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta), - {i, i})); + TF_ASSIGN_OR_RETURN(a, UpdateSliceInMinorDims( + builder, a, xla::Sub(before, delta), {i, i})); } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc index e4f195901e..3dfa66029c 100644 --- a/tensorflow/compiler/tf2xla/lib/random.cc +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" namespace tensorflow { @@ -49,9 +50,9 @@ xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) { XlaHelpers::FloatLiteral(builder, dtype, kAlphaNormalCdf); // probit(p) = sqrt(2) * erfinv(2*p-1) - auto p = builder->Add(alpha_normal_cdf, builder->Mul(z, uniform)); - auto erfinv_input = builder->Sub(builder->Mul(p, two), one); - return builder->Mul(sqrt_2, ErfInv(erfinv_input)); + auto p = xla::Add(alpha_normal_cdf, xla::Mul(z, uniform)); + auto erfinv_input = xla::Sub(xla::Mul(p, two), one); + return xla::Mul(sqrt_2, ErfInv(erfinv_input)); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index d5a27abb25..85e3d3ab85 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -97,8 +98,8 @@ xla::StatusOr<xla::XlaOp> XlaScatter( buffer_shape_post_axes.end()); // Construct the initial values of the loop-carried Tensors. - auto flat_indices = builder->Reshape(indices, flat_indices_shape); - auto flat_updates = builder->Reshape(updates, flat_updates_shape); + auto flat_indices = xla::Reshape(indices, flat_indices_shape); + auto flat_updates = xla::Reshape(updates, flat_updates_shape); auto init = {flat_indices, flat_updates, buffer}; // Constructs the loop body. The implementation of scatter is essentially: @@ -112,46 +113,44 @@ xla::StatusOr<xla::XlaOp> XlaScatter( auto updates = loop_vars[1]; auto buffer = loop_vars[2]; - auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape.element_type())); + auto zero_index = xla::ConstantLiteral( + body_builder, xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. xla::XlaOp index; - auto indices_offset = body_builder->Reshape(i, {1}); + auto indices_offset = xla::Reshape(i, {1}); if (indices_are_vectors) { - indices_offset = body_builder->Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); + indices_offset = xla::Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); - index = body_builder->DynamicSlice(indices, indices_offset, - {1, num_index_dims}); - index = body_builder->Collapse(index, {0, 1}); + index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); + index = xla::Collapse(index, {0, 1}); } else { - index = body_builder->DynamicSlice(indices, indices_offset, {1}); + index = xla::DynamicSlice(indices, indices_offset, {1}); } // Discard updates with negative indices, since some users expect this. - auto index_in_range = - body_builder->ReduceAll(body_builder->Le(zero_index, index), - body_builder->ConstantR0<bool>(true), - xla::CreateScalarAndComputation(body_builder)); + auto index_in_range = xla::ReduceAll( + xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true), + xla::CreateScalarAndComputation(body_builder)); // Make the index in bounds to prevent implementation defined behavior. - index = body_builder->Max(index, zero_index); - index = body_builder->Pad( + index = xla::Max(index, zero_index); + index = xla::Pad( index, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); // Slice the i-th index from the updates array. - auto updates_offset = body_builder->Reshape(i, {1}); - updates_offset = body_builder->Pad( + auto updates_offset = xla::Reshape(i, {1}); + updates_offset = xla::Pad( updates_offset, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); std::vector<int64> flat_updates_slice_shape({1}); flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - auto update = body_builder->DynamicSlice(updates, updates_offset, - flat_updates_slice_shape); + auto update = + xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); // Unflatten the major (iteration) dimensions of the slice to their // original shape. @@ -159,20 +158,19 @@ xla::StatusOr<xla::XlaOp> XlaScatter( updates_slice_shape.insert(updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - update = body_builder->Reshape(update, updates_slice_shape); + update = xla::Reshape(update, updates_slice_shape); // Apply the update to the buffer. If there is a combiner, use it to merge // the current values with the update. - auto current_value = - body_builder->DynamicSlice(buffer, index, updates_slice_shape); + auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); if (combiner) { update = combiner(current_value, update, body_builder); } // Use the current value instead of the update if the index is out of // bounds. - update = body_builder->Select(index_in_range, update, current_value); + update = xla::Select(index_in_range, update, current_value); // Apply the update. - buffer = body_builder->DynamicUpdateSlice(buffer, update, index); + buffer = xla::DynamicUpdateSlice(buffer, update, index); return std::vector<xla::XlaOp>{indices, updates, buffer}; }; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index b4503601f9..b9f695ac4b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -90,8 +91,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder( tensorflow::strings::StrCat("trsm_base_", k)); - auto a_param = sub->Parameter( - 0, + auto a_param = xla::Parameter( + sub.get(), 0, xla::ShapeUtil::MakeShape( b_shape.element_type(), PrependMajorDims(sub.get(), batch_dimensions, {k, k})), @@ -103,8 +104,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, } else { b_lastd = {m, k}; } - auto b_param = sub->Parameter( - 1, + auto b_param = xla::Parameter( + sub.get(), 1, xla::ShapeUtil::MakeShape( b_shape.element_type(), PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), @@ -165,11 +166,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, if (k > 1) { TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { TF_ASSIGN_OR_RETURN(auto a_slice_conj, MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); + update = xla::Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -196,7 +197,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, /*conjugate_y=*/conjugate_a)); TF_ASSIGN_OR_RETURN(auto b_slice_2, SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); + b_update = xla::Sub(b_slice_2, b_update); TF_ASSIGN_OR_RETURN( b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); } @@ -217,11 +218,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, if (k > 1) { TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { TF_ASSIGN_OR_RETURN(auto a_slice_conj, MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); + update = xla::Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -247,7 +248,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, /*conjugate_y=*/false)); TF_ASSIGN_OR_RETURN(auto b_slice_2, SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); + b_update = xla::Sub(b_slice_2, b_update); TF_ASSIGN_OR_RETURN( b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); } @@ -268,11 +269,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, if (k > 1) { TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { TF_ASSIGN_OR_RETURN(auto a_slice_conj, MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); + update = xla::Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -299,7 +300,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, /*conjugate_y=*/conjugate_a)); TF_ASSIGN_OR_RETURN(auto b_slice_2, SliceInMinorDims(builder, b, {0, 0}, {m, i})); - b_update = builder->Sub(b_slice_2, b_update); + b_update = xla::Sub(b_slice_2, b_update); TF_ASSIGN_OR_RETURN( b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); } @@ -320,11 +321,11 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, if (k > 1) { TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + update = xla::Call(builder, *solve, {a_slice, b_slice}); } else { TF_ASSIGN_OR_RETURN(auto a_slice_conj, MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); + update = xla::Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -350,7 +351,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder, /*conjugate_y=*/false)); TF_ASSIGN_OR_RETURN(auto b_slice_2, SliceInMinorDims(builder, b, {0, 0}, {i, n})); - b_update = builder->Sub(b_slice_2, b_update); + b_update = xla::Sub(b_slice_2, b_update); TF_ASSIGN_OR_RETURN( b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); } @@ -394,7 +395,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); TF_ASSIGN_OR_RETURN(auto a_slice_conj, MaybeConjugate(builder, a_slice, conjugate_a)); - auto update = builder->Div(b_slice, a_slice_conj); + auto update = xla::Div(b_slice, a_slice_conj); TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); } @@ -414,8 +415,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, // The right-hand-side matrix b is a loop invariant. b_shape}; xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1); - auto init = builder->Tuple({init_i, output, a, b}); + auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); // Construct the loop condition function, // def cond_fun(loop_carry): @@ -424,14 +425,14 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, std::unique_ptr<xla::XlaBuilder> condb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"), + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"), 0); if (transpose_a) { - condb->Ge(i, condb->ConstantR0<int32>(0)); + xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); } else { - condb->Lt(i, condb->ConstantR0<int32>(m)); + xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m)); } } TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); @@ -454,15 +455,15 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, std::unique_ptr<xla::XlaBuilder> bodyb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); { - auto input_tuple = bodyb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"); + auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"); // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0<int32>(0); + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); // We'd like to implement this: // if transpose_a: @@ -491,14 +492,14 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, TF_ASSIGN_OR_RETURN( auto result_row_slice, DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); - auto result_row = bodyb->Sub(result_row_slice, b_update); + auto result_row = xla::Sub(result_row_slice, b_update); // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, {i, i}, {1, 1})); TF_ASSIGN_OR_RETURN(auto a_elt_conj, MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_elt_conj); + auto div_result = xla::Div(result_row, a_elt_conj); TF_ASSIGN_OR_RETURN(body_out, DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, div_result, {i, zero})); @@ -507,15 +508,16 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder, // return (i - 1, body_out, a, b) // else: // return (i + 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); + auto next_i = + xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); } TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); // Construct the While loop and return the result, // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); } xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, @@ -553,8 +555,8 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, // The right-hand-side matrix b is a loop invariant. b_shape}; xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0<int32>(transpose_a ? 0 : n - 1); - auto init = builder->Tuple({init_i, output, a, b}); + auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1); + auto init = xla::Tuple(builder, {init_i, output, a, b}); // Construct the loop condition function, // def cond_fun(loop_carry): @@ -563,14 +565,14 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, std::unique_ptr<xla::XlaBuilder> condb = builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"), + auto i = xla::GetTupleElement( + xla::Parameter(condb.get(), 0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), 0); if (transpose_a) { - condb->Lt(i, condb->ConstantR0<int32>(n)); + xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n)); } else { - condb->Ge(i, condb->ConstantR0<int32>(0)); + xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0)); } } TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); @@ -593,15 +595,15 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, std::unique_ptr<xla::XlaBuilder> bodyb = builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); { - auto input_tuple = bodyb->Parameter( - 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"); // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0<int32>(0); + auto i = xla::GetTupleElement(input_tuple, 0); + auto body_out = xla::GetTupleElement(input_tuple, 1); + auto body_a = xla::GetTupleElement(input_tuple, 2); + auto body_b = xla::GetTupleElement(input_tuple, 3); + auto zero = xla::ConstantR0<int32>(bodyb.get(), 0); // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, // i:i+1]) But since we can't have intermediate array sizes depend on the @@ -613,7 +615,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, /*conjugate_x=*/false, /*conjugate_y=*/conjugate_a)); // result = b - np.matmul(output, a) - auto result = bodyb->Sub(body_b, b_update); + auto result = xla::Sub(body_b, b_update); // result_row = result[..., :, i:i+1] TF_ASSIGN_OR_RETURN( auto result_row, @@ -624,7 +626,7 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, {i, i}, {1, 1})); TF_ASSIGN_OR_RETURN(auto a_ii_conj, MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_ii_conj); + auto div_result = xla::Div(result_row, a_ii_conj); TF_ASSIGN_OR_RETURN(body_out, DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, div_result, {zero, i})); @@ -633,15 +635,16 @@ xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder, // return (i + 1, body_out, a, b) // else: // return (i - 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? 1 : -1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); + auto next_i = + xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1)); + xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); } TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); // Construct the While loop and return the result, // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); + auto triangular_solve_left_looking_while = xla::While(cond, body, init); + return xla::GetTupleElement(triangular_solve_left_looking_while, 1); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index d9ff7e6259..11774dde08 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <vector> +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -28,8 +29,8 @@ limitations under the License. namespace tensorflow { xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), + return xla::Broadcast( + xla::ConstantLiteral(builder, xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } @@ -37,19 +38,19 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, double value) { switch (type) { case xla::F16: - return builder->ConstantR0<xla::half>(static_cast<xla::half>(value)); + return xla::ConstantR0<xla::half>(builder, static_cast<xla::half>(value)); break; case xla::BF16: - return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value)); + return xla::ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value)); break; case xla::F32: - return builder->ConstantR0<float>(static_cast<float>(value)); + return xla::ConstantR0<float>(builder, static_cast<float>(value)); break; case xla::F64: - return builder->ConstantR0<double>(value); + return xla::ConstantR0<double>(builder, value); break; case xla::C64: - return builder->ConstantR0<xla::complex64>(value); + return xla::ConstantR0<xla::complex64>(builder, value); break; default: LOG(FATAL) << "unhandled element type " << type; @@ -107,7 +108,7 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, default: LOG(FATAL) << "unhandled element type " << type; } - return builder->ConstantLiteral(literal); + return xla::ConstantLiteral(builder, literal); } xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder, @@ -136,7 +137,7 @@ xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder, std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); std::vector<int64> strides(n_dims, 1); - return builder->Slice(x, padded_start, padded_end, strides); + return xla::Slice(x, padded_start, padded_end, strides); } std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder, @@ -163,7 +164,7 @@ xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims( TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(builder, x, starts)); auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); - return builder->DynamicSlice(x, padded_starts, padded_sizes); + return xla::DynamicSlice(x, padded_starts, padded_sizes); } xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder, @@ -172,7 +173,7 @@ xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder, gtl::ArraySlice<int64> start) { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector<int32> start_as_int32(start.begin(), start.end()); - auto start_constant = builder->ConstantR1<int32>(start_as_int32); + auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32); TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, @@ -180,7 +181,7 @@ xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder, const int64 start_length = xla::ShapeUtil::GetDimension(start_constant_shape, -1); TF_RET_CHECK(start_length == n_dims); - return builder->DynamicUpdateSlice(x, update, start_constant); + return xla::DynamicUpdateSlice(x, update, start_constant); } xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder, @@ -202,7 +203,7 @@ xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims( const std::vector<xla::XlaOp>& starts) { TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(builder, x, starts)); - return builder->DynamicUpdateSlice(x, update, padded_starts); + return xla::DynamicUpdateSlice(x, update, padded_starts); } xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims( @@ -210,13 +211,12 @@ xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims( const std::vector<xla::XlaOp>& starts) { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1}); + auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1}); std::vector<xla::XlaOp> padded_starts(n_dims, zero); for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = - builder->Reshape(starts[i], {1}); + padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); } - return builder->ConcatInDim(padded_starts, 0); + return xla::ConcatInDim(builder, padded_starts, 0); } xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder, @@ -227,14 +227,14 @@ xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder, std::vector<int64> permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return builder->Transpose(x, permutation); + return xla::Transpose(x, permutation); } xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder, const xla::XlaOp& x, bool conjugate) { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? builder->Conj(x) : x; + return perform_conj ? xla::Conj(x) : x; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 5f408f2ed0..2a332c933f 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -86,9 +86,10 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index); - TF_ASSERT_OK(DynamicSliceInMinorDims( - &builder, a, {index, builder.ConstantR0<int32>(0)}, {1, 4}) - .status()); + TF_ASSERT_OK( + DynamicSliceInMinorDims( + &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, 4}) + .status()); ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, {a_data.get(), index_data.get()}); @@ -129,8 +130,8 @@ XLA_TEST_F(UtilTest, RowBatchDot) { TF_ASSERT_OK_AND_ASSIGN( auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0<int32>(0)}, {1, n})); + DynamicSliceInMinorDims( + &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n})); TF_ASSERT_OK(BatchDot(&builder, l_index, row, /*transpose_x=*/false, /*transpose_y=*/true) .status()); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 09ce594930..7cc88f34d2 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,7 +40,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop( xla::XlaBuilder* builder) { std::vector<xla::XlaOp> elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = builder->GetTupleElement(tuple, i); + elements[i] = xla::GetTupleElement(tuple, i); } return elements; }; @@ -48,7 +49,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop( std::unique_ptr<xla::XlaBuilder> cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { - auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -61,7 +63,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop( std::unique_ptr<xla::XlaBuilder> body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { - auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -69,11 +72,11 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - body_builder->Tuple(result); + xla::Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); + auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } @@ -86,9 +89,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( auto while_cond_fn = [&](gtl::ArraySlice<xla::XlaOp> values, xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> { - return cond_builder->Lt( - values[0], - IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); + return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, + num_iterations)); }; auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values, xla::XlaBuilder* body_builder) @@ -97,9 +99,9 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( std::vector<xla::XlaOp> updated_values; updated_values.reserve(values.size()); - updated_values.push_back(body_builder->Add( - iteration, - body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); + updated_values.push_back(xla::Add( + iteration, xla::ConstantLiteral( + body_builder, xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs, @@ -112,7 +114,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( std::vector<xla::XlaOp> values; values.reserve(initial_values.size() + 1); values.push_back( - builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); + xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8a0fb41afb..82dd48d43d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -416,7 +417,7 @@ Status BuildComputation( // create a tuple/get-tuple-element combination so that sharding // assignment will be placed on this value, which will cause the resource // update to be returned from the same device that provided the resource. - handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); elems.push_back(handle); } @@ -426,7 +427,7 @@ Status BuildComputation( // Builds the XLA computation. if (always_return_tuple || elems.size() != 1) { - builder->Tuple(elems); + xla::Tuple(builder, elems); } builder->ClearOpMetadata(); @@ -554,16 +555,16 @@ Status XlaCompiler::BuildArguments( } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } else { - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = builder->GetTupleElement(tuple, i); + arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { @@ -571,8 +572,8 @@ Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = - builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); + arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], + strings::StrCat("arg", i)); } } @@ -603,7 +604,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression.set_handle( - builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); } else { arg_expression.set_handle(arg_handles[i]); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 67174b251d..d0b5606907 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -131,9 +131,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Max(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Max(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -145,9 +147,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Min(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Min(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -159,9 +163,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Add(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Add(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -173,9 +179,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Mul(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Mul(x, y); return b.Build().ConsumeValueOrDie(); }); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 31115eea60..917ef4037d 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -50,14 +50,14 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, xla::PrimitiveType xla_output_type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer, - /*dimensions_to_reduce=*/{axis}); + xla::XlaOp input_max = xla::Reduce(input, init_value, *reducer, + /*dimensions_to_reduce=*/{axis}); std::vector<int64> broadcast_dims(input_shape.dims() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = builder->ConvertElementType( - builder->Eq(input, input_max, broadcast_dims), xla_output_type); + xla::XlaOp partial_mask = xla::ConvertElementType( + xla::Eq(input, input_max, broadcast_dims), xla_output_type); // In order to make identity elements for a bitwise And, we: // Left shift the 1 to the leftmost bit, yielding 0x10...0 @@ -67,8 +67,8 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; xla::XlaOp shift_amount = XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::XlaOp full_mask = builder->ShiftRightArithmetic( - builder->ShiftLeft(partial_mask, shift_amount), shift_amount); + xla::XlaOp full_mask = xla::ShiftRightArithmetic( + xla::ShiftLeft(partial_mask, shift_amount), shift_amount); // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its // index. @@ -77,14 +77,14 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, const int64 axis_size = input_shape.dim_size(axis); TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); xla::XlaOp product = - builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); // If there are multiple maximum elements, choose the one with the highest // index. xla::XlaOp output = - builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), - *ctx->GetOrCreateMax(output_type), - /*dimensions_to_reduce=*/{axis}); + xla::Reduce(product, XlaHelpers::MinValue(builder, output_type), + *ctx->GetOrCreateMax(output_type), + /*dimensions_to_reduce=*/{axis}); *argminmax = output; return Status::OK(); } @@ -94,7 +94,7 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MinValue(type)); + return xla::ConstantLiteral(b, xla::Literal::MinValue(type)); } xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) { @@ -102,23 +102,23 @@ xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) { TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { case xla::F16: - return b->ConstantR0<Eigen::half>( - Eigen::NumTraits<Eigen::half>::lowest()); + return xla::ConstantR0<Eigen::half>( + b, Eigen::NumTraits<Eigen::half>::lowest()); case xla::BF16: - return b->ConstantR0<bfloat16>(bfloat16::lowest()); + return xla::ConstantR0<bfloat16>(b, bfloat16::lowest()); case xla::F32: - return b->ConstantR0<float>(-std::numeric_limits<float>::max()); + return xla::ConstantR0<float>(b, -std::numeric_limits<float>::max()); case xla::F64: - return b->ConstantR0<double>(-std::numeric_limits<double>::max()); + return xla::ConstantR0<double>(b, -std::numeric_limits<double>::max()); default: - return b->ConstantLiteral(xla::Literal::MinValue(type)); + return xla::ConstantLiteral(b, xla::Literal::MinValue(type)); } } xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MaxValue(type)); + return xla::ConstantLiteral(b, xla::Literal::MaxValue(type)); } xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) { @@ -126,42 +126,43 @@ xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) { TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { case xla::F16: - return b->ConstantR0<Eigen::half>( - Eigen::NumTraits<Eigen::half>::highest()); + return xla::ConstantR0<Eigen::half>( + b, Eigen::NumTraits<Eigen::half>::highest()); case xla::BF16: - return b->ConstantR0<bfloat16>(bfloat16::highest()); + return xla::ConstantR0<bfloat16>(b, bfloat16::highest()); case xla::F32: - return b->ConstantR0<float>(std::numeric_limits<float>::max()); + return xla::ConstantR0<float>(b, std::numeric_limits<float>::max()); case xla::F64: - return b->ConstantR0<double>(std::numeric_limits<double>::max()); + return xla::ConstantR0<double>(b, std::numeric_limits<double>::max()); default: - return b->ConstantLiteral(xla::Literal::MaxValue(type)); + return xla::ConstantLiteral(b, xla::Literal::MaxValue(type)); } } xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::Zero(type)); + return xla::ConstantLiteral(b, xla::Literal::Zero(type)); } xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::One(type)); + return xla::ConstantLiteral(b, xla::Literal::One(type)); } xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { switch (data_type) { case DT_HALF: - return b->ConstantR0<Eigen::half>( + return xla::ConstantR0<Eigen::half>( + b, static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon())); case DT_BFLOAT16: - return b->ConstantR0<bfloat16>(bfloat16::epsilon()); + return xla::ConstantR0<bfloat16>(b, bfloat16::epsilon()); case DT_FLOAT: - return b->ConstantR0<float>(std::numeric_limits<float>::epsilon()); + return xla::ConstantR0<float>(b, std::numeric_limits<float>::epsilon()); case DT_DOUBLE: - return b->ConstantR0<double>(std::numeric_limits<double>::epsilon()); + return xla::ConstantR0<double>(b, std::numeric_limits<double>::epsilon()); default: LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " << DataTypeString(data_type); @@ -250,7 +251,7 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, xla::BorrowingLiteral linspace_literal; TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - *iota = builder->ConstantLiteral(linspace_literal); + *iota = xla::ConstantLiteral(builder, linspace_literal); return Status::OK(); } @@ -292,13 +293,13 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, std::vector<int64> broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = builder->Eq( - indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); + xla::XlaOp one_hot_bool = xla::Eq( + indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. - *one_hot = builder->Select( - one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()), - builder->Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select(one_hot_bool, + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } @@ -316,7 +317,7 @@ xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); - return builder->ConvertElementType(operand, convert_to); + return xla::ConvertElementType(operand, convert_to); } } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index b58959bd6c..c2298b97e1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -128,7 +129,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( xla::XlaOp handle = expression->handle(); if (new_shape != tensor.shape()) { // Reshape the handle to the desired shape. - handle = builder()->Reshape(handle, new_shape.dim_sizes()); + handle = xla::Reshape(handle, new_shape.dim_sizes()); } // The XLA layout is specified minor to major, and TensorFlow's minor @@ -342,8 +343,7 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, if (representation_shape == variable->shape()) { *value = variable->value(); } else { - *value = - builder()->Reshape(variable->value(), variable->shape().dim_sizes()); + *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); } return Status::OK(); } @@ -394,7 +394,7 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { xla::BorrowingLiteral literal; OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); - xla::XlaOp handle = builder()->ConstantLiteral(literal); + xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); CHECK(handle.valid()); // Make the Tensor that will refer to the expression. @@ -462,7 +462,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, TensorShape representation_shape = xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { - handle = builder()->Reshape(handle, representation_shape.dim_sizes()); + handle = xla::Reshape(handle, representation_shape.dim_sizes()); } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 540c65c597..baea814965 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { @@ -89,16 +90,16 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } switch (kind_) { case kVariable: { - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - shape_.dim_sizes()); + value_ = + xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes()); break; } case kTensorArray: { TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()); + value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); break; } case kStack: { @@ -106,9 +107,9 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); value_ = - builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()), - builder->ConstantR0<int32>(0)}); + xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + xla::ConstantR0<int32>(builder, 0)}); break; } @@ -130,8 +131,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - xla::XlaOp gradient_value = builder->Broadcast( - XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); + xla::XlaOp gradient_value = + xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/strings::StrCat("TensorArrayGrad: ", name_), @@ -152,7 +153,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { for (const auto& gradient : tensor_array_gradients_) { elems.push_back(gradient.second->value_); } - *pack = builder->Tuple(elems); + *pack = xla::Tuple(builder, elems); } return Status::OK(); } @@ -168,7 +169,7 @@ Status XlaResource::SetFromPack(const std::set<string>& gradient_sources, } else { TF_RET_CHECK(kind_ == kTensorArray); int pos = 0; - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!initialized()) { initial_value_ = v; } @@ -178,7 +179,7 @@ Status XlaResource::SetFromPack(const std::set<string>& gradient_sources, XlaResource* gradient; TF_RETURN_IF_ERROR( GetOrCreateTensorArrayGradient(source, builder, &gradient)); - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!gradient->initialized()) { gradient->initial_value_ = v; } |