aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/while_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/while_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc66
1 files changed, 31 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1bdf1867b9..7abd8651d5 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// have all reached 2.0.
auto expected_data =
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = LiteralUtil::MakeTuple({expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
@@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple(
+ {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
@@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
TEST_F(WhileTest, WhileWithPredicateTupleResult) {
@@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
- auto expected = LiteralUtil::MakeTuple(
- {expected_counter.get(), expected_predicate.get()});
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
+ auto expected =
+ LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
}
TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
@@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR0<int32>(7);
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests two while nodes when the result type T is a Tuple and the second
@@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
auto expected_counter = LiteralUtil::CreateR0<int32>(5);
auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
- auto expected =
- LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
- VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
- ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+ auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+ ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
}
// Tests a while node when the result type T is a vector of S32.
@@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({&expected_element, &expected_element});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
+ client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
+ client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
+ client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
auto add1 = LiteralUtil::CreateR0<int32>(15);
auto add2 = LiteralUtil::CreateR0<int32>(16);
- auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
- ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+ auto expected = LiteralUtil::MakeTuple({&add1, &add2});
+ ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
GetTupleElement(while_instruction, 3);
TF_ASSERT_OK_AND_ASSIGN(
- auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
{{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
@@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
XlaBuilder builder(TestName());
While(condition, body, ConstantR0<int32>(&builder, 0));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
ComputeAndCompareR0<int32>(&builder, 2, {});
}