From 1bc5cff9dba8baff9c1476c1708adfeb7898d41a Mon Sep 17 00:00:00 2001 From: Tayo Oguntebi Date: Wed, 30 Aug 2017 13:25:56 -0700 Subject: [XLA] Generalize pooling to support N-dimensional inputs. Extends the usage of windowed pipeline emitter in pooling to support arbitrary input dimensions > 3. Users should note that memory usage compounds quickly with very-high dimensional tensors. Adds an HLO evaluator test for the R6 pooling case. PiperOrigin-RevId: 167040226 --- tensorflow/compiler/xla/service/hlo_evaluator.cc | 4 +- .../compiler/xla/service/hlo_evaluator_test.cc | 64 ++++++++++++++++++++++ tensorflow/compiler/xla/tests/BUILD | 1 + .../compiler/xla/tests/reduce_window_test.cc | 63 +++++++++++++++++++++ 4 files changed, 130 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 31cb64d879..e09c9d3beb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -912,9 +912,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK( ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) << "return shape is set to: " - << ShapeUtil::HumanString(reduce_window->shape()) + << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); + << ShapeUtil::HumanStringWithLayout(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 9d33236df5..a826548349 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1207,6 +1207,70 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { + HloComputation::Builder b(TestName()); + + // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. + std::vector input_dims(6, 4); + std::unique_ptr arg_literal = + Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + HloModule module(TestName()); + auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + + Window window; + + WindowDimension trivial_dim; + trivial_dim.set_size(1); + trivial_dim.set_stride(1); + trivial_dim.set_padding_low(0); + trivial_dim.set_padding_high(0); + trivial_dim.set_window_dilation(1); + trivial_dim.set_base_dilation(1); + + WindowDimension active_dim; + active_dim.set_size(2); + active_dim.set_stride(1); + active_dim.set_padding_low(0); + active_dim.set_padding_high(0); + active_dim.set_window_dilation(1); + active_dim.set_base_dilation(1); + + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = trivial_dim; + + Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, add_func)); + + auto computation = module.AddEntryComputation(b.Build()); + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + std::vector output_dims = {4, 3, 3, 3, 4, 4}; + std::unique_ptr result_literal = + Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 8.0f); + LiteralTestUtil::ExpectEqual(*result_literal, *result); +} + TEST_F(HloEvaluatorTest, StridedSlice) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 0a2d337752..52b2027aec 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -899,6 +899,7 @@ xla_test_library( "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 60d6d19ce1..6ef5c4a8c8 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -319,6 +320,68 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { ErrorSpec(1e-3, 1e-3)); } +XLA_TEST_F(HloTestBase, R6Add) { + auto b = HloComputation::Builder(TestName()); + + std::vector input_dims(6, 8); + std::unique_ptr arg_literal = + Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + auto input = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + + auto module = CreateNewModule(); + auto add_func = module->AddEmbeddedComputation(add_computation.Build()); + + WindowDimension trivial_dim; + trivial_dim.set_size(1); + trivial_dim.set_stride(1); + trivial_dim.set_padding_low(0); + trivial_dim.set_padding_high(0); + trivial_dim.set_window_dilation(1); + trivial_dim.set_base_dilation(1); + + WindowDimension active_dim; + active_dim.set_size(3); + active_dim.set_stride(1); + active_dim.set_padding_low(0); + active_dim.set_padding_high(0); + active_dim.set_window_dilation(1); + active_dim.set_base_dilation(1); + + Window window; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = trivial_dim; + + Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8}); + b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value, + window, add_func)); + + std::vector output_dims = {8, 8, 6, 6, 8, 8}; + std::unique_ptr expected = + Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); + + module->AddEntryComputation(b.Build()); + auto actual = ExecuteAndTransfer(std::move(module), {}); + + LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); +} + XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); -- cgit v1.2.3