aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc47
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc15
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc192
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc8
5 files changed, 160 insertions, 105 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 3b96bc72be..c3c824a231 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1390,14 +1390,57 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
+ TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape,
+ GetShape(gather_indices));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferGatherShape(input_shape, gather_indices_shape,
+ dimension_numbers, window_bounds));
+
+ *instr.mutable_gather_dimension_numbers() = dimension_numbers;
+ for (int64 bound : window_bounds) {
+ instr.add_gather_window_bounds(bound);
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kGather,
+ {input, gather_indices});
+ });
}
XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate));
+ TF_ASSIGN_OR_RETURN(const Shape& true_operand_shape,
+ GetShape(true_operand));
+ TF_ASSIGN_OR_RETURN(const ProgramShape& true_computation_shape,
+ true_computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(const Shape& false_operand_shape,
+ GetShape(false_operand));
+ TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape,
+ false_computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferConditionalShape(
+ predicate_shape, true_operand_shape, false_operand_shape,
+ true_computation_shape, false_computation_shape));
+
+ // The index of true_computation must be 0 and that of false computation
+ // must be 1.
+ AddCalledComputation(true_computation, &instr);
+ AddCalledComputation(false_computation, &instr);
+
+ return AddInstruction(std::move(instr), HloOpcode::kConditional,
+ {predicate, true_operand, false_operand});
+ });
}
XlaOp XlaBuilder::Reduce(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8149e47cb5..3629106a25 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -159,6 +159,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->fft_length_.push_back(fft_len);
}
+ if (proto.has_gather_dimension_numbers()) {
+ instruction->gather_dimension_numbers_ =
+ MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers());
+ }
+ for (int64 bound : proto.gather_window_bounds()) {
+ instruction->gather_window_bounds_.push_back(bound);
+ }
+
return std::move(instruction);
}
@@ -2416,6 +2424,13 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.add_fft_length(fft_len);
}
+ if (gather_dimension_numbers_ != nullptr) {
+ *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_;
+ }
+ for (int64 bound : gather_window_bounds_) {
+ proto.add_gather_window_bounds(bound);
+ }
+
return proto;
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 74ea1a0f39..1f90a44d8b 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -492,9 +492,10 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index b917dee77b..7ff6706935 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -23,8 +24,8 @@ namespace {
class ConditionalOpTest : public ClientLibraryTestBase {
protected:
- Computation CreateR0ConstantComputation(float value) {
- ComputationBuilder builder(client_, "Constant");
+ XlaComputation CreateR0ConstantComputation(float value) {
+ XlaBuilder builder("Constant");
builder.Parameter(0, empty_tuple_, "tuple");
builder.ConstantR0<float>(value);
auto build_status = builder.Build();
@@ -32,16 +33,16 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0IdentityComputation() {
- ComputationBuilder builder(client_, "Identity");
+ XlaComputation CreateR0IdentityComputation() {
+ XlaBuilder builder("Identity");
builder.Parameter(0, r0f32_, "x");
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
}
- Computation CreateCeilComputation(const Shape& shape) {
- ComputationBuilder builder(client_, "Ceil");
+ XlaComputation CreateCeilComputation(const Shape& shape) {
+ XlaBuilder builder("Ceil");
auto param = builder.Parameter(0, shape, "param");
builder.Ceil(param);
auto build_status = builder.Build();
@@ -49,16 +50,16 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0CeilComputation() {
+ XlaComputation CreateR0CeilComputation() {
return CreateCeilComputation(r0f32_);
}
- Computation CreateR1CeilComputation() {
+ XlaComputation CreateR1CeilComputation() {
return CreateCeilComputation(r1s2f32_);
}
- Computation CreateFloorComputation(const Shape& shape) {
- ComputationBuilder builder(client_, "Floor");
+ XlaComputation CreateFloorComputation(const Shape& shape) {
+ XlaBuilder builder("Floor");
auto param = builder.Parameter(0, shape, "param");
builder.Floor(param);
auto build_status = builder.Build();
@@ -66,17 +67,17 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0FloorComputation() {
+ XlaComputation CreateR0FloorComputation() {
return CreateFloorComputation(r0f32_);
}
- Computation CreateR1FloorComputation() {
+ XlaComputation CreateR1FloorComputation() {
return CreateFloorComputation(r1s2f32_);
}
- Computation CreateTupleCeilComputation(const string& computation_name,
- const Shape& tuple_shape) {
- ComputationBuilder builder(client_, computation_name);
+ XlaComputation CreateTupleCeilComputation(const string& computation_name,
+ const Shape& tuple_shape) {
+ XlaBuilder builder(computation_name);
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
auto x = builder.GetTupleElement(tuple, 0);
auto y = builder.GetTupleElement(tuple, 1);
@@ -88,17 +89,17 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0TupleCeilComputation() {
+ XlaComputation CreateR0TupleCeilComputation() {
return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
}
- Computation CreateR1TupleCeilComputation() {
+ XlaComputation CreateR1TupleCeilComputation() {
return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
}
- Computation CreateTupleFloorComputation(const string& computation_name,
- const Shape& tuple_shape) {
- ComputationBuilder builder(client_, computation_name);
+ XlaComputation CreateTupleFloorComputation(const string& computation_name,
+ const Shape& tuple_shape) {
+ XlaBuilder builder(computation_name);
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
auto x = builder.GetTupleElement(tuple, 0);
auto y = builder.GetTupleElement(tuple, 1);
@@ -110,17 +111,17 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0TupleFloorComputation() {
+ XlaComputation CreateR0TupleFloorComputation() {
return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
}
- Computation CreateR1TupleFloorComputation() {
+ XlaComputation CreateR1TupleFloorComputation() {
return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
}
- Computation CreateTupleAddComputation(const string& computation_name,
- const Shape& tuple_shape) {
- ComputationBuilder builder(client_, computation_name);
+ XlaComputation CreateTupleAddComputation(const string& computation_name,
+ const Shape& tuple_shape) {
+ XlaBuilder builder(computation_name);
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
auto x = builder.GetTupleElement(tuple, 0);
auto y = builder.GetTupleElement(tuple, 1);
@@ -130,17 +131,17 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0TupleAddComputation() {
+ XlaComputation CreateR0TupleAddComputation() {
return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
}
- Computation CreateR1TupleAddComputation() {
+ XlaComputation CreateR1TupleAddComputation() {
return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
}
- Computation CreateTupleSubComputation(const string& computation_name,
- const Shape& tuple_shape) {
- ComputationBuilder builder(client_, computation_name);
+ XlaComputation CreateTupleSubComputation(const string& computation_name,
+ const Shape& tuple_shape) {
+ XlaBuilder builder(computation_name);
auto tuple = builder.Parameter(0, tuple_shape, "tuple");
auto x = builder.GetTupleElement(tuple, 0);
auto y = builder.GetTupleElement(tuple, 1);
@@ -150,11 +151,11 @@ class ConditionalOpTest : public ClientLibraryTestBase {
return build_status.ConsumeValueOrDie();
}
- Computation CreateR0TupleSubComputation() {
+ XlaComputation CreateR0TupleSubComputation() {
return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
}
- Computation CreateR1TupleSubComputation() {
+ XlaComputation CreateR1TupleSubComputation() {
return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
}
@@ -170,26 +171,25 @@ class ConditionalOpTest : public ClientLibraryTestBase {
// Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest, Parameters0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operands = builder.Tuple({});
auto true_computation = CreateR0ConstantComputation(56.0f);
auto false_computation = CreateR0ConstantComputation(12.0f);
- auto result = builder.Conditional(pred, operands, true_computation, operands,
- false_computation);
+ builder.Conditional(pred, operands, true_computation, operands,
+ false_computation);
ComputeAndCompareR0<float>(&builder, 56.0f, {}, error_spec_);
}
// Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest, Parameters1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
auto identity = CreateR0IdentityComputation();
- auto result =
- builder.Conditional(pred, operand1, identity, operand2, identity);
+ builder.Conditional(pred, operand1, identity, operand2, identity);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -197,12 +197,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) {
// Test conditional with two different computations in the true and false cases
// that take in different arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.4f);
auto operand2 = builder.ConstantR0<float>(12.6f);
- auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(),
- operand2, CreateR0FloorComputation());
+ builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -210,11 +210,11 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
// Test conditional with two different computations in the true and false cases
// that take in the same arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand = builder.ConstantR0<float>(12.6f);
- auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(),
- operand, CreateR0FloorComputation());
+ builder.Conditional(pred, operand, CreateR0CeilComputation(), operand,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -222,12 +222,12 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
// Test conditional with the same computation in the true and false cases but
// take in different arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.4f);
auto operand2 = builder.ConstantR0<float>(12.6f);
auto floor = CreateR0FloorComputation();
- auto result = builder.Conditional(pred, operand1, floor, operand2, floor);
+ builder.Conditional(pred, operand1, floor, operand2, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -235,11 +235,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
// Test conditional with the same computation in the true and false cases that
// take in the same arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand = builder.ConstantR0<float>(12.6f);
auto floor = CreateR0FloorComputation();
- auto result = builder.Conditional(pred, operand, floor, operand, floor);
+ builder.Conditional(pred, operand, floor, operand, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -247,12 +247,12 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
// Test conditional with different instances of the same computation in the true
// and false cases.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.4f);
auto operand2 = builder.ConstantR0<float>(12.6f);
- auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(),
- operand2, CreateR0FloorComputation());
+ builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -260,7 +260,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
// Test the case when a call invokes a computation that contains a conditional.
XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
- ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
+ XlaBuilder inner_builder(TestName() + ".inner_conditional");
auto pred_cond = inner_builder.Parameter(0, r0bool, "param0");
auto true_operand = inner_builder.Parameter(1, r0f32_, "param1");
auto false_operand = inner_builder.Parameter(2, r0f32_, "param2");
@@ -268,7 +268,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
false_operand, CreateR0FloorComputation());
auto inner_builder_result = inner_builder.Build();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.4f);
auto operand2 = builder.ConstantR0<float>(12.6f);
@@ -281,14 +281,13 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
// Test true and false computations that take in 2 parameters and predicate is
// true.
XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
auto operands = builder.Tuple({operand1, operand2});
- auto result =
- builder.Conditional(pred, operands, CreateR0TupleAddComputation(),
- operands, CreateR0TupleSubComputation());
+ builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
+ CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_);
}
@@ -296,14 +295,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
// Test true and false computations that take in 2 parameters and predicate is
// false.
XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
auto operands = builder.Tuple({operand1, operand2});
- auto result =
- builder.Conditional(pred, operands, CreateR0TupleAddComputation(),
- operands, CreateR0TupleSubComputation());
+ builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
+ CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_);
}
@@ -311,14 +309,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
// Test true and false computations that take in 2 array parameters and
// predicate is true.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
auto operands = builder.Tuple({operand1, operand2});
- auto result =
- builder.Conditional(pred, operands, CreateR1TupleAddComputation(),
- operands, CreateR1TupleSubComputation());
+ builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
+ CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_);
}
@@ -326,21 +323,20 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
// Test true and false computations that take in 2 array parameters and
// predicate is false.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
auto operands = builder.Tuple({operand1, operand2});
- auto result =
- builder.Conditional(pred, operands, CreateR1TupleAddComputation(),
- operands, CreateR1TupleSubComputation());
+ builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
+ CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_);
}
// Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operands = builder.Tuple(
{builder.ConstantR0<float>(12.2f), builder.ConstantR0<float>(25.6f)});
@@ -356,7 +352,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
// Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}),
builder.ConstantR1<float>({25.6f, 29.2f})});
@@ -373,7 +369,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
// Test true and false computations that return a tuple of a predicate, a
// scalar, and an array.
XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
- ComputationBuilder true_builder(client_, TestName() + ".true");
+ XlaBuilder true_builder(TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
auto true_pred = true_builder.ConstantR0<bool>(true);
@@ -384,7 +380,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
auto true_builder_result = true_builder.Build();
EXPECT_IS_OK(true_builder_result.status());
- ComputationBuilder false_builder(client_, TestName() + ".false");
+ XlaBuilder false_builder(TestName() + ".false");
{
false_builder.Parameter(0, empty_tuple_, "tuple");
auto false_pred = false_builder.ConstantR0<bool>(false);
@@ -395,7 +391,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
auto false_builder_result = false_builder.Build();
EXPECT_IS_OK(false_builder_result.status());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operands = builder.Tuple({});
builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
@@ -411,7 +407,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
// Test true and false computations that return a nested tuple.
XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
- ComputationBuilder true_builder(client_, TestName() + ".true");
+ XlaBuilder true_builder(TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
auto true_constant1 = true_builder.ConstantR0<float>(12.2f);
@@ -424,7 +420,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
auto true_builder_result = true_builder.Build();
EXPECT_IS_OK(true_builder_result.status());
- ComputationBuilder false_builder(client_, TestName() + ".false");
+ XlaBuilder false_builder(TestName() + ".false");
{
false_builder.Parameter(0, empty_tuple_, "tuple");
auto false_constant1 = false_builder.ConstantR0<float>(46.6f);
@@ -438,7 +434,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
auto false_builder_result = false_builder.Build();
EXPECT_IS_OK(false_builder_result.status());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(false);
auto operands = builder.Tuple({});
builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
@@ -460,16 +456,16 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
// params.
XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
- ComputationDataHandle pred, operand1, operand2;
+ XlaOp pred, operand1, operand2;
auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1_param =
CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
auto operand2_param =
CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
- auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(),
- operand2, CreateR0FloorComputation());
+ builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(
&builder, 57.0f,
@@ -480,16 +476,16 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
// Test conditional that takes in array operands in the form of external params.
XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
- ComputationDataHandle pred, operand1, operand2;
+ XlaOp pred, operand1, operand2;
auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
&builder, &operand1);
auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
&builder, &operand2);
- auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(),
- operand2, CreateR1FloorComputation());
+ builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
+ CreateR1FloorComputation());
ComputeAndCompareR1<float>(
&builder, {10.0f, 11.0f},
@@ -499,7 +495,7 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
// Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
- ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
+ XlaBuilder inner_builder(TestName() + ".inner_conditional");
{
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
@@ -514,7 +510,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
auto inner_builder_result = inner_builder.Build();
EXPECT_IS_OK(inner_builder_result.status());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred1 = builder.ConstantR0<bool>(true);
auto pred2 = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(1.1f);
@@ -529,7 +525,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
}
XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
- ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
+ XlaBuilder inner_builder(TestName() + ".inner_conditional");
{
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
@@ -544,7 +540,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
auto inner_builder_result = inner_builder.Build();
EXPECT_IS_OK(inner_builder_result.status());
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred2 = builder.ConstantR0<bool>(false);
auto operand1 = builder.ConstantR0<float>(1.1f);
auto operand2 = builder.ConstantR0<float>(12.2f);
@@ -556,7 +552,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
// Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operand1 = builder.ConstantR0<float>(56.0f);
auto operand2 = builder.ConstantR0<float>(12.0f);
@@ -573,27 +569,27 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
- Computation swapper;
+ XlaComputation swapper;
{
- ComputationBuilder builder(client_, TestName() + ".swapper");
+ XlaBuilder builder(TestName() + ".swapper");
auto param0 = builder.Parameter(0, tuple_shape, "sp0");
auto x = builder.GetTupleElement(param0, 0);
auto y = builder.GetTupleElement(param0, 1);
builder.Tuple({y, x});
swapper = builder.Build().ConsumeValueOrDie();
}
- Computation forwarder;
+ XlaComputation forwarder;
{
- ComputationBuilder builder(client_, TestName() + ".forwarder");
+ XlaBuilder builder(TestName() + ".forwarder");
auto param0 = builder.Parameter(0, tuple_shape, "fp0");
auto x = builder.GetTupleElement(param0, 0);
auto y = builder.GetTupleElement(param0, 1);
builder.Tuple({x, y});
forwarder = builder.Build().ConsumeValueOrDie();
}
- Computation main;
+ XlaComputation main;
{
- ComputationBuilder builder(client_, TestName() + ".main");
+ XlaBuilder builder(TestName() + ".main");
auto param0 = builder.Parameter(0, tuple_shape, "mp0");
auto x = builder.GetTupleElement(param0, 0);
auto y = builder.GetTupleElement(param0, 1);
@@ -605,7 +601,7 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
}
auto test_swap = [&](float a, float b) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR0<float>(a);
auto y = builder.ConstantR0<float>(b);
auto tuple_operand = builder.Tuple({x, y});
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 9db68ff7a6..90496d55e6 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -405,7 +405,7 @@ class GatherClientLibraryTest : public ClientLibraryTestBase {};
// GPU and CPU_PARALLEL.
XLA_TEST_F(GatherClientLibraryTest,
DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) {
- // We create this HLO, but using the ComputationBuilder API.
+ // We create this HLO, but using the XlaBuilder API.
//
// ENTRY main {
// operand = s32[3,3] parameter(0)
@@ -418,7 +418,7 @@ XLA_TEST_F(GatherClientLibraryTest,
// window_bounds={1, 3}
// }
- ComputationBuilder builder(client_, "gather_basic");
+ XlaBuilder builder("gather_basic");
Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
@@ -443,8 +443,8 @@ XLA_TEST_F(GatherClientLibraryTest,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
*execution_options.add_device_handles() = devices[0];
- TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build());
- std::vector<xla::Client::ComputationInstance> computation_instances = {
+ TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build());
+ std::vector<xla::Client::XlaComputationInstance> computation_instances = {
{computation,
{operand_arg.get(), indices_arg.get()},
execution_options,