diff options
author | 2018-01-08 16:38:01 -0800 | |
---|---|---|
committer | 2018-01-08 16:41:51 -0800 | |
commit | 8398ece2cf59c6a4596b3bdf79d7e8553b953afa (patch) | |
tree | fe74d7f3195b8bcca26fd99420369d7749d707f4 | |
parent | fffcacb1d0537a353c4e9303f98d1e9b2e83bf23 (diff) |
[XLA] Add tests for conditionals that return tuples and whose predicate operands are external parameters.
PiperOrigin-RevId: 181237427
-rw-r--r-- | tensorflow/compiler/xla/tests/conditional_test.cc | 334 |
1 files changed, 281 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index c8c4932be8..0016b6cc61 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -23,7 +23,7 @@ namespace { class ConditionalOpTest : public ClientLibraryTestBase { protected: - Computation CreateR0F32ConstantComputation(float value) { + Computation CreateR0ConstantComputation(float value) { ComputationBuilder builder(client_, "Constant"); builder.Parameter(0, empty_tuple_, "tuple"); builder.ConstantR0<float>(value); @@ -32,7 +32,7 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0F32IdentityComputation() { + Computation CreateR0IdentityComputation() { ComputationBuilder builder(client_, "Identity"); builder.Parameter(0, r0f32_, "x"); auto build_status = builder.Build(); @@ -40,25 +40,85 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0F32CeilComputation() { + Computation CreateCeilComputation(const Shape& shape) { ComputationBuilder builder(client_, "Ceil"); - auto param = builder.Parameter(0, r0f32_, "param"); + auto param = builder.Parameter(0, shape, "param"); builder.Ceil(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateR0F32FloorComputation() { - ComputationBuilder builder(client_, "Ceil"); - auto param = builder.Parameter(0, r0f32_, "param"); + Computation CreateR0CeilComputation() { + return CreateCeilComputation(r0f32_); + } + + Computation CreateR1CeilComputation() { + return CreateCeilComputation(r1s2f32_); + } + + Computation CreateFloorComputation(const Shape& shape) { + ComputationBuilder builder(client_, "Floor"); + auto param = builder.Parameter(0, shape, "param"); builder.Floor(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateAddTupleComputation(const string& computation_name, + Computation CreateR0FloorComputation() { + return CreateFloorComputation(r0f32_); + } + + Computation CreateR1FloorComputation() { + return CreateFloorComputation(r1s2f32_); + } + + Computation CreateTupleCeilComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + auto x_ceil = builder.Ceil(x); + auto y_ceil = builder.Ceil(y); + builder.Tuple({x_ceil, y_ceil}); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleCeilComputation() { + return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleCeilComputation() { + return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleFloorComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + auto x_floor = builder.Floor(x); + auto y_floor = builder.Floor(y); + builder.Tuple({x_floor, y_floor}); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleFloorComputation() { + return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleFloorComputation() { + return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleAddComputation(const string& computation_name, const Shape& tuple_shape) { ComputationBuilder builder(client_, computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); @@ -70,15 +130,15 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateAddR0Computation() { - return CreateAddTupleComputation("AddR0", tuple_2_r0f32_); + Computation CreateR0TupleAddComputation() { + return CreateTupleAddComputation("AddR0", tuple_2_r0f32_); } - Computation CreateAddR1Computation() { - return CreateAddTupleComputation("AddR1", tuple_2_r1s2f32_); + Computation CreateR1TupleAddComputation() { + return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_); } - Computation CreateSubTupleComputation(const string& computation_name, + Computation CreateTupleSubComputation(const string& computation_name, const Shape& tuple_shape) { ComputationBuilder builder(client_, computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); @@ -90,15 +150,16 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateSubR0Computation() { - return CreateSubTupleComputation("SubR0", tuple_2_r0f32_); + Computation CreateR0TupleSubComputation() { + return CreateTupleSubComputation("SubR0", tuple_2_r0f32_); } - Computation CreateSubR1Computation() { - return CreateSubTupleComputation("SubR1", tuple_2_r1s2f32_); + Computation CreateR1TupleSubComputation() { + return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_); } Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2}); Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape( @@ -112,8 +173,8 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) { ComputationBuilder builder(client_, TestName()); auto pred = builder.ConstantR0<bool>(true); auto operands = builder.Tuple({}); - auto true_computation = CreateR0F32ConstantComputation(56.0f); - auto false_computation = CreateR0F32ConstantComputation(12.0f); + auto true_computation = CreateR0ConstantComputation(56.0f); + auto false_computation = CreateR0ConstantComputation(12.0f); auto result = builder.Conditional(pred, operands, true_computation, operands, false_computation); @@ -126,7 +187,7 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); - auto identity = CreateR0F32IdentityComputation(); + auto identity = CreateR0IdentityComputation(); auto result = builder.Conditional(pred, operand1, identity, operand2, identity); @@ -140,9 +201,8 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.4f); auto operand2 = builder.ConstantR0<float>(12.6f); - auto result = - builder.Conditional(pred, operand1, CreateR0F32CeilComputation(), - operand2, CreateR0F32FloorComputation()); + auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), + operand2, CreateR0FloorComputation()); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -153,8 +213,8 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { ComputationBuilder builder(client_, TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand = builder.ConstantR0<float>(12.6f); - auto result = builder.Conditional(pred, operand, CreateR0F32CeilComputation(), - operand, CreateR0F32FloorComputation()); + auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(), + operand, CreateR0FloorComputation()); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -166,7 +226,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.4f); auto operand2 = builder.ConstantR0<float>(12.6f); - auto floor = CreateR0F32FloorComputation(); + auto floor = CreateR0FloorComputation(); auto result = builder.Conditional(pred, operand1, floor, operand2, floor); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); @@ -178,7 +238,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { ComputationBuilder builder(client_, TestName()); auto pred = builder.ConstantR0<bool>(false); auto operand = builder.ConstantR0<float>(12.6f); - auto floor = CreateR0F32FloorComputation(); + auto floor = CreateR0FloorComputation(); auto result = builder.Conditional(pred, operand, floor, operand, floor); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); @@ -191,9 +251,8 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { auto pred = builder.ConstantR0<bool>(false); auto operand1 = builder.ConstantR0<float>(56.4f); auto operand2 = builder.ConstantR0<float>(12.6f); - auto result = - builder.Conditional(pred, operand1, CreateR0F32FloorComputation(), - operand2, CreateR0F32FloorComputation()); + auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(), + operand2, CreateR0FloorComputation()); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -205,9 +264,8 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0F32CeilComputation(), false_operand, - CreateR0F32FloorComputation()); + inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); auto inner_builder_result = inner_builder.Build(); ComputationBuilder builder(client_, TestName()); @@ -228,8 +286,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), - operands, CreateSubR0Computation()); + auto result = + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), + operands, CreateR0TupleSubComputation()); ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_); } @@ -242,8 +301,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), - operands, CreateSubR0Computation()); + auto result = + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), + operands, CreateR0TupleSubComputation()); ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_); } @@ -256,8 +316,9 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f}); auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), - operands, CreateSubR1Computation()); + auto result = + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), + operands, CreateR1TupleSubComputation()); ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_); } @@ -270,25 +331,192 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f}); auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), - operands, CreateSubR1Computation()); + auto result = + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), + operands, CreateR1TupleSubComputation()); ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_); } +// Test true and false computations that return a tuple of scalars. +XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0<bool>(false); + auto operands = builder.Tuple( + {builder.ConstantR0<float>(12.2f), builder.ConstantR0<float>(25.6f)}); + builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, + CreateR0TupleFloorComputation()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(), + Literal::CreateR0<float>(25.0f).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a tuple of arrays. +// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. +XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0<bool>(true); + auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}), + builder.ConstantR1<float>({25.6f, 29.2f})}); + builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, + CreateR1TupleFloorComputation()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(), + Literal::CreateR1<float>({26.0f, 30.0f}).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a tuple of a predicate, a +// scalar, and an array. +// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. +XLA_TEST_F(ConditionalOpTest, + DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) { + ComputationBuilder true_builder(client_, TestName() + ".true"); + { + true_builder.Parameter(0, empty_tuple_, "tuple"); + auto true_pred = true_builder.ConstantR0<bool>(true); + auto true_scalar = true_builder.ConstantR0<float>(12.2f); + auto true_array = true_builder.ConstantR1<float>({12.8f, 14.6f}); + true_builder.Tuple({true_pred, true_scalar, true_array}); + } + auto true_builder_result = true_builder.Build(); + EXPECT_IS_OK(true_builder_result.status()); + + ComputationBuilder false_builder(client_, TestName() + ".false"); + { + false_builder.Parameter(0, empty_tuple_, "tuple"); + auto false_pred = false_builder.ConstantR0<bool>(false); + auto false_scalar = false_builder.ConstantR0<float>(25.6f); + auto false_array = false_builder.ConstantR1<float>({26.4f, 32.6f}); + false_builder.Tuple({false_pred, false_scalar, false_array}); + } + auto false_builder_result = false_builder.Build(); + EXPECT_IS_OK(false_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0<bool>(true); + auto operands = builder.Tuple({}); + builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), + operands, false_builder_result.ConsumeValueOrDie()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0<bool>(true).get(), + Literal::CreateR0<float>(12.2f).get(), + Literal::CreateR1<float>({12.8f, 14.6f}).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a nested tuple. +// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. +XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) { + ComputationBuilder true_builder(client_, TestName() + ".true"); + { + true_builder.Parameter(0, empty_tuple_, "tuple"); + auto true_constant1 = true_builder.ConstantR0<float>(12.2f); + auto true_constant2 = true_builder.ConstantR1<float>({12.8f, 14.6f}); + auto true_constant3 = true_builder.ConstantR1<float>({25.4f, 29.8f}); + auto true_constant4 = true_builder.ConstantR0<float>(35.6f); + true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}), + true_builder.Tuple({true_constant3, true_constant4})}); + } + auto true_builder_result = true_builder.Build(); + EXPECT_IS_OK(true_builder_result.status()); + + ComputationBuilder false_builder(client_, TestName() + ".false"); + { + false_builder.Parameter(0, empty_tuple_, "tuple"); + auto false_constant1 = false_builder.ConstantR0<float>(46.6f); + auto false_constant2 = false_builder.ConstantR1<float>({54.4f, 58.4f}); + auto false_constant3 = false_builder.ConstantR1<float>({62.1f, 67.4f}); + auto false_constant4 = false_builder.ConstantR0<float>(9.3f); + false_builder.Tuple( + {false_builder.Tuple({false_constant1, false_constant2}), + false_builder.Tuple({false_constant3, false_constant4})}); + } + auto false_builder_result = false_builder.Build(); + EXPECT_IS_OK(false_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0<bool>(false); + auto operands = builder.Tuple({}); + builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), + operands, false_builder_result.ConsumeValueOrDie()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple( + {Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(), + Literal::CreateR1<float>({54.4f, 58.4f}).get()}) + .get(), + Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(), + Literal::CreateR0<float>(9.3f).get()}) + .get()}), + {}, error_spec_); +} + +// Test conditional that takes in scalar operands in the form of external +// params. +XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle pred, operand1, operand2; + auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred); + auto operand1_param = + CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1); + auto operand2_param = + CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2); + auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), + operand2, CreateR0FloorComputation()); + + ComputeAndCompareR0<float>( + &builder, 57.0f, + {pred_arg.get(), operand1_param.get(), operand2_param.get()}, + error_spec_); +} + +// Test conditional that takes in array operands in the form of external params. +XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle pred, operand1, operand2; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); + auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1", + &builder, &operand1); + auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2", + &builder, &operand2); + auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(), + operand2, CreateR1FloorComputation()); + + ComputeAndCompareR1<float>( + &builder, {10.0f, 11.0f}, + {pred_arg.get(), operand1_param.get(), operand2_param.get()}, + error_spec_); +} + // Test the case where one conditional is nested within another. XLA_TEST_F(ConditionalOpTest, NestedConditionals) { - Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); - auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); - auto pred_cond = inner_builder.GetTupleElement(param0, 0); - auto true_operand = inner_builder.GetTupleElement(param0, 1); - auto false_operand = inner_builder.GetTupleElement(param0, 2); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0F32CeilComputation(), false_operand, - CreateR0F32FloorComputation()); + { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); + } auto inner_builder_result = inner_builder.Build(); + EXPECT_IS_OK(inner_builder_result.status()); ComputationBuilder builder(client_, TestName()); auto pred1 = builder.ConstantR0<bool>(true); @@ -299,7 +527,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); builder.Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(), operand3, - CreateR0F32IdentityComputation()); + CreateR0IdentityComputation()); ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); } @@ -311,8 +539,8 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { auto operand1 = builder.ConstantR0<float>(56.0f); auto operand2 = builder.ConstantR0<float>(12.0f); auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateAddR1Computation(), operands, - CreateSubR0Computation()); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR0TupleSubComputation()); auto result = builder.Build(); EXPECT_FALSE(result.ok()); |