aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/params_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/params_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc62
1 files changed, 32 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index 2620063aa4..bf3b5f2b65 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -22,9 +22,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -42,7 +42,8 @@ class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(3.14159f);
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -54,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -67,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({3.14f, -100.25f});
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -80,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XlaBuilder builder(TestName());
string str("hello world");
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1U8(str);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -94,7 +95,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
+ LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -106,7 +107,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -122,12 +123,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XLA_TEST_F(ParamsTest, TwoParameters) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20});
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
@@ -153,7 +154,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
XLA_TEST_F(ParamsTest, MissingParameter) {
// Test that an error is returned when a computation with an incomplete set of
// parameters (parameter numbers not contiguous from 0) is executed.
- std::unique_ptr<Literal> literal = Literal::CreateR0<float>(3.14159f);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -167,12 +168,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
XLA_TEST_F(ParamsTest, UnusedParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
Parameter(&builder, 0, literal0->shape(), "param0");
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20});
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
Parameter(&builder, 1, literal1->shape(), "param1");
@@ -187,11 +188,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// unused expression.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20, 30});
+ std::unique_ptr<Literal> literal1 =
+ LiteralUtil::CreateR1<float>({10, 20, 30});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
@@ -231,7 +233,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> sum_value = {{entry0, entry1}};
sum_value.resize(size);
- std::unique_ptr<Literal> literal = Literal::CreateR1<float>(sum_value);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
param_data_owner.push_back(
client_->TransferToServer(*literal).ConsumeValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -266,7 +268,7 @@ XLA_TEST_F(ParamsTest,
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = Literal::CreateR0<float>(i);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -298,7 +300,7 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -322,10 +324,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<std::unique_ptr<Literal>> elements;
std::vector<const Literal*> ptrs;
for (int i = 0; i < kParamCount; ++i) {
- elements.push_back(Literal::CreateR1<int32>({target + i, target + i}));
+ elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
ptrs.push_back(elements.back().get());
}
- ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
}
// Test large number of parameters flowing into a while-loop.
@@ -354,7 +356,7 @@ XLA_TEST_F(ParamsTest,
std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
XlaOp param = Parameter(&builder, i, literal->shape(), "param");
@@ -364,7 +366,7 @@ XLA_TEST_F(ParamsTest,
// Add bool parameter for the loop condition. Use a parameter HLO instead of a
// constant because DCE may eliminate the while-body otherwise.
- std::unique_ptr<Literal> bool_literal = Literal::CreateR0<bool>(false);
+ std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
param_data_owner.push_back(
std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
XlaOp bool_param =
@@ -421,10 +423,10 @@ XLA_TEST_F(ParamsTest,
std::vector<std::unique_ptr<Literal>> elements;
std::vector<const Literal*> ptrs;
for (int i = 0; i < kParamCount; ++i) {
- elements.push_back(Literal::CreateR1<int32>({i, i}));
+ elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
ptrs.push_back(elements.back().get());
}
- ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
}
#endif
@@ -441,9 +443,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*Literal::MakeTuple({
- Literal::CreateR1<float>({1, 2, 3}).get(),
- Literal::CreateR1<float>({4, 5, 6}).get(),
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
}))
.ConsumeValueOrDie();
@@ -455,7 +457,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
XlaBuilder builder(TestName());
Parameter(&builder, 0, literal->shape(), "input");
@@ -467,7 +469,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
XlaBuilder builder(TestName());
Parameter(&builder, 0, literal->shape(), "input");
@@ -478,7 +480,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
}
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
- std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
{1, 3},
{2, 4},
});