aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/while_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/while_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc71
1 files changed, 36 insertions, 35 deletions
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index bbd67cd8d7..0a39778002 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -347,8 +347,8 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
// the sum will increase by 1.0. It will first be >15.5 when the elements
// have all reached 2.0.
auto expected_data =
- Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = Literal::MakeTuple({expected_data.get()});
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
+ auto expected = LiteralUtil::MakeTuple({expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -397,12 +397,13 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(N);
- auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
- auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f});
- auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
+ auto expected_counter = LiteralUtil::CreateR0<int32>(N);
+ auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
+ auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
+ auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
+ auto expected =
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
+ expected_w3.get(), expected_w1.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -506,11 +507,11 @@ TEST_F(WhileTest, WhileWithTupleResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR1<float>(
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -554,10 +555,10 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_predicate = Literal::CreateR0<bool>(true);
- auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_predicate.get()});
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
+ auto expected = LiteralUtil::MakeTuple(
+ {expected_counter.get(), expected_predicate.get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
}
@@ -599,10 +600,10 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR0<int32>(7);
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR0<int32>(7);
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -882,11 +883,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR1<float>(
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -974,12 +975,12 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
While(cond_computation, body_computation, t);
- auto expected_element = Literal::CreateR1<float>({1, 1});
+ auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- Literal::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1004,7 +1005,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1030,7 +1031,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR0<float>(42)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1069,11 +1070,11 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR0<int32>(1)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
- auto add1 = Literal::CreateR0<int32>(15);
- auto add2 = Literal::CreateR0<int32>(16);
- auto expected = Literal::MakeTuple({add1.get(), add2.get()});
+ auto add1 = LiteralUtil::CreateR0<int32>(15);
+ auto add2 = LiteralUtil::CreateR0<int32>(16);
+ auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1226,9 +1227,9 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
auto while_instruction = While(condition, body, init);
GetTupleElement(while_instruction, 3);
- TF_ASSERT_OK_AND_ASSIGN(auto param_value,
- client_->TransferToServer(*Literal::CreateR2<float>(
- {{1.0, 2.0}, {-1.0, -2.0}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ {{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
&builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},