/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace se = ::perftools::gputools; namespace xla { namespace { class WhileTest : public ClientLibraryTestBase {}; // Tests a while node when the result type T is S32. // // int32 result = 0; // while (result < 5) { // result = result + 1; // } TEST_F(WhileTest, WhileWithScalarResult) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); auto result = builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, TestName()); auto init = builder.ConstantR0(0); auto result = builder.While(condition, body, init); auto shape = builder.GetShape(result).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, 5, {}); } TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto result_shape = ShapeUtil::MakeShape(S32, {}); auto orig_shape = ShapeUtil::MakeShape(S32, {2}); // Create a computation for the condition: repeat for 5 iterations. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); auto result = builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, TestName()); auto init = builder.Reduce(builder.ConstantR1(2, 1), builder.ConstantR0(0), CreateScalarAddComputation(S32, &builder), {0}); auto result = builder.While(condition, body, init); auto shape = builder.GetShape(result).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, 5, {}); } TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Ne(builder.ConstantR0(true), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: or condition with true. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, TestName()); auto init = builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true)); auto result = builder.While(condition, body, init); ComputeAndCompareR0(&builder, true, {}); } // Tests a while node when the result type T is a vector. // // All constants are chosen to produce exact results. // vector result(0); // while (result.sum() < 15.5f) { // result = result + vector(0); // } // TODO(b/29185393): does not terminate on CPU. TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. Computation add; { ComputationBuilder builder(client_, "add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); add = builder.Build().ConsumeValueOrDie(); } // Create a computation for the condition. // Repeat until the sum of the result vector is less than 15.5f. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); auto test = builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1({}); auto result = builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.ConstantR1({}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a vector. // // All constants are chosen to produce exact results. // vector result(8, 0.0f); // while (result.sum() < 15.5f) { // result = result + vector(8, 0.125f); // } TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. Computation add; { ComputationBuilder builder(client_, "add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); add = builder.Build().ConsumeValueOrDie(); } // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); auto test = builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1(8, 0.125f); auto result = builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.ConstantR1(8, 0.f); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements // have all reached 2.0. std::vector expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a Tuple. // // tuple> result(0, vector(10, 0.0f)); // while (get<0>(result) < 5) { // get<0>(result) = get<0>(result) + 1; // get<1>(result) = get<1>(result) + vector(10, 1.0f); // } TEST_F(WhileTest, WhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); // Create a computation for the condition. // Repeat for 5 iterations. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPredicateTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(PRED, {})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); // Create a computation for the condition. // Repeat for 5 iterations. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add 1 to the iteration variable and or the predicate with true Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.Tuple({builder.ConstantR0(0), builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true))}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_predicate = Literal::CreateR0(true); auto expected = Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes. // tuple> w0(0, vector(10, 0.0f)); // w0 = while (get<0>(w0) < c1) { // get<0>(w0) = get<0>(w0) + 1; // get<1>(w0) = get<1>(w0) + vector(10, 1.0f); // } // tuple> w1(get<0>(w0), get<1>(w0)); // w1 = while (get<0>(w1) < c2) { // get<0>(w1) = get<0>(w1) + 1; // get<1>(w1) = get<1>(w1) + vector(10, 1.0f); // } // result = get<1>(w0) + get<1>(w1) TEST_F(WhileTest, TwoWhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); // Create a computation for the condition. // Repeat for 5 iterations. Computation condition; const int c1 = 5; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); } Computation condition2; const int c2 = 7; { ComputationBuilder builder(client_, "condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); } // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); } Computation body2; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSIGN_OR_ASSERT_OK(body2, builder.Build()); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); auto while2 = builder.While(condition2, body2, while1); auto while_result1 = builder.GetTupleElement(while1, 1); auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( *builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } // Test while nodes that share the while body computation. TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); // Create a computation for the condition. // Repeat for 5 iterations. Computation condition; const int c1 = 5; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); } Computation condition2; const int c2 = 7; { ComputationBuilder builder(client_, "condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); } // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); auto while2 = builder.While(condition2, body, while1); auto while_result1 = builder.GetTupleElement(while1, 1); auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( *builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } // Test while nodes that share the while body computation. // TODO(b/37245345): Fails on GPU backend. TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); // Create a computation for the condition. // Repeat for 5 iterations. Computation condition; const int c1 = 5; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); } Computation condition2; const int c2 = 7; { ComputationBuilder builder(client_, "condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); } // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); auto while2 = builder.While(condition2, body, init); auto while_result1 = builder.GetTupleElement(while1, 1); auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( *builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } // WhileTest that uses DynamicUpdateSlice instruction in body computation. // Loop state tuple element 1 has as its single user operand(0) of // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU. XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); // Create a computation for the condition. // Repeat for 5 iterations. Computation condition; { ComputationBuilder builder(client_, "condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); auto out0 = builder.Add(iteration, builder.ConstantR0(1)); // TupleElement 1 auto input = builder.GetTupleElement(prev, 1); // Update. auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32); // Starts = iteration * 2; auto starts = builder.Reshape( builder.Mul(iteration, builder.ConstantR0(2)), {1}); // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); auto result = builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, "while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a vector of S32. // // int32 result = (0, 0, 0, 0, 0, 0); // while (result[0] < count) { // result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]); // } // // This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a // pair: // ((iteration, (random vector))). // // Note: this test currently only tests generating random values within a loop. // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. TEST_F(WhileTest, WhileWithPrngScalarResult) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { ComputationBuilder builder(client_, TestName()); auto prev = builder.Reshape( builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {}); builder.Gt(builder.ConstantR0(count), prev); return builder.Build().ConsumeValueOrDie(); }; // Create a computation for the body: add 1 to the result variable. Computation body; { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, v6s32, "prev"); auto inc = builder.ConcatInDim( {builder.ConstantR1({1}), builder.RngUniform(builder.ConstantR0(0), builder.ConstantR0(100), ShapeUtil::MakeShape(S32, {5}))}, 0); auto result = builder.Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { ComputationBuilder builder(client_, TestName()); auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); auto result = builder.While(build_condition(count), body, init); auto shape = builder.GetShape(result).ConsumeValueOrDie(); return builder.Build(); }; for (int i = 1; i < 4; ++i) { TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i)); ExecutionOptions execution_options = execution_options_; execution_options.set_seed(65); TF_ASSIGN_OR_ASSERT_OK( auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); } } // Tests nested while loops. // // int32 result = 0; // while (result < 30) { // int i = 0; // while (i < 7) { // result = result + 2; // i = i + 1; // } // } XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { auto outer_result_shape = ShapeUtil::MakeShape(S32, {}); auto inner_result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); Computation inner_condition; { ComputationBuilder builder(client_, "inner_condition"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); builder.Lt(i, builder.ConstantR0(7)); inner_condition = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the outer loop condition: // repeat while result < 30. Computation outer_condition; { ComputationBuilder builder(client_, "outer_condition"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); builder.Lt(prev, builder.ConstantR0(30)); outer_condition = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to // `result`. Computation inner_body; { ComputationBuilder builder(client_, "inner_body"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); auto result = builder.GetTupleElement(params, 1); i = builder.Add(builder.ConstantR0(1), i); result = builder.Add(builder.ConstantR0(2), result); auto output = builder.Tuple({i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the outer loop: run the inner loop with i = 0. Computation outer_body; { ComputationBuilder builder(client_, "outer_body"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); auto init = builder.Tuple({builder.ConstantR0(0), prev}); auto result = builder.While(inner_condition, inner_body, init); auto output = builder.GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. ComputationBuilder builder(client_, TestName()); auto init = builder.ConstantR0(0); auto result = builder.While(outer_condition, outer_body, init); auto shape = builder.GetShape(result).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, 42, {}); } void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); StreamExecutorMemoryAllocator allocator(platform, executors); LocalClient* client = ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); Shape loop_state_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}); // Create while condition computation with 'loop_limit'. const int32 loop_limit = 100; Computation condition; { ComputationBuilder builder(client, "condition"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(loop_limit)); condition = builder.Build().ConsumeValueOrDie(); } // Create while body computation with unit loop increment. Computation body; { ComputationBuilder builder(client, "body"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto one = builder.ConstantR0(1); auto next_iteration = builder.Add(iteration, one); auto one_vec = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, one_vec); auto result = builder.Tuple({next_iteration, new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. ComputationBuilder builder(client, "while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); builder.While(condition, body, init); auto computation = builder.Build().ConsumeValueOrDie(); std::unique_ptr executable = client->Compile(computation, {}, ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. ExecutableRunOptions options; options.set_allocator(&allocator); const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { auto result = executable->Run({}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { auto result = executable->Run({}, options); ASSERT_TRUE(result.ok()); } } // TODO(b/32470510): Benchmark fails on parallel CPU backend. #ifndef XLA_TEST_BACKEND_CPU_PARALLEL BENCHMARK(BM_WhileLoop); #endif } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { LOG(ERROR) << "\n" << usage; return 2; } testing::InitGoogleTest(&argc, argv); if (argc > 1) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; } tensorflow::testing::RunBenchmarks(); return RUN_ALL_TESTS(); }