aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/reduce_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/reduce_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc506
1 files changed, 506 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
new file mode 100644
index 0000000000..f3d8da5c8c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that multi-dimensional arrays can be reduced among various
+// user-provided dimensions.
+//
+// Note that comments for these tests are white-box in that they talk about the
+// default data layout.
+//
+// The test space for reductions is the cartesian product of:
+//
+// <possible ranks> x
+// <possible layouts for chosen rank> x
+// <possible subsets of dimensions in chosen rank>
+
+#include <stdlib.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReduceTest : public ClientLibraryTestBase {
+ protected:
+ ReduceTest() {
+ // Implementation note: layed out z >> y >> x by default.
+ // clang-format off
+ literal_2d_ = LiteralUtil::CreateR2<float>({
+ // x0 x1 x2
+ { 1.f, 2.f, 3.f}, // y0
+ { 4.f, 5.f, 6.f}, // y1
+ });
+ literal_3d_ = LiteralUtil::CreateR3Projected<float>({
+ // x0 x1 x2
+ { 1.f, 2.f, 3.f}, // y0
+ { 4.f, 5.f, 6.f}, // y1
+ }, 4);
+ // clang-format on
+ CHECK(ShapeUtil::Equal(
+ literal_3d_->shape(),
+ ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
+ << literal_3d_->shape().ShortDebugString();
+ }
+
+ // Runs an R1 => R0 reduction test with the given number of elements.
+ void RunR1ToR0Test(int64 element_count) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ std::vector<float> input_data(element_count);
+ for (int64 i = 0; i < element_count; ++i) {
+ input_data[i] = rand_r(&seed_) % 3;
+ if (rand_r(&seed_) % 2 == 0) {
+ input_data[i] *= -1;
+ }
+ }
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR1(AsSlice(input_data));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ float expected = 0.0;
+ for (float item : input_data) {
+ expected += item;
+ }
+ ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.001));
+ }
+
+ // Runs an R2 => R0 reduction test with the given number of (rows, cols).
+ void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout({minor, major}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ float expected = 0.0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ for (int64 colno = 0; colno < cols; ++colno) {
+ expected += input_data(rowno, colno);
+ }
+ }
+ ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+ }
+
+ // Runs an R2 => R1 reduction test with the given number of (rows, cols).
+ void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout({minor, major}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 colno = 0; colno < cols; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += input_data(rowno, colno);
+ }
+ expected.push_back(column_sum);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+ }
+
+ std::unique_ptr<Literal> literal_2d_;
+ std::unique_ptr<Literal> literal_3d_;
+ uint32 seed_ = 0xdeadbeef;
+};
+
+XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
+XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
+XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
+XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
+XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
+XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
+XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
+XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
+XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
+XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
+ RunR1ToR0Test(16 * 1024 + 1);
+}
+
+XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); }
+XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); }
+XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) {
+ RunR2ToR0Test(111, 50, 0, 1);
+}
+XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); }
+XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); }
+
+// Disabled due to b/33245142. Failed on 2016-11-30.
+// XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); }
+// Disabled due to b/33245142. Failed on 2016-11-30.
+// XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); }
+XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); }
+XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) {
+ RunR2ToR1Test(111, 50, 0, 1);
+}
+XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); }
+XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); }
+
+XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
+ const int64 rows = 111, cols = 50;
+
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto log_ = builder.Log(input);
+ builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal =
+ LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 colno = 0; colno < cols; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += log(input_data(rowno, colno));
+ }
+ expected.push_back(column_sum);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+}
+
+struct BoundsLayout {
+ std::vector<int64> bounds;
+ std::vector<int64> layout;
+ std::vector<int64> reduce_dims;
+};
+
+void PrintTo(const BoundsLayout& spec, std::ostream* os) {
+ *os << tensorflow::strings::Printf(
+ "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(),
+ spec.bounds.size() - spec.reduce_dims.size(),
+ tensorflow::str_util::Join(spec.bounds, "x").c_str(),
+ tensorflow::str_util::Join(spec.layout, "").c_str(),
+ tensorflow::str_util::Join(spec.reduce_dims, "").c_str());
+}
+
+// Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto add = CreateScalarAddComputation(F32, &builder);
+ auto scalar = builder.ConstantR0<float>(42.0);
+ auto broacasted = builder.Broadcast(scalar, {500, 500});
+ builder.Reduce(broacasted, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ float expected = 42.0f * static_cast<float>(500 * 500);
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto max = CreateScalarMaxComputation(F32, &builder);
+ auto scalar = builder.ConstantR0<float>(42.0);
+ auto broacasted = builder.Broadcast(scalar, {500, 500});
+ builder.Reduce(broacasted, builder.ConstantR0<float>(0.0f), max, {0, 1});
+
+ float expected = 42.0f;
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Max-reduces a matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto max = CreateScalarMaxComputation(F32, &builder);
+ Array2D<float> input(300, 250);
+ input.FillRandom(214.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ builder.Reduce(builder.ConstantLiteral(*input_literal),
+ builder.ConstantR0<float>(FLT_MIN), max, {0, 1});
+ auto input_max = FLT_MIN;
+ input.Each(
+ [&](int64, int64, float* v) { input_max = std::max(input_max, *v); });
+ ComputeAndCompareR0<float>(&builder, input_max, {}, ErrorSpec(0.0001));
+}
+
+// Min-reduces matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto min = CreateScalarMinComputation(F32, &builder);
+ Array2D<float> input(150, 130);
+ input.FillRandom(214.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ builder.Reduce(builder.ConstantLiteral(*input_literal),
+ builder.ConstantR0<float>(FLT_MAX), min, {0, 1});
+
+ auto input_min = FLT_MAX;
+ input.Each(
+ [&](int64, int64, float* v) { input_min = std::min(input_min, *v); });
+ ComputeAndCompareR0<float>(&builder, input_min, {}, ErrorSpec(0.0001));
+}
+
+// Reduces a matrix among dimension 1.
+XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+
+ std::vector<float> expected = {6.f, 15.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
+ // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
+}
+
+// Tests 2D matrix ReduceToRow operation.
+XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
+ ComputationBuilder builder(client_, "reduce_among_y");
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+
+ std::vector<float> expected = {5.f, 7.f, 9.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1, 2});
+
+ std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ std::vector<float> expected = {20.f, 28.f, 36.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1, 2});
+
+ float expected = 21.0f * 4.0;
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+
+ // clang-format off
+ Array2D<float> expected({
+ {4.f, 8.f, 12.f},
+ {16.f, 20.f, 24.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+
+ // clang-format off
+ Array2D<float> expected({
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {2});
+
+ // clang-format off
+ Array2D<float> expected({
+ {6.f, 15.f},
+ {6.f, 15.f},
+ {6.f, 15.f},
+ {6.f, 15.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+class ReduceR3ToR2Test : public ReduceTest,
+ public ::testing::WithParamInterface<BoundsLayout> {};
+
+XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
+ ComputationBuilder builder(client_, TestName());
+ const auto& bounds = GetParam().bounds;
+ Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
+ input_array.FillRandom(3.14f, 0.05);
+
+ auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout(GetParam().layout));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ auto input_activations =
+ builder.Parameter(0, input_literal->shape(), "input");
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ auto sum = builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f),
+ add, GetParam().reduce_dims);
+
+ auto expected =
+ ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ ComputeAndCompareR2<float>(&builder, *expected, {input_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test,
+ // Specifies (shape, layout, reduction dimensions).
+ ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}},
+ BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}},
+ BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}},
+ // These should be simplified into a reshape.
+ BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}