aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/while_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/while_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc368
1 files changed, 182 insertions, 186 deletions
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 33d457c70b..89ce2ce797 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -18,10 +18,10 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -54,29 +54,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int32>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -91,29 +90,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) {
auto result_shape = ShapeUtil::MakeShape(S64, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int64>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int64>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int64>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int64>(&builder, 5, {});
}
@@ -123,31 +121,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int32>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
builder.ConstantR0<int32>(0),
CreateScalarAddComputation(S32, &builder), {0});
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -156,28 +153,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) {
auto result_shape = ShapeUtil::MakeShape(PRED, {});
// Create a computation for the condition: run until condition is true.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Ne(builder.ConstantR0<bool>(true), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: or condition with true.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
- auto result = builder.Or(prev, builder.ConstantR0<bool>(true));
+ builder.Or(prev, builder.ConstantR0<bool>(true));
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true));
- auto result = builder.While(condition, body, init);
+ builder.While(condition, body, init);
ComputeAndCompareR0<bool>(&builder, true, {});
}
@@ -194,9 +191,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
@@ -205,33 +202,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 15.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>({});
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>({});
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
}
@@ -247,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) {
Shape result_shape = ShapeUtil::MakeShape(F32, {8});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
@@ -258,33 +256,34 @@ TEST_F(WhileTest, WhileWithVectorResult) {
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 5.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>(8, 0.125f);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>(8, 0.f);
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
// Individual elements with increase by 1/8 each time through the loop, so
// the sum will increase by 1.0. It will first be >15.5 when the elements
@@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
Shape result_shape = ShapeUtil::MakeShape(F32, {8});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
@@ -317,34 +316,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 5.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>(8, 0.125f);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>(8, 0.f);
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
builder.Tuple({result});
// Individual elements with increase by 1/8 each time through the loop, so
@@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
// Create a computation for the condition.
// Repeat for N iterations.
const int N = 2;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(N), iteration);
@@ -377,28 +376,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and permute the weights.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto w1 = builder.GetTupleElement(prev, 1);
auto w2 = builder.GetTupleElement(prev, 2);
auto w3 = builder.GetTupleElement(prev, 3);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
auto result = builder.While(condition, body, init);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(N);
auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
@@ -419,9 +418,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
// Create a computation for the condition.
// Repeat for N iterations.
const int N = 2;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(N), iteration);
@@ -430,21 +429,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
// Create a computation for the body.
// Add 1 to the iteration variable permute the weights.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto w1 = builder.GetTupleElement(prev, 1);
auto w2 = builder.GetTupleElement(prev, 2);
auto w3 = builder.GetTupleElement(prev, 3);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
@@ -455,7 +454,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3));
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
std::vector<float> expected = {6.f, 6.f, 6.f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
@@ -474,9 +473,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -486,26 +485,27 @@ TEST_F(WhileTest, WhileWithTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR1<float>(
@@ -523,9 +523,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -534,27 +534,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and or the predicate with true
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto pred = builder.GetTupleElement(prev, 1);
auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true));
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple({builder.ConstantR0<int32>(0),
builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true))});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_predicate = Literal::CreateR0<bool>(true);
@@ -570,9 +570,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -582,25 +582,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and set the other tuple element to a
// constant.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
- auto result =
- builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
- builder.ConstantR0<int32>(7)});
+ builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
+ builder.ConstantR0<int32>(7)});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR0<int32>(7);
@@ -631,20 +630,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -654,34 +653,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
- Computation body2;
+ XlaComputation body2;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
@@ -692,11 +691,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -710,20 +709,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -733,21 +732,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
@@ -758,11 +757,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -777,20 +776,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
@@ -800,21 +799,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
@@ -824,11 +823,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -844,9 +843,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
@@ -856,9 +855,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
// TupleElement 0
auto iteration = builder.GetTupleElement(prev, 0);
@@ -873,18 +872,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
// UpdateSlice.
auto out1 = builder.DynamicUpdateSlice(input, update, starts);
- auto result = builder.Tuple({out0, out1});
+ builder.Tuple({out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR1<float>(
@@ -915,18 +914,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
// Create a computation for the condition: repeat for count iterations.
auto build_condition = [this, v6s32](int count) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto prev = builder.Reshape(
builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
- {});
+ {});
builder.Gt(builder.ConstantR0<int32>(count), prev);
return builder.Build().ConsumeValueOrDie();
};
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, v6s32, "prev");
auto inc = builder.ConcatInDim(
{builder.ConstantR1<int32>({1}),
@@ -934,16 +933,15 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
builder.ConstantR0<int32>(100),
ShapeUtil::MakeShape(S32, {5}))},
0);
- auto result = builder.Add(inc, prev);
+ builder.Add(inc, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
auto while_loop = [this, &body, build_condition](int count) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
- auto result = builder.While(build_condition(count), body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(build_condition(count), body, init);
return builder.Build();
};
@@ -1107,9 +1105,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
auto inner_result_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
- Computation inner_condition;
+ XlaComputation inner_condition;
{
- ComputationBuilder builder(client_, "inner_condition");
+ XlaBuilder builder("inner_condition");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
builder.Lt(i, builder.ConstantR0<int32>(7));
@@ -1118,9 +1116,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
// Creates a computation for the outer loop condition:
// repeat while result < 30.
- Computation outer_condition;
+ XlaComputation outer_condition;
{
- ComputationBuilder builder(client_, "outer_condition");
+ XlaBuilder builder("outer_condition");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
builder.Lt(prev, builder.ConstantR0<int32>(30));
outer_condition = builder.Build().ConsumeValueOrDie();
@@ -1128,34 +1126,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
// Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
// `result`.
- Computation inner_body;
+ XlaComputation inner_body;
{
- ComputationBuilder builder(client_, "inner_body");
+ XlaBuilder builder("inner_body");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
auto result = builder.GetTupleElement(params, 1);
i = builder.Add(builder.ConstantR0<int32>(1), i);
result = builder.Add(builder.ConstantR0<int32>(2), result);
- auto output = builder.Tuple({i, result});
+ builder.Tuple({i, result});
inner_body = builder.Build().ConsumeValueOrDie();
}
// Creates a computation for the outer loop: run the inner loop with i = 0.
- Computation outer_body;
+ XlaComputation outer_body;
{
- ComputationBuilder builder(client_, "outer_body");
+ XlaBuilder builder("outer_body");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
auto result = builder.While(inner_condition, inner_body, init);
- auto output = builder.GetTupleElement(result, 1);
+ builder.GetTupleElement(result, 1);
outer_body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(outer_condition, outer_body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(outer_condition, outer_body, init);
ComputeAndCompareR0<int32>(&builder, 42, {});
}
@@ -1170,18 +1167,18 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition_callee;
+ XlaComputation condition_callee;
{
- ComputationBuilder builder(client_, "condition_callee");
+ XlaBuilder builder("condition_callee");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
condition_callee = builder.Build().ConsumeValueOrDie();
}
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto result = builder.Call(condition_callee, {prev});
builder.GetTupleElement(result, 0);
@@ -1189,20 +1186,19 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -1214,28 +1210,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
{scalar_s32, matrix_shape, matrix_shape, matrix_shape});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto state = builder.Parameter(0, while_shape, "state");
builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto state = builder.Parameter(0, while_shape, "state");
auto indvar = builder.GetTupleElement(state, 0);
auto input_0 = builder.GetTupleElement(state, 1);
auto input_1 = builder.GetTupleElement(state, 2);
auto output = builder.Tanh(builder.Dot(input_0, input_1));
auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
- auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output});
+ builder.Tuple({indvar_next, input_0, input_1, output});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
@@ -1268,9 +1264,9 @@ void BM_WhileLoop(int num_iters) {
// Create while condition computation with 'loop_limit'.
const int32 loop_limit = 100;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, loop_state_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
@@ -1278,9 +1274,9 @@ void BM_WhileLoop(int num_iters) {
}
// Create while body computation with unit loop increment.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, loop_state_shape, "prev");
// TupleElement 0
auto iteration = builder.GetTupleElement(prev, 0);
@@ -1294,12 +1290,12 @@ void BM_WhileLoop(int num_iters) {
auto starts = builder.ConstantR1<int32>({0, 0, 0});
// UpdateSlice.
auto out1 = builder.DynamicUpdateSlice(input, update, starts);
- auto result = builder.Tuple({out0, out1});
+ builder.Tuple({out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While instruction.
- ComputationBuilder builder(client, "while");
+ XlaBuilder builder("while");
auto zero = builder.ConstantR0<float>(0.0);
auto input = builder.Broadcast(zero, {seq_len, 1024, 1024});
auto init = builder.Tuple({builder.ConstantR0<int32>(0), input});