diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/reduce_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/reduce_test.cc | 87 |
1 files changed, 60 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index c9f57cbb16..638b0825a1 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -67,12 +67,12 @@ class ReduceTest : public ClientLibraryTestBase { ReduceTest() { // Implementation note: laid out z >> y >> x by default. // clang-format off - literal_2d_ = Literal::CreateR2<float>({ + literal_2d_ = LiteralUtil::CreateR2<float>({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 }); - literal_3d_ = Literal::CreateR3Projected<float>({ + literal_3d_ = LiteralUtil::CreateR3Projected<float>({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 @@ -101,7 +101,7 @@ class ReduceTest : public ClientLibraryTestBase { } } std::unique_ptr<Literal> input_literal = - Literal::CreateR1(AsSlice(input_data)); + LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr<GlobalData> input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -125,15 +125,15 @@ class ReduceTest : public ClientLibraryTestBase { XlaComputation reduce; if (and_reduce) { init_value = ConstantR0<bool>(&builder, true); - reduce = CreateScalarAndComputation(&builder); + reduce = CreateScalarAndComputation(PRED, &builder); } else { init_value = ConstantR0<bool>(&builder, false); - reduce = CreateScalarOrComputation(&builder); + reduce = CreateScalarOrComputation(PRED, &builder); } Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr<Literal> input_literal = Literal::CreateR1(input_data); + std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr<GlobalData> input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -163,10 +163,10 @@ class ReduceTest : public ClientLibraryTestBase { XlaComputation reduce_op; if (and_reduce) { init_value = ConstantR0<bool>(&builder, true); - reduce_op = CreateScalarAndComputation(&builder); + reduce_op = CreateScalarAndComputation(PRED, &builder); } else { init_value = ConstantR0<bool>(&builder, false); - reduce_op = CreateScalarOrComputation(&builder); + reduce_op = CreateScalarOrComputation(PRED, &builder); } Reduce(input_pred, init_value, reduce_op, @@ -175,7 +175,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<uint8> input_data(rows, cols); input_data.FillRandom(0, 1); std::unique_ptr<Literal> input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = @@ -209,7 +209,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr<Literal> input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = @@ -237,7 +237,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr<Literal> input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = @@ -295,7 +295,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<NativeT> input_data(rows, cols); input_data.FillUnique(initial_value); std::unique_ptr<Literal> input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = @@ -450,7 +450,7 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr<Literal> input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr<GlobalData> input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -482,7 +482,7 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr<Literal> input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr<GlobalData> input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -531,7 +531,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D<float> input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); std::unique_ptr<Literal> input_literal = - Literal::CreateR3FromArray3D(input_data); + LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr<GlobalData> input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -594,7 +594,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { auto max = CreateScalarMaxComputation(F32, &builder); Array2D<float> input(300, 250); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); Reduce(ConstantLiteral(&builder, *input_literal), ConstantR0<float>(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; @@ -609,7 +609,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { auto min = CreateScalarMinComputation(F32, &builder); Array2D<float> input(150, 130); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); Reduce(ConstantLiteral(&builder, *input_literal), ConstantR0<float>(&builder, FLT_MAX), min, {0, 1}); @@ -623,7 +623,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { XlaBuilder builder(TestName()); Array2D<uint32> input({{1}, {2}}); auto min = CreateScalarMinComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max()); @@ -635,7 +635,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { XlaBuilder builder(TestName()); Array2D<uint32> input({{1}, {2}}); auto max = CreateScalarMaxComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min()); @@ -798,13 +798,17 @@ XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { RunVectorizedReduceTestForType<bool>( - static_cast<FuncGenerator>(CreateScalarAndComputation), + static_cast<FuncGenerator>([](XlaBuilder* builder) { + return CreateScalarAndComputation(PRED, builder); + }), [](bool a, bool b) { return a && b; }, true); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { RunVectorizedReduceTestForType<bool>( - static_cast<FuncGenerator>(CreateScalarOrComputation), + static_cast<FuncGenerator>([](XlaBuilder* builder) { + return CreateScalarOrComputation(PRED, builder); + }), [](bool a, bool b) { return a || b; }, false); } @@ -818,7 +822,7 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { // input_array.FillRandom(3.14f, 0.05); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR3FromArray3D(input_array); + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr<GlobalData> input_data = @@ -872,7 +876,8 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { auto a = ConstantR0<float>(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({1.0f, 4.0f}); + std::unique_ptr<Literal> b_literal = + LiteralUtil::CreateR1<float>({1.0f, 4.0f}); std::unique_ptr<GlobalData> b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b = Parameter(&builder, 0, b_literal->shape(), "b"); @@ -900,7 +905,7 @@ class ReduceInitializerTest : public ReduceTest { auto init = ConstantR0<T>(&builder, initializer); std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest()); - auto input_literal = Literal::CreateR1<T>(input_arr); + auto input_literal = LiteralUtil::CreateR1<T>(input_arr); auto input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, @@ -950,10 +955,11 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(operand); + std::unique_ptr<Literal> input_literal = + LiteralUtil::CreateR1<float>(operand); std::unique_ptr<GlobalData> input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr<Literal> input_literal2 = Literal::CreateR0<float>(init); + std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init); std::unique_ptr<GlobalData> input_global_data2 = client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0<float>( @@ -961,5 +967,32 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { ErrorSpec(0.0001)); } +XLA_TEST_F(ReduceTest, AndReduceU64) { + XlaBuilder builder(TestName()); + Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL}, + {0XFFFFFFFFFFFFFFD6LL, 101}, + {1, 0XFFFFFFFFFFFFFFFFLL}}; + auto reducer = CreateScalarAndComputation(U64, &builder); + auto m = ConstantR2FromArray2D(&builder, initializer); + Reduce(m, ConstantR0<uint64>(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1}); + + std::vector<uint64> expected = {0x1204461080145890LL, 68, 1}; + ComputeAndCompareR1<uint64>(&builder, expected, {}); +} + +XLA_TEST_F(ReduceTest, OrReduceU64) { + XlaBuilder builder(TestName()); + Array2D<uint64> initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL}, + {0xFFFFFFFFFFFFFFD6LL, 101}, + {1, 0xCAFEBEEFABABABABLL}}; + auto reducer = CreateScalarOrComputation(U64, &builder); + auto m = ConstantR2FromArray2D(&builder, initializer); + Reduce(m, ConstantR0<uint64>(&builder, 0), reducer, {1}); + + std::vector<uint64> expected = {0X3BFDFF7ABEFEFEF0LL, 0XFFFFFFFFFFFFFFF7LL, + 0xCAFEBEEFABABABABLL}; + ComputeAndCompareR1<uint64>(&builder, expected, {}); +} + } // namespace } // namespace xla |