diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/while_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/while_test.cc | 102 |
1 files changed, 66 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index bbd67cd8d7..1bdf1867b9 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -347,8 +347,8 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // the sum will increase by 1.0. It will first be >15.5 when the elements // have all reached 2.0. auto expected_data = - Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = Literal::MakeTuple({expected_data.get()}); + LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto expected = LiteralUtil::MakeTuple({expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -397,12 +397,13 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(N); - auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f}); - auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f}); - auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f}); - auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); + auto expected_counter = LiteralUtil::CreateR0<int32>(N); + auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f}); + auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f}); + auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f}); + auto expected = + LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -506,11 +507,11 @@ TEST_F(WhileTest, WhileWithTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_data = Literal::CreateR1<float>( + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_data = LiteralUtil::CreateR1<float>( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -554,10 +555,10 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_predicate = Literal::CreateR0<bool>(true); - auto expected = - Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_predicate = LiteralUtil::CreateR0<bool>(true); + auto expected = LiteralUtil::MakeTuple( + {expected_counter.get(), expected_predicate.get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } @@ -599,10 +600,10 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_data = Literal::CreateR0<int32>(7); + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_data = LiteralUtil::CreateR0<int32>(7); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -882,11 +883,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0<int32>(5); - auto expected_data = Literal::CreateR1<float>( + auto expected_counter = LiteralUtil::CreateR0<int32>(5); + auto expected_data = LiteralUtil::CreateR1<float>( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -974,12 +975,12 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); While(cond_computation, body_computation, t); - auto expected_element = Literal::CreateR1<float>({1, 1}); + auto expected_element = LiteralUtil::CreateR1<float>({1, 1}); auto expected = - Literal::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR1<float>({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42}))); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1004,7 +1005,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR1<float>({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42}))); ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1030,7 +1031,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR0<float>(42))); + client_->TransferToServer(*LiteralUtil::CreateR0<float>(42))); ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1069,11 +1070,11 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> parameter_data, - client_->TransferToServer(*Literal::CreateR0<int32>(1))); + client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1))); - auto add1 = Literal::CreateR0<int32>(15); - auto add2 = Literal::CreateR0<int32>(16); - auto expected = Literal::MakeTuple({add1.get(), add2.get()}); + auto add1 = LiteralUtil::CreateR0<int32>(15); + auto add2 = LiteralUtil::CreateR0<int32>(16); + auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1226,15 +1227,44 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { auto while_instruction = While(condition, body, init); GetTupleElement(while_instruction, 3); - TF_ASSERT_OK_AND_ASSIGN(auto param_value, - client_->TransferToServer(*Literal::CreateR2<float>( - {{1.0, 2.0}, {-1.0, -2.0}}))); + TF_ASSERT_OK_AND_ASSIGN( + auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>( + {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2<float>( &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, {param_value.get()}, ErrorSpec(4e-5)); } +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { + auto while_shape = ShapeUtil::MakeShape(S32, {}); + + XlaComputation condition; + { + XlaBuilder builder("condition"); + Parameter(&builder, 0, while_shape, "state"); + Infeed(&builder, ShapeUtil::MakeShape(PRED, {})); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + XlaComputation body; + { + XlaBuilder builder("body"); + auto indvar = Parameter(&builder, 0, while_shape, "state"); + Add(indvar, ConstantR0<int32>(&builder, 1)); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + XlaBuilder builder(TestName()); + While(condition, body, ConstantR0<int32>(&builder, 0)); + + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false))); + + ComputeAndCompareR0<int32>(&builder, 2, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); |