aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/conditional_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/conditional_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc138
1 files changed, 82 insertions, 56 deletions
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index ee3c83039b..d9d42bf061 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -172,88 +172,95 @@ class ConditionalOpTest : public ClientLibraryTestBase {
// Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest, Parameters0) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, true);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
auto true_computation = CreateR0ConstantComputation(56.0f);
auto false_computation = CreateR0ConstantComputation(12.0f);
Conditional(pred, operands, true_computation, operands, false_computation);
- ComputeAndCompareR0<float>(&builder, 56.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
}
// Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest, Parameters1) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto identity = CreateR0IdentityComputation();
Conditional(pred, operand1, identity, operand2, identity);
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with two different computations in the true and false cases
// that take in different arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
CreateR0FloorComputation());
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with two different computations in the true and false cases
// that take in the same arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand, CreateR0CeilComputation(), operand,
CreateR0FloorComputation());
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with the same computation in the true and false cases but
// take in different arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
Conditional(pred, operand1, floor, operand2, floor);
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with the same computation in the true and false cases that
// take in the same arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
Conditional(pred, operand, floor, operand, floor);
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test conditional with different instances of the same computation in the true
// and false cases.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
CreateR0FloorComputation());
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test the case when a call invokes a computation that contains a conditional.
@@ -268,75 +275,83 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
auto inner_builder_result = inner_builder.Build();
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.4f);
auto operand2 = ConstantR0<float>(&builder, 12.6f);
Call(&builder, inner_builder_result.ConsumeValueOrDie(),
{pred, operand1, operand2});
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test true and false computations that take in 2 parameters and predicate is
// true.
XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, true);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
CreateR0TupleSubComputation());
- ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_);
}
// Test true and false computations that take in 2 parameters and predicate is
// false.
XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 56.0f);
auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
CreateR0TupleSubComputation());
- ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_);
}
// Test true and false computations that take in 2 array parameters and
// predicate is true.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, true);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
CreateR1TupleSubComputation());
- ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_);
+ ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()},
+ error_spec_);
}
// Test true and false computations that take in 2 array parameters and
// predicate is false.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
auto operands = Tuple(&builder, {operand1, operand2});
Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
CreateR1TupleSubComputation());
- ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_);
+ ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()},
+ error_spec_);
}
// Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
ConstantR0<float>(&builder, 25.6f)});
Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
@@ -344,15 +359,16 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(),
- Literal::CreateR0<float>(25.0f).get()}),
- {}, error_spec_);
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
+ LiteralUtil::CreateR0<float>(25.0f).get()}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, true);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands =
Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
ConstantR1<float>(&builder, {25.6f, 29.2f})});
@@ -361,9 +377,10 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(),
- Literal::CreateR1<float>({26.0f, 30.0f}).get()}),
- {}, error_spec_);
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
@@ -392,17 +409,19 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
EXPECT_IS_OK(false_builder_result.status());
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, true);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<bool>(true).get(),
- Literal::CreateR0<float>(12.2f).get(),
- Literal::CreateR1<float>({12.8f, 14.6f}).get()}),
- {}, error_spec_);
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<bool>(true).get(),
+ LiteralUtil::CreateR0<float>(12.2f).get(),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
@@ -436,21 +455,24 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
EXPECT_IS_OK(false_builder_result.status());
XlaBuilder builder(TestName());
- auto pred = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operands = Tuple(&builder, {});
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple(
- {Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(),
- Literal::CreateR1<float>({54.4f, 58.4f}).get()})
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(46.6f).get(),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
.get(),
- Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(),
- Literal::CreateR0<float>(9.3f).get()})
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
+ LiteralUtil::CreateR0<float>(9.3f).get()})
.get()}),
- {}, error_spec_);
+ {pred_arg.get()}, error_spec_);
}
// Test conditional that takes in scalar operands in the form of external
@@ -511,8 +533,9 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
EXPECT_IS_OK(inner_builder_result.status());
XlaBuilder builder(TestName());
- auto pred1 = ConstantR0<bool>(&builder, true);
- auto pred2 = ConstantR0<bool>(&builder, false);
+ XlaOp pred1, pred2;
+ auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1);
+ auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2);
auto operand1 = ConstantR0<float>(&builder, 1.1f);
auto operand2 = ConstantR0<float>(&builder, 12.2f);
auto operand3 = ConstantR0<float>(&builder, 43.3f);
@@ -520,7 +543,8 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(),
operand3, CreateR0IdentityComputation());
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f,
+ {pred1_arg.get(), pred2_arg.get()}, error_spec_);
}
XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
@@ -539,13 +563,14 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
EXPECT_IS_OK(inner_builder_result.status());
XlaBuilder builder(TestName());
- auto pred2 = ConstantR0<bool>(&builder, false);
+ XlaOp pred;
+ auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
auto operand1 = ConstantR0<float>(&builder, 1.1f);
auto operand2 = ConstantR0<float>(&builder, 12.2f);
- auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
+ auto tuple_operand = Tuple(&builder, {pred, operand1, operand2});
Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
- ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+ ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
}
// Test a mismatch in the shape of the true operand and true computation.
@@ -600,16 +625,17 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
auto test_swap = [&](float a, float b) {
XlaBuilder builder(TestName());
- auto x = ConstantR0<float>(&builder, a);
- auto y = ConstantR0<float>(&builder, b);
+ XlaOp x, y;
+ auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
+ auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
auto tuple_operand = Tuple(&builder, {x, y});
Call(&builder, main, {tuple_operand});
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<float>(a).get(),
- Literal::CreateR0<float>(b).get()}),
- {}, error_spec_);
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
+ LiteralUtil::CreateR0<float>(b).get()}),
+ {x_arg.get(), y_arg.get()}, error_spec_);
};
test_swap(3.11f, 9.4f);