diff options
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 47 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/conditional_test.cc | 192 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/gather_operation_test.cc | 8 |
5 files changed, 160 insertions, 105 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 3b96bc72be..c3c824a231 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1390,14 +1390,57 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape, + GetShape(gather_indices)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferGatherShape(input_shape, gather_indices_shape, + dimension_numbers, window_bounds)); + + *instr.mutable_gather_dimension_numbers() = dimension_numbers; + for (int64 bound : window_bounds) { + instr.add_gather_window_bounds(bound); + } + + return AddInstruction(std::move(instr), HloOpcode::kGather, + {input, gather_indices}); + }); } XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); + TF_ASSIGN_OR_RETURN(const Shape& true_operand_shape, + GetShape(true_operand)); + TF_ASSIGN_OR_RETURN(const ProgramShape& true_computation_shape, + true_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const Shape& false_operand_shape, + GetShape(false_operand)); + TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, + false_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConditionalShape( + predicate_shape, true_operand_shape, false_operand_shape, + true_computation_shape, false_computation_shape)); + + // The index of true_computation must be 0 and that of false computation + // must be 1. + AddCalledComputation(true_computation, &instr); + AddCalledComputation(false_computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kConditional, + {predicate, true_operand, false_operand}); + }); } XlaOp XlaBuilder::Reduce( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8149e47cb5..3629106a25 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -159,6 +159,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->fft_length_.push_back(fft_len); } + if (proto.has_gather_dimension_numbers()) { + instruction->gather_dimension_numbers_ = + MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); + } + for (int64 bound : proto.gather_window_bounds()) { + instruction->gather_window_bounds_.push_back(bound); + } + return std::move(instruction); } @@ -2416,6 +2424,13 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_fft_length(fft_len); } + if (gather_dimension_numbers_ != nullptr) { + *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; + } + for (int64 bound : gather_window_bounds_) { + proto.add_gather_window_bounds(bound); + } + return proto; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 74ea1a0f39..1f90a44d8b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -492,9 +492,10 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index b917dee77b..7ff6706935 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -23,8 +24,8 @@ namespace { class ConditionalOpTest : public ClientLibraryTestBase { protected: - Computation CreateR0ConstantComputation(float value) { - ComputationBuilder builder(client_, "Constant"); + XlaComputation CreateR0ConstantComputation(float value) { + XlaBuilder builder("Constant"); builder.Parameter(0, empty_tuple_, "tuple"); builder.ConstantR0<float>(value); auto build_status = builder.Build(); @@ -32,16 +33,16 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0IdentityComputation() { - ComputationBuilder builder(client_, "Identity"); + XlaComputation CreateR0IdentityComputation() { + XlaBuilder builder("Identity"); builder.Parameter(0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateCeilComputation(const Shape& shape) { - ComputationBuilder builder(client_, "Ceil"); + XlaComputation CreateCeilComputation(const Shape& shape) { + XlaBuilder builder("Ceil"); auto param = builder.Parameter(0, shape, "param"); builder.Ceil(param); auto build_status = builder.Build(); @@ -49,16 +50,16 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0CeilComputation() { + XlaComputation CreateR0CeilComputation() { return CreateCeilComputation(r0f32_); } - Computation CreateR1CeilComputation() { + XlaComputation CreateR1CeilComputation() { return CreateCeilComputation(r1s2f32_); } - Computation CreateFloorComputation(const Shape& shape) { - ComputationBuilder builder(client_, "Floor"); + XlaComputation CreateFloorComputation(const Shape& shape) { + XlaBuilder builder("Floor"); auto param = builder.Parameter(0, shape, "param"); builder.Floor(param); auto build_status = builder.Build(); @@ -66,17 +67,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0FloorComputation() { + XlaComputation CreateR0FloorComputation() { return CreateFloorComputation(r0f32_); } - Computation CreateR1FloorComputation() { + XlaComputation CreateR1FloorComputation() { return CreateFloorComputation(r1s2f32_); } - Computation CreateTupleCeilComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleCeilComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -88,17 +89,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleCeilComputation() { + XlaComputation CreateR0TupleCeilComputation() { return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_); } - Computation CreateR1TupleCeilComputation() { + XlaComputation CreateR1TupleCeilComputation() { return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_); } - Computation CreateTupleFloorComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleFloorComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -110,17 +111,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleFloorComputation() { + XlaComputation CreateR0TupleFloorComputation() { return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_); } - Computation CreateR1TupleFloorComputation() { + XlaComputation CreateR1TupleFloorComputation() { return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_); } - Computation CreateTupleAddComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleAddComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -130,17 +131,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleAddComputation() { + XlaComputation CreateR0TupleAddComputation() { return CreateTupleAddComputation("AddR0", tuple_2_r0f32_); } - Computation CreateR1TupleAddComputation() { + XlaComputation CreateR1TupleAddComputation() { return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_); } - Computation CreateTupleSubComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleSubComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -150,11 +151,11 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleSubComputation() { + XlaComputation CreateR0TupleSubComputation() { return CreateTupleSubComputation("SubR0", tuple_2_r0f32_); } - Computation CreateR1TupleSubComputation() { + XlaComputation CreateR1TupleSubComputation() { return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_); } @@ -170,26 +171,25 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(true); auto operands = builder.Tuple({}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); - auto result = builder.Conditional(pred, operands, true_computation, operands, - false_computation); + builder.Conditional(pred, operands, true_computation, operands, + false_computation); ComputeAndCompareR0<float>(&builder, 56.0f, {}, error_spec_); } // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); auto identity = CreateR0IdentityComputation(); - auto result = - builder.Conditional(pred, operand1, identity, operand2, identity); + builder.Conditional(pred, operand1, identity, operand2, identity); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -197,12 +197,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { // Test conditional with two different computations in the true and false cases // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.4f); auto operand2 = builder.ConstantR0<float>(12.6f); - auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -210,11 +210,11 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { // Test conditional with two different computations in the true and false cases // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand = builder.ConstantR0<float>(12.6f); - auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(), - operand, CreateR0FloorComputation()); + builder.Conditional(pred, operand, CreateR0CeilComputation(), operand, + CreateR0FloorComputation()); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -222,12 +222,12 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { // Test conditional with the same computation in the true and false cases but // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.4f); auto operand2 = builder.ConstantR0<float>(12.6f); auto floor = CreateR0FloorComputation(); - auto result = builder.Conditional(pred, operand1, floor, operand2, floor); + builder.Conditional(pred, operand1, floor, operand2, floor); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -235,11 +235,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { // Test conditional with the same computation in the true and false cases that // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand = builder.ConstantR0<float>(12.6f); auto floor = CreateR0FloorComputation(); - auto result = builder.Conditional(pred, operand, floor, operand, floor); + builder.Conditional(pred, operand, floor, operand, floor); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -247,12 +247,12 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { // Test conditional with different instances of the same computation in the true // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.4f); auto operand2 = builder.ConstantR0<float>(12.6f); - auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -260,7 +260,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { // Test the case when a call invokes a computation that contains a conditional. XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); @@ -268,7 +268,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { false_operand, CreateR0FloorComputation()); auto inner_builder_result = inner_builder.Build(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.4f); auto operand2 = builder.ConstantR0<float>(12.6f); @@ -281,14 +281,13 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { // Test true and false computations that take in 2 parameters and predicate is // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(true); auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), - operands, CreateR0TupleSubComputation()); + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_); } @@ -296,14 +295,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { // Test true and false computations that take in 2 parameters and predicate is // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), - operands, CreateR0TupleSubComputation()); + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_); } @@ -311,14 +309,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { // Test true and false computations that take in 2 array parameters and // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(true); auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f}); auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), - operands, CreateR1TupleSubComputation()); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_); } @@ -326,21 +323,20 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { // Test true and false computations that take in 2 array parameters and // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f}); auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), - operands, CreateR1TupleSubComputation()); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_); } // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operands = builder.Tuple( {builder.ConstantR0<float>(12.2f), builder.ConstantR0<float>(25.6f)}); @@ -356,7 +352,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(true); auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}), builder.ConstantR1<float>({25.6f, 29.2f})}); @@ -373,7 +369,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { // Test true and false computations that return a tuple of a predicate, a // scalar, and an array. XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { - ComputationBuilder true_builder(client_, TestName() + ".true"); + XlaBuilder true_builder(TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); auto true_pred = true_builder.ConstantR0<bool>(true); @@ -384,7 +380,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); - ComputationBuilder false_builder(client_, TestName() + ".false"); + XlaBuilder false_builder(TestName() + ".false"); { false_builder.Parameter(0, empty_tuple_, "tuple"); auto false_pred = false_builder.ConstantR0<bool>(false); @@ -395,7 +391,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(true); auto operands = builder.Tuple({}); builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), @@ -411,7 +407,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { // Test true and false computations that return a nested tuple. XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { - ComputationBuilder true_builder(client_, TestName() + ".true"); + XlaBuilder true_builder(TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); auto true_constant1 = true_builder.ConstantR0<float>(12.2f); @@ -424,7 +420,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); - ComputationBuilder false_builder(client_, TestName() + ".false"); + XlaBuilder false_builder(TestName() + ".false"); { false_builder.Parameter(0, empty_tuple_, "tuple"); auto false_constant1 = false_builder.ConstantR0<float>(46.6f); @@ -438,7 +434,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(false); auto operands = builder.Tuple({}); builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), @@ -460,16 +456,16 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { // params. XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle pred, operand1, operand2; + XlaOp pred, operand1, operand2; auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred); auto operand1_param = CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2); - auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0<float>( &builder, 57.0f, @@ -480,16 +476,16 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { // Test conditional that takes in array operands in the form of external params. XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle pred, operand1, operand2; + XlaOp pred, operand1, operand2; auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2", &builder, &operand2); - auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(), - operand2, CreateR1FloorComputation()); + builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2, + CreateR1FloorComputation()); ComputeAndCompareR1<float>( &builder, {10.0f, 11.0f}, @@ -499,7 +495,7 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { // Test the case where one conditional is nested within another. XLA_TEST_F(ConditionalOpTest, NestedConditionals) { - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); @@ -514,7 +510,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred1 = builder.ConstantR0<bool>(true); auto pred2 = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(1.1f); @@ -529,7 +525,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { } XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); @@ -544,7 +540,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred2 = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(1.1f); auto operand2 = builder.ConstantR0<float>(12.2f); @@ -556,7 +552,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { // Test a mismatch in the shape of the true operand and true computation. XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0<bool>(true); auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); @@ -573,27 +569,27 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_}); - Computation swapper; + XlaComputation swapper; { - ComputationBuilder builder(client_, TestName() + ".swapper"); + XlaBuilder builder(TestName() + ".swapper"); auto param0 = builder.Parameter(0, tuple_shape, "sp0"); auto x = builder.GetTupleElement(param0, 0); auto y = builder.GetTupleElement(param0, 1); builder.Tuple({y, x}); swapper = builder.Build().ConsumeValueOrDie(); } - Computation forwarder; + XlaComputation forwarder; { - ComputationBuilder builder(client_, TestName() + ".forwarder"); + XlaBuilder builder(TestName() + ".forwarder"); auto param0 = builder.Parameter(0, tuple_shape, "fp0"); auto x = builder.GetTupleElement(param0, 0); auto y = builder.GetTupleElement(param0, 1); builder.Tuple({x, y}); forwarder = builder.Build().ConsumeValueOrDie(); } - Computation main; + XlaComputation main; { - ComputationBuilder builder(client_, TestName() + ".main"); + XlaBuilder builder(TestName() + ".main"); auto param0 = builder.Parameter(0, tuple_shape, "mp0"); auto x = builder.GetTupleElement(param0, 0); auto y = builder.GetTupleElement(param0, 1); @@ -605,7 +601,7 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { } auto test_swap = [&](float a, float b) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR0<float>(a); auto y = builder.ConstantR0<float>(b); auto tuple_operand = builder.Tuple({x, y}); diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 9db68ff7a6..90496d55e6 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -405,7 +405,7 @@ class GatherClientLibraryTest : public ClientLibraryTestBase {}; // GPU and CPU_PARALLEL. XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) { - // We create this HLO, but using the ComputationBuilder API. + // We create this HLO, but using the XlaBuilder API. // // ENTRY main { // operand = s32[3,3] parameter(0) @@ -418,7 +418,7 @@ XLA_TEST_F(GatherClientLibraryTest, // window_bounds={1, 3} // } - ComputationBuilder builder(client_, "gather_basic"); + XlaBuilder builder("gather_basic"); Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); @@ -443,8 +443,8 @@ XLA_TEST_F(GatherClientLibraryTest, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); *execution_options.add_device_handles() = devices[0]; - TF_ASSERT_OK_AND_ASSIGN(Computation computation, builder.Build()); - std::vector<xla::Client::ComputationInstance> computation_instances = { + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + std::vector<xla::Client::XlaComputationInstance> computation_instances = { {computation, {operand_arg.get(), indices_arg.get()}, execution_options, |