aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-10-06 10:04:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-06 10:09:36 -0700
commit5c0a6bdfeb1848b0146a36706d921dde06ba160a (patch)
treee549be74d1f90165865102536d45cc1b4a2a75a0 /tensorflow/compiler/xla/service
parent262f22f9eeee1ee00a9a92318d9a567a25c76696 (diff)
[XLA] Add base and window dilation support to ReduceWindow
PiperOrigin-RevId: 216041507
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc27
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h13
5 files changed, 105 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 75dae7a714..86d9dbea90 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -2057,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return Status::OK();
}
+ // Bail on dilation.
+ if (window_util::HasDilation(window)) {
+ VLOG(10) << "Not folding pad into reduce-window as there is dilation.";
+ return Status::OK();
+ }
+
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
<< (convert != nullptr
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a70abb117a..b2abdb39a5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -688,8 +688,25 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ input_index[i] = NSWSub(
+ NSWAdd(strided_index,
+ NSWMul(window_index[i],
+ b_.getInt64(window.dimensions(i).window_dilation()))),
+ b_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())),
+ b_.getInt64(0));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = dilation_condition;
+ } else {
+ in_bounds_condition = And(in_bounds_condition, dilation_condition);
+ }
+
+ // Apply base dilation to the index.
+ input_index[i] =
+ SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
@@ -728,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
/*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32, F16}));
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(reduce_window->window())) {
- return Unimplemented(
- "Dilation for ReduceWindow is not implemented on CPU.");
- }
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index c1aaa4bf04..6dcdaf1cfe 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
const Window& window = hlo->window();
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
- return Unimplemented(
- "Dilation for reduce-window not implemented on GPU. "
- "See b/31410564.");
- }
-
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
@@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(
+ NSWAdd(stridden_index,
+ NSWMul(window_index[i],
+ index_typed_const(
+ window.dimensions(i).window_dilation()))),
+ index_typed_const(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation())),
+ index_typed_const(0));
+ in_bounds = And(in_bounds, dilation_condition);
+
+ // Apply base dilation to the index.
input_index[i] =
- NSWSub(NSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ SDiv(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index cee11a8a21..608a42bb60 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1463,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
+TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) {
+ HloComputation::Builder b(TestName());
+
+ // arg:
+ // f32[3,3] {
+ // { 1, 2, 3 },
+ // { 5, 6, 7 },
+ // { 9, 10, 11 },
+ // }
+ auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
+ arg_array->FillUnique(1.0f);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
+
+ HloInstruction* arg_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
+
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
+
+ HloComputation::Builder max_computation("max");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ max_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
+ auto max_func = module().AddEmbeddedComputation(max_computation.Build());
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(2);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
+ b.AddInstruction(HloInstruction::CreateReduceWindow(
+ shape, arg_instruction, init_value, window, max_func));
+
+ module().AddEntryComputation(b.Build());
+
+ Literal result = Evaluate();
+
+ auto expected = LiteralUtil::CreateR2<float>({{11}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
HloComputation::Builder b(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index b2d12c94b8..a450dc6ff5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -2613,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> base_index(rank);
bool out_of_bound = false;
for (int64 i = 0; i < rank; ++i) {
- base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
+ base_index[i] =
+ window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] * window.dimensions(i).window_dilation() -
+ window.dimensions(i).padding_low();
+ // We are not in the base area if the dilation placed us out of bounds.
+ if (base_index[i] % window.dimensions(i).base_dilation() != 0) {
+ out_of_bound = true;
+ break;
+ }
+ // Apply the dilation to the base area.
+ base_index[i] /= window.dimensions(i).base_dilation();
if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
out_of_bound = true;
break;