aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/conditional_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/conditional_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc64
1 files changed, 30 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 25d10ab00a..32cac499c7 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
- LiteralUtil::CreateR0<float>(25.0f).get()}),
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
+ LiteralUtil::CreateR0<float>(25.0f)}),
{pred_arg.get()}, error_spec_);
}
@@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
CreateR1TupleFloorComputation());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
- LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a tuple of a predicate, a
@@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
false_builder_result.ConsumeValueOrDie());
- ComputeAndCompareTuple(
- &builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<bool>(true).get(),
- LiteralUtil::CreateR0<float>(12.2f).get(),
- LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
- {pred_arg.get()}, error_spec_);
+ ComputeAndCompareTuple(&builder,
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(true),
+ LiteralUtil::CreateR0<float>(12.2f),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
+ {pred_arg.get()}, error_spec_);
}
// Test true and false computations that return a nested tuple.
@@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(46.6f).get(),
- LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
- .get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
- LiteralUtil::CreateR0<float>(9.3f).get()})
- .get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(46.6f),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
+ LiteralUtil::CreateR0<float>(9.3f)})}),
{pred_arg.get()}, error_spec_);
}
@@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
ComputeAndCompareTuple(
&builder,
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
- LiteralUtil::CreateR0<float>(b).get()}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
{x_arg.get(), y_arg.get()}, error_spec_);
};
@@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is true case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(true));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
@@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
{
// Pred is false case.
std::vector<Literal> args;
- args.push_back(std::move(
- *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
- LiteralUtil::CreateR0<int32>(-42).get()})));
- args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
+ args.push_back(
+ LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+ LiteralUtil::CreateR0<int32>(-42)}));
+ args.push_back(LiteralUtil::CreateR0<bool>(false));
XlaBuilder builder(TestName() + ".main");
auto p = Parameter(&builder, 0, tuple2, "p0");
auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");