diff options
author | 2018-10-06 10:04:16 -0700 | |
---|---|---|
committer | 2018-10-06 10:09:36 -0700 | |
commit | 5c0a6bdfeb1848b0146a36706d921dde06ba160a (patch) | |
tree | e549be74d1f90165865102536d45cc1b4a2a75a0 /tensorflow/compiler/xla/service | |
parent | 262f22f9eeee1ee00a9a92318d9a567a25c76696 (diff) |
[XLA] Add base and window dilation support to ReduceWindow
PiperOrigin-RevId: 216041507
Diffstat (limited to 'tensorflow/compiler/xla/service')
5 files changed, 105 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 75dae7a714..86d9dbea90 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2057,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return Status::OK(); } + // Bail on dilation. + if (window_util::HasDilation(window)) { + VLOG(10) << "Not folding pad into reduce-window as there is dilation."; + return Status::OK(); + } + VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() << (convert != nullptr diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index a70abb117a..b2abdb39a5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -688,8 +688,25 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + input_index[i] = NSWSub( + NSWAdd(strided_index, + NSWMul(window_index[i], + b_.getInt64(window.dimensions(i).window_dilation()))), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); + if (in_bounds_condition == nullptr) { + in_bounds_condition = dilation_condition; + } else { + in_bounds_condition = And(in_bounds_condition, dilation_condition); + } + + // Apply base dilation to the index. + input_index[i] = + SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to @@ -728,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { /*operands=*/{reduce_window->operand(0)}, /*supported_types=*/{F32, BF16, S32, F16})); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(reduce_window->window())) { - return Unimplemented( - "Dilation for ReduceWindow is not implemented on CPU."); - } - // Pseudo code for reduce window: // // for (coordinates O in the output) diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c1aaa4bf04..6dcdaf1cfe 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); const Window& window = hlo->window(); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { - return Unimplemented( - "Dilation for reduce-window not implemented on GPU. " - "See b/31410564."); - } - PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), @@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); + input_index[i] = NSWSub( + NSWAdd(stridden_index, + NSWMul(window_index[i], + index_typed_const( + window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. input_index[i] = - NSWSub(NSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + SDiv(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index cee11a8a21..608a42bb60 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1463,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[3,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // } + auto arg_array = absl::make_unique<Array2D<float>>(3, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f))); + + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(2); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = LiteralUtil::CreateR2<float>({{11}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b2d12c94b8..a450dc6ff5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2613,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); + base_index[i] = + window_count_index[i] * window.dimensions(i).stride() + + window_index[i] * window.dimensions(i).window_dilation() - + window.dimensions(i).padding_low(); + // We are not in the base area if the dilation placed us out of bounds. + if (base_index[i] % window.dimensions(i).base_dilation() != 0) { + out_of_bound = true; + break; + } + // Apply the dilation to the base area. + base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; break; |