diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/reduce_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/reduce_test.cc | 123 |
1 files changed, 56 insertions, 67 deletions
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 57f7fed61f..83997cdac2 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -81,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase { }, 4); // clang-format on CHECK(ShapeUtil::Equal( - literal_3d_->shape(), + literal_3d_.shape(), ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3}))) - << literal_3d_->shape().ShortDebugString(); + << literal_3d_.shape().ShortDebugString(); } // Runs an R1 => R0 reduction test with the given number of elements. @@ -102,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase { input_data[i] *= -1; } } - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (float item : input_data) { @@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase { Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data); + Literal input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); bool expected = and_reduce; for (bool item : input_data) { @@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<uint8> input_data(rows, cols); input_data.FillRandom(0, 1); - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::array<bool, cols> expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (int64 rowno = 0; rowno < rows; ++rowno) { @@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector<float> expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -295,12 +291,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D<NativeT> input_data(rows, cols); input_data.FillUnique(initial_value); - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector<bool> does not convert to // Span. @@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase { reference_reduction_function_for_uints, unsigned_int_identity); } - std::unique_ptr<Literal> literal_2d_; - std::unique_ptr<Literal> literal_3d_; + Literal literal_2d_; + Literal literal_3d_; uint32 seed_ = 0xdeadbeef; }; @@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector<float> expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D<float> input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector<float> expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data, - MakeFakeLiteral(input_shape)); + TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { @@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D<float> input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector<float> expected; for (int64 major = 0; major < 2; ++major) { @@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D<float> input(300, 250); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0<float>(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( @@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D<float> input(150, 130); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0<float>(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; @@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto initial_value = ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0<uint32>(&builder, 1, {}); } @@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto initial_value = ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0<uint32>(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1}); @@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1}); @@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0}); @@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2}); @@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1}); @@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2}); @@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0}); @@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1}); @@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2}); @@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); + input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr<GlobalData> input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add, GetParam().reduce_dims); @@ -873,11 +864,10 @@ XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { auto a = ConstantR0<float>(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr<Literal> b_literal = - LiteralUtil::CreateR1<float>({1.0f, 4.0f}); + Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f}); std::unique_ptr<GlobalData> b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b = Parameter(&builder, 0, b_literal.shape(), "b"); Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()}); @@ -904,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest { std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest()); auto input_literal = LiteralUtil::CreateR1<T>(input_arr); auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, - max_fn, {0}); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn, + {0}); ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()}); } @@ -952,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr<Literal> input_literal = - LiteralUtil::CreateR1<float>(operand); + Literal input_literal = LiteralUtil::CreateR1<float>(operand); std::unique_ptr<GlobalData> input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Literal input_literal2 = LiteralUtil::CreateR0<float>(init); std::unique_ptr<GlobalData> input_global_data2 = - client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + client_->TransferToServer(input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0<float>( &builder, expected, {input_global_data.get(), input_global_data2.get()}, ErrorSpec(0.0001)); |