diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/while_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/while_test.cc | 368 |
1 files changed, 182 insertions, 186 deletions
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 33d457c70b..89ce2ce797 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -54,29 +54,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0<int32>(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int32>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int32>(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int32>(&builder, 5, {}); } @@ -91,29 +90,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0<int64>(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int64>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int64>(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int64>(&builder, 5, {}); } @@ -123,31 +121,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto orig_shape = ShapeUtil::MakeShape(S32, {2}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0<int32>(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int32>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1), builder.ConstantR0<int32>(0), CreateScalarAddComputation(S32, &builder), {0}); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int32>(&builder, 5, {}); } @@ -156,28 +153,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Ne(builder.ConstantR0<bool>(true), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: or condition with true. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Or(prev, builder.ConstantR0<bool>(true)); + builder.Or(prev, builder.ConstantR0<bool>(true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Ne(builder.ConstantR0<bool>(false), builder.ConstantR0<bool>(true)); - auto result = builder.While(condition, body, init); + builder.While(condition, body, init); ComputeAndCompareR0<bool>(&builder, true, {}); } @@ -194,9 +191,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -205,33 +202,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 15.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); + builder.Gt(builder.ConstantR0<float>(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1<float>({}); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1<float>({}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001)); } @@ -247,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -258,33 +256,34 @@ TEST_F(WhileTest, WhileWithVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); + builder.Gt(builder.ConstantR0<float>(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1<float>(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1<float>(8, 0.f); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements @@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -317,34 +316,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum); + builder.Gt(builder.ConstantR0<float>(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1<float>(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1<float>(8, 0.f); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); builder.Tuple({result}); // Individual elements with increase by 1/8 each time through the loop, so @@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(N), iteration); @@ -377,28 +376,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f), builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(N); auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f}); @@ -419,9 +418,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(N), iteration); @@ -430,21 +429,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the body. // Add 1 to the iteration variable permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f), builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)}); @@ -455,7 +454,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); std::vector<float> expected = {6.f, 6.f, 6.f}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -474,9 +473,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -486,26 +485,27 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_data = Literal::CreateR1<float>( @@ -523,9 +523,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -534,27 +534,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and or the predicate with true - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true)); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple({builder.ConstantR0<int32>(0), builder.Ne(builder.ConstantR0<bool>(false), builder.ConstantR0<bool>(true))}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_predicate = Literal::CreateR0<bool>(true); @@ -570,9 +570,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -582,25 +582,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the body. // Add 1 to the iteration variable and set the other tuple element to a // constant. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); - auto result = - builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)), - builder.ConstantR0<int32>(7)}); + builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)), + builder.ConstantR0<int32>(7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_data = Literal::CreateR0<int32>(7); @@ -631,20 +630,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); @@ -654,34 +653,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - Computation body2; + XlaComputation body2; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -692,11 +691,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector<float> expected(10, sum); ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -710,20 +709,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); @@ -733,21 +732,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -758,11 +757,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector<float> expected(10, sum); ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -777,20 +776,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(c2)); @@ -800,21 +799,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1<float>(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -824,11 +823,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector<float> expected(10, sum); ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); @@ -844,9 +843,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0<int32>(5), iteration); @@ -856,9 +855,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -873,18 +872,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0<int32>(5); auto expected_data = Literal::CreateR1<float>( @@ -915,18 +914,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prev = builder.Reshape( builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); + {}); builder.Gt(builder.ConstantR0<int32>(count), prev); return builder.Build().ConsumeValueOrDie(); }; // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, v6s32, "prev"); auto inc = builder.ConcatInDim( {builder.ConstantR1<int32>({1}), @@ -934,16 +933,15 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { builder.ConstantR0<int32>(100), ShapeUtil::MakeShape(S32, {5}))}, 0); - auto result = builder.Add(inc, prev); + builder.Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0}); - auto result = builder.While(build_condition(count), body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(build_condition(count), body, init); return builder.Build(); }; @@ -1107,9 +1105,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { auto inner_result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - Computation inner_condition; + XlaComputation inner_condition; { - ComputationBuilder builder(client_, "inner_condition"); + XlaBuilder builder("inner_condition"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); builder.Lt(i, builder.ConstantR0<int32>(7)); @@ -1118,9 +1116,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the outer loop condition: // repeat while result < 30. - Computation outer_condition; + XlaComputation outer_condition; { - ComputationBuilder builder(client_, "outer_condition"); + XlaBuilder builder("outer_condition"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); builder.Lt(prev, builder.ConstantR0<int32>(30)); outer_condition = builder.Build().ConsumeValueOrDie(); @@ -1128,34 +1126,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to // `result`. - Computation inner_body; + XlaComputation inner_body; { - ComputationBuilder builder(client_, "inner_body"); + XlaBuilder builder("inner_body"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); auto result = builder.GetTupleElement(params, 1); i = builder.Add(builder.ConstantR0<int32>(1), i); result = builder.Add(builder.ConstantR0<int32>(2), result); - auto output = builder.Tuple({i, result}); + builder.Tuple({i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the outer loop: run the inner loop with i = 0. - Computation outer_body; + XlaComputation outer_body; { - ComputationBuilder builder(client_, "outer_body"); + XlaBuilder builder("outer_body"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev}); auto result = builder.While(inner_condition, inner_body, init); - auto output = builder.GetTupleElement(result, 1); + builder.GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int32>(0); - auto result = builder.While(outer_condition, outer_body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(outer_condition, outer_body, init); ComputeAndCompareR0<int32>(&builder, 42, {}); } @@ -1170,18 +1167,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition_callee; + XlaComputation condition_callee; { - ComputationBuilder builder(client_, "condition_callee"); + XlaBuilder builder("condition_callee"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto result = builder.Call(condition_callee, {prev}); builder.GetTupleElement(result, 0); @@ -1189,20 +1186,19 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0<int32>(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0<int32>(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0<int32>(&builder, 5, {}); } @@ -1214,28 +1210,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_shape, "state"); builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_shape, "state"); auto indvar = builder.GetTupleElement(state, 0); auto input_0 = builder.GetTupleElement(state, 1); auto input_1 = builder.GetTupleElement(state, 2); auto output = builder.Tanh(builder.Dot(input_0, input_1)); auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1)); - auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + builder.Tuple({indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); auto init = builder.Tuple( {builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input}); @@ -1268,9 +1264,9 @@ void BM_WhileLoop(int num_iters) { // Create while condition computation with 'loop_limit'. const int32 loop_limit = 100; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit)); @@ -1278,9 +1274,9 @@ void BM_WhileLoop(int num_iters) { } // Create while body computation with unit loop increment. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -1294,12 +1290,12 @@ void BM_WhileLoop(int num_iters) { auto starts = builder.ConstantR1<int32>({0, 0, 0}); // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. - ComputationBuilder builder(client, "while"); + XlaBuilder builder("while"); auto zero = builder.ConstantR0<float>(0.0); auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); auto init = builder.Tuple({builder.ConstantR0<int32>(0), input}); |