aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tayo Oguntebi <tayo@google.com>2017-08-30 13:25:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 13:30:14 -0700
commit1bc5cff9dba8baff9c1476c1708adfeb7898d41a (patch)
tree2ac748aaf1276891c1d8a10b93852d8e0777b5d3
parente565d1f1fced69789feb10f1ea1241157ec95f93 (diff)
[XLA] Generalize pooling to support N-dimensional inputs.
Extends the usage of windowed pipeline emitter in pooling to support arbitrary input dimensions > 3. Users should note that memory usage compounds quickly with very-high dimensional tensors. Adds an HLO evaluator test for the R6 pooling case. PiperOrigin-RevId: 167040226
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc64
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc63
4 files changed, 130 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 31cb64d879..e09c9d3beb 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -912,9 +912,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_RET_CHECK(
ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
<< "return shape is set to: "
- << ShapeUtil::HumanString(reduce_window->shape())
+ << ShapeUtil::HumanStringWithLayout(reduce_window->shape())
<< "but is inferred to be: "
- << ShapeUtil::HumanString(inferred_return_shape);
+ << ShapeUtil::HumanStringWithLayout(inferred_return_shape);
const Literal& operand_literal =
parent_->GetEvaluatedLiteralFor(reduce_window->operand(0));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 9d33236df5..a826548349 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1207,6 +1207,70 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) {
LiteralTestUtil::ExpectEqual(*expected, *result);
}
+TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) {
+ HloComputation::Builder b(TestName());
+
+ // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
+ std::vector<int64> input_dims(6, 4);
+ std::unique_ptr<Literal> arg_literal =
+ Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f);
+
+ HloInstruction* arg_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
+
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+
+ HloComputation::Builder add_computation("add");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ add_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
+ HloModule module(TestName());
+ auto add_func = module.AddEmbeddedComputation(add_computation.Build());
+
+ Window window;
+
+ WindowDimension trivial_dim;
+ trivial_dim.set_size(1);
+ trivial_dim.set_stride(1);
+ trivial_dim.set_padding_low(0);
+ trivial_dim.set_padding_high(0);
+ trivial_dim.set_window_dilation(1);
+ trivial_dim.set_base_dilation(1);
+
+ WindowDimension active_dim;
+ active_dim.set_size(2);
+ active_dim.set_stride(1);
+ active_dim.set_padding_low(0);
+ active_dim.set_padding_high(0);
+ active_dim.set_window_dilation(1);
+ active_dim.set_base_dilation(1);
+
+ *window.add_dimensions() = trivial_dim;
+ *window.add_dimensions() = active_dim;
+ *window.add_dimensions() = active_dim;
+ *window.add_dimensions() = active_dim;
+ *window.add_dimensions() = trivial_dim;
+ *window.add_dimensions() = trivial_dim;
+
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
+ b.AddInstruction(HloInstruction::CreateReduceWindow(
+ shape, arg_instruction, init_value, window, add_func));
+
+ auto computation = module.AddEntryComputation(b.Build());
+ std::unique_ptr<Literal> result =
+ evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie();
+
+ std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
+ std::unique_ptr<Literal> result_literal =
+ Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 8.0f);
+ LiteralTestUtil::ExpectEqual(*result_literal, *result);
+}
+
TEST_F(HloEvaluatorTest, StridedSlice) {
HloComputation::Builder b(TestName());
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 0a2d337752..52b2027aec 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -899,6 +899,7 @@ xla_test_library(
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 60d6d19ce1..6ef5c4a8c8 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -319,6 +320,68 @@ TEST_F(ReduceWindowTest, R4UnitWindow) {
ErrorSpec(1e-3, 1e-3));
}
+XLA_TEST_F(HloTestBase, R6Add) {
+ auto b = HloComputation::Builder(TestName());
+
+ std::vector<int64> input_dims(6, 8);
+ std::unique_ptr<Literal> arg_literal =
+ Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f);
+ auto input =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
+
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+
+ HloComputation::Builder add_computation("add");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = add_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ add_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
+
+ auto module = CreateNewModule();
+ auto add_func = module->AddEmbeddedComputation(add_computation.Build());
+
+ WindowDimension trivial_dim;
+ trivial_dim.set_size(1);
+ trivial_dim.set_stride(1);
+ trivial_dim.set_padding_low(0);
+ trivial_dim.set_padding_high(0);
+ trivial_dim.set_window_dilation(1);
+ trivial_dim.set_base_dilation(1);
+
+ WindowDimension active_dim;
+ active_dim.set_size(3);
+ active_dim.set_stride(1);
+ active_dim.set_padding_low(0);
+ active_dim.set_padding_high(0);
+ active_dim.set_window_dilation(1);
+ active_dim.set_base_dilation(1);
+
+ Window window;
+ *window.add_dimensions() = trivial_dim;
+ *window.add_dimensions() = trivial_dim;
+ *window.add_dimensions() = active_dim;
+ *window.add_dimensions() = active_dim;
+ *window.add_dimensions() = trivial_dim;
+ *window.add_dimensions() = trivial_dim;
+
+ Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8});
+ b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value,
+ window, add_func));
+
+ std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
+ std::unique_ptr<Literal> expected =
+ Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 9.0f);
+
+ module->AddEntryComputation(b.Build());
+ auto actual = ExecuteAndTransfer(std::move(module), {});
+
+ LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3));
+}
+
XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) {
Array4D<float> input_array(2, 1, 27, 119);
input_array.FillRandom(2.0f);