diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/params_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/params_test.cc | 62 |
1 files changed, 32 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 2620063aa4..bf3b5f2b65 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -42,7 +42,8 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(3.14159f); + std::unique_ptr<Literal> param0_literal = + LiteralUtil::CreateR0<float>(3.14159f); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -54,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({}); + std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({}); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -67,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); std::unique_ptr<Literal> param0_literal = - Literal::CreateR1<float>({3.14f, -100.25f}); + LiteralUtil::CreateR1<float>({3.14f, -100.25f}); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -80,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr<Literal> param0_literal = Literal::CreateR1U8(str); + std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -94,7 +95,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); std::unique_ptr<Literal> param0_literal = - Literal::CreateR2FromArray2D<float>(Array2D<float>(3, 0)); + LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0)); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -106,7 +107,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>( + std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -122,12 +123,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2}); + std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2}); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20}); + std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20}); std::unique_ptr<GlobalData> param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); @@ -153,7 +154,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr<Literal> literal = Literal::CreateR0<float>(3.14159f); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f); std::unique_ptr<GlobalData> data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -167,12 +168,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2}); + std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2}); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20}); + std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20}); std::unique_ptr<GlobalData> param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); Parameter(&builder, 1, literal1->shape(), "param1"); @@ -187,11 +188,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2}); + std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2}); std::unique_ptr<GlobalData> param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20, 30}); + std::unique_ptr<Literal> literal1 = + LiteralUtil::CreateR1<float>({10, 20, 30}); std::unique_ptr<GlobalData> param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); @@ -231,7 +233,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector<float> sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr<Literal> literal = Literal::CreateR1<float>(sum_value); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -266,7 +268,7 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr<Literal> literal = Literal::CreateR0<float>(i); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -298,7 +300,7 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector<XlaOp> params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i}); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -322,10 +324,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector<std::unique_ptr<Literal>> elements; std::vector<const Literal*> ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1<int32>({target + i, target + i})); + elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -354,7 +356,7 @@ XLA_TEST_F(ParamsTest, std::vector<XlaOp> params; std::vector<Shape> parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i}); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -364,7 +366,7 @@ XLA_TEST_F(ParamsTest, // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr<Literal> bool_literal = Literal::CreateR0<bool>(false); + std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false); param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); XlaOp bool_param = @@ -421,10 +423,10 @@ XLA_TEST_F(ParamsTest, std::vector<std::unique_ptr<Literal>> elements; std::vector<const Literal*> ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1<int32>({i, i})); + elements.push_back(LiteralUtil::CreateR1<int32>({i, i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -441,9 +443,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr<GlobalData> data = client_ - ->TransferToServer(*Literal::MakeTuple({ - Literal::CreateR1<float>({1, 2, 3}).get(), - Literal::CreateR1<float>({4, 5, 6}).get(), + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::CreateR1<float>({1, 2, 3}).get(), + LiteralUtil::CreateR1<float>({4, 5, 6}).get(), })) .ConsumeValueOrDie(); @@ -455,7 +457,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>( + std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal->shape(), "input"); @@ -467,7 +469,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>( + std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal->shape(), "input"); @@ -478,7 +480,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr<Literal> literal = Literal::CreateR2<float>({ + std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({ {1, 3}, {2, 4}, }); |