diff options
author | David Majnemer <majnemer@google.com> | 2018-10-06 10:04:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-06 10:09:36 -0700 |
commit | 5c0a6bdfeb1848b0146a36706d921dde06ba160a (patch) | |
tree | e549be74d1f90165865102536d45cc1b4a2a75a0 /tensorflow/compiler/xla/service/gpu | |
parent | 262f22f9eeee1ee00a9a92318d9a567a25c76696 (diff) |
[XLA] Add base and window dilation support to ReduceWindow
PiperOrigin-RevId: 216041507
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c1aaa4bf04..6dcdaf1cfe 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); const Window& window = hlo->window(); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { - return Unimplemented( - "Dilation for reduce-window not implemented on GPU. " - "See b/31410564."); - } - PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), @@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); + input_index[i] = NSWSub( + NSWAdd(stridden_index, + NSWMul(window_index[i], + index_typed_const( + window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. input_index[i] = - NSWSub(NSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + SDiv(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This |