diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/while_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/while_test.cc | 71 |
1 files changed, 36 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index bbd67cd8d7..0a39778002 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -347,8 +347,8 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // the sum will increase by 1.0. It will first be >15.5 when the elements // have all reached 2.0. auto expected_data = - Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = Literal::MakeTuple({expected_data.get()}); + LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto expected = LiteralUtil::MakeTuple({expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -397,12 +397,13 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(N); - auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f}); - auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f}); - auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f}); - auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); + auto expected_counter = LiteralUtil::CreateR0<int32>(N); + auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f}); + auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f}); + auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f}); + auto expected = + LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -506,11 +507,11 @@ TEST_F(WhileTest, WhileWithTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_data = Literal::CreateR1<float>( + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_data = LiteralUtil::CreateR1<float>( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -554,10 +555,10 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_predicate = Literal::CreateR0<bool>(true); - auto expected = - Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_predicate = LiteralUtil::CreateR0<bool>(true); + auto expected = LiteralUtil::MakeTuple( + {expected_counter.get(), expected_predicate.get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } @@ -599,10 +600,10 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_data = Literal::CreateR0<int32>(7); + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_data = LiteralUtil::CreateR0<int32>(7); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -882,11 +883,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_data = Literal::CreateR1<float>( + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_data = LiteralUtil::CreateR1<float>( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -974,12 +975,12 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); While(cond_computation, body_computation, t); - auto expected_element = Literal::CreateR1<float>({1, 1}); + auto expected_element = LiteralUtil::CreateR1<float>({1, 1}); auto expected = - Literal::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR1<float>({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42}))); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1004,7 +1005,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR1<float>({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42}))); ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1030,7 +1031,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR0<float>(42))); + client_->TransferToServer(*LiteralUtil::CreateR0<float>(42))); ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1069,11 +1070,11 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR0<int32>(1))); + client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1))); - auto add1 = Literal::CreateR0<int32>(15); - auto add2 = Literal::CreateR0<int32>(16); - auto expected = Literal::MakeTuple({add1.get(), add2.get()}); + auto add1 = LiteralUtil::CreateR0<int32>(15); + auto add2 = LiteralUtil::CreateR0<int32>(16); + auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1226,9 +1227,9 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { auto while_instruction = While(condition, body, init); GetTupleElement(while_instruction, 3); - TF_ASSERT_OK_AND_ASSIGN(auto param_value, - client_->TransferToServer(*Literal::CreateR2<float>( - {{1.0, 2.0}, {-1.0, -2.0}}))); + TF_ASSERT_OK_AND_ASSIGN( + auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>( + {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2<float>( &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, |