diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/conditional_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/conditional_test.cc | 138 |
1 files changed, 82 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index ee3c83039b..d9d42bf061 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -172,88 +172,95 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); Conditional(pred, operands, true_computation, operands, false_computation); - ComputeAndCompareR0<float>(&builder, 56.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 56.0f); auto operand2 = ConstantR0<float>(&builder, 12.0f); auto identity = CreateR0IdentityComputation(); Conditional(pred, operand1, identity, operand2, identity); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with two different computations in the true and false cases // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 56.4f); auto operand2 = ConstantR0<float>(&builder, 12.6f); Conditional(pred, operand1, CreateR0CeilComputation(), operand2, CreateR0FloorComputation()); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with two different computations in the true and false cases // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand = ConstantR0<float>(&builder, 12.6f); Conditional(pred, operand, CreateR0CeilComputation(), operand, CreateR0FloorComputation()); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with the same computation in the true and false cases but // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 56.4f); auto operand2 = ConstantR0<float>(&builder, 12.6f); auto floor = CreateR0FloorComputation(); Conditional(pred, operand1, floor, operand2, floor); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with the same computation in the true and false cases that // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand = ConstantR0<float>(&builder, 12.6f); auto floor = CreateR0FloorComputation(); Conditional(pred, operand, floor, operand, floor); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with different instances of the same computation in the true // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 56.4f); auto operand2 = ConstantR0<float>(&builder, 12.6f); Conditional(pred, operand1, CreateR0FloorComputation(), operand2, CreateR0FloorComputation()); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test the case when a call invokes a computation that contains a conditional. @@ -268,75 +275,83 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { auto inner_builder_result = inner_builder.Build(); XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 56.4f); auto operand2 = ConstantR0<float>(&builder, 12.6f); Call(&builder, inner_builder_result.ConsumeValueOrDie(), {pred, operand1, operand2}); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 parameters and predicate is // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 56.0f); auto operand2 = ConstantR0<float>(&builder, 12.0f); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR0TupleAddComputation(), operands, CreateR0TupleSubComputation()); - ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 parameters and predicate is // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 56.0f); auto operand2 = ConstantR0<float>(&builder, 12.0f); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR0TupleAddComputation(), operands, CreateR0TupleSubComputation()); - ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 array parameters and // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred); auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f}); auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f}); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR1TupleAddComputation(), operands, CreateR1TupleSubComputation()); - ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_); + ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()}, + error_spec_); } // Test true and false computations that take in 2 array parameters and // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f}); auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f}); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR1TupleAddComputation(), operands, CreateR1TupleSubComputation()); - ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_); + ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()}, + error_spec_); } // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f), ConstantR0<float>(&builder, 25.6f)}); Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, @@ -344,15 +359,16 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(), - Literal::CreateR0<float>(25.0f).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(), + LiteralUtil::CreateR0<float>(25.0f).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}), ConstantR1<float>(&builder, {25.6f, 29.2f})}); @@ -361,9 +377,10 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(), - Literal::CreateR1<float>({26.0f, 30.0f}).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(), + LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -392,17 +409,19 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {}); Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0<bool>(true).get(), - Literal::CreateR0<float>(12.2f).get(), - Literal::CreateR1<float>({12.8f, 14.6f}).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0<bool>(true).get(), + LiteralUtil::CreateR0<float>(12.2f).get(), + LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. @@ -436,21 +455,24 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {}); Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple( - {Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(), - Literal::CreateR1<float>({54.4f, 58.4f}).get()}) + *LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0<float>(46.6f).get(), + LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()}) .get(), - Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(), - Literal::CreateR0<float>(9.3f).get()}) + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(), + LiteralUtil::CreateR0<float>(9.3f).get()}) .get()}), - {}, error_spec_); + {pred_arg.get()}, error_spec_); } // Test conditional that takes in scalar operands in the form of external @@ -511,8 +533,9 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred1 = ConstantR0<bool>(&builder, true); - auto pred2 = ConstantR0<bool>(&builder, false); + XlaOp pred1, pred2; + auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1); + auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2); auto operand1 = ConstantR0<float>(&builder, 1.1f); auto operand2 = ConstantR0<float>(&builder, 12.2f); auto operand3 = ConstantR0<float>(&builder, 43.3f); @@ -520,7 +543,8 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(), operand3, CreateR0IdentityComputation()); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, + {pred1_arg.get(), pred2_arg.get()}, error_spec_); } XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { @@ -539,13 +563,14 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred2 = ConstantR0<bool>(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0<float>(&builder, 1.1f); auto operand2 = ConstantR0<float>(&builder, 12.2f); - auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2}); + auto tuple_operand = Tuple(&builder, {pred, operand1, operand2}); Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); - ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test a mismatch in the shape of the true operand and true computation. @@ -600,16 +625,17 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { auto test_swap = [&](float a, float b) { XlaBuilder builder(TestName()); - auto x = ConstantR0<float>(&builder, a); - auto y = ConstantR0<float>(&builder, b); + XlaOp x, y; + auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x); + auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y); auto tuple_operand = Tuple(&builder, {x, y}); Call(&builder, main, {tuple_operand}); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0<float>(a).get(), - Literal::CreateR0<float>(b).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(), + LiteralUtil::CreateR0<float>(b).get()}), + {x_arg.get(), y_arg.get()}, error_spec_); }; test_swap(3.11f, 9.4f); |