aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-10-06 10:04:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-06 10:09:36 -0700
commit5c0a6bdfeb1848b0146a36706d921dde06ba160a (patch)
treee549be74d1f90165865102536d45cc1b4a2a75a0 /tensorflow/compiler/xla
parent262f22f9eeee1ee00a9a92318d9a567a25c76696 (diff)
[XLA] Add base and window dilation support to ReduceWindow
PiperOrigin-RevId: 216041507
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc15
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h2
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py25
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc27
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h13
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc12
11 files changed, 161 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index d196252db1..6b31831010 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1789,9 +1789,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
std::vector<std::pair<int64, int64>> padding_values =
MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
window_strides, padding);
- return ReduceWindowWithGeneralPadding(operand, init_value, computation,
- window_dimensions, window_strides,
- padding_values);
+ return ReduceWindowWithGeneralPadding(
+ operand, init_value, computation, window_dimensions, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
});
}
@@ -1800,6 +1800,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1810,7 +1812,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
MakeWindow(window_dimensions, window_strides, padding,
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
+ /*lhs_dilation=*/base_dilations,
+ /*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferReduceWindowShape(operand_shape, init_shape,
@@ -2800,10 +2803,12 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
- padding);
+ base_dilations, window_dilations, padding);
}
XlaOp CrossReplicaSum(const XlaOp& operand,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index cd0d5ca5d3..2e14e47a35 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -671,6 +671,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
@@ -1245,6 +1247,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
friend XlaOp CrossReplicaSum(const XlaOp& operand,
absl::Span<const ReplicaGroup> replica_groups);
@@ -1818,6 +1822,8 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index cd5fd33029..ffa336f304 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -532,10 +532,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return xla::ReduceWindowWithGeneralPadding(
operand.op(), init_value.op(), local_computation.computation(),
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
}
LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 2166bb6721..43332e0abd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -278,6 +278,8 @@ class LocalComputationBuilder {
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64> > padding);
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index bb303c5678..f8197488fb 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -995,7 +995,30 @@ class ComputationBuilder(object):
window_strides)
return self._client.ReduceWindowWithGeneralPadding(
operand, init_value, computation_to_apply.c_local_computation,
- window_dimensions, window_strides, pads)
+ window_dimensions, window_strides, (), (), pads)
+
+ def ReduceWindowWithGeneralPadding(
+ self, operand, init_value, computation_to_apply, window_dimensions,
+ window_strides, base_dilations, window_dilations, padding):
+ """Enqueues a windowed reduction operation onto the computation.
+
+ Args:
+ operand: reduction operand (LocalOp).
+ init_value: reduction initial value (LocalOp).
+ computation_to_apply: a binary reduction function (Computation).
+ window_dimensions: dimensions of window (sequence of integers).
+ window_strides: strides for window (sequence of integers).
+ base_dilations: dilations for the base (sequence of integers).
+ window_dilations: dilations for window (sequence of integers).
+ padding: length-N array-like of pairs of integers of (low, high) padding.
+
+ Returns:
+ A LocalOp representing the added ReduceWindow op.
+ """
+ return self._client.ReduceWindowWithGeneralPadding(
+ operand, init_value, computation_to_apply.c_local_computation,
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding)
def RngNormal(self, mu, sigma, dims):
"""Enqueues an RngNormal operation onto the computation.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 75dae7a714..86d9dbea90 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -2057,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return Status::OK();
}
+ // Bail on dilation.
+ if (window_util::HasDilation(window)) {
+ VLOG(10) << "Not folding pad into reduce-window as there is dilation.";
+ return Status::OK();
+ }
+
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
<< (convert != nullptr
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a70abb117a..b2abdb39a5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -688,8 +688,25 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ input_index[i] = NSWSub(
+ NSWAdd(strided_index,
+ NSWMul(window_index[i],
+ b_.getInt64(window.dimensions(i).window_dilation()))),
+ b_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())),
+ b_.getInt64(0));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = dilation_condition;
+ } else {
+ in_bounds_condition = And(in_bounds_condition, dilation_condition);
+ }
+
+ // Apply base dilation to the index.
+ input_index[i] =
+ SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
@@ -728,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
/*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32, F16}));
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(reduce_window->window())) {
- return Unimplemented(
- "Dilation for ReduceWindow is not implemented on CPU.");
- }
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index c1aaa4bf04..6dcdaf1cfe 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
const Window& window = hlo->window();
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
- return Unimplemented(
- "Dilation for reduce-window not implemented on GPU. "
- "See b/31410564.");
- }
-
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
@@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(
+ NSWAdd(stridden_index,
+ NSWMul(window_index[i],
+ index_typed_const(
+ window.dimensions(i).window_dilation()))),
+ index_typed_const(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation())),
+ index_typed_const(0));
+ in_bounds = And(in_bounds, dilation_condition);
+
+ // Apply base dilation to the index.
input_index[i] =
- NSWSub(NSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ SDiv(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index cee11a8a21..608a42bb60 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1463,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
+TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) {
+ HloComputation::Builder b(TestName());
+
+ // arg:
+ // f32[3,3] {
+ // { 1, 2, 3 },
+ // { 5, 6, 7 },
+ // { 9, 10, 11 },
+ // }
+ auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
+ arg_array->FillUnique(1.0f);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
+
+ HloInstruction* arg_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
+
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
+
+ HloComputation::Builder max_computation("max");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ max_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
+ auto max_func = module().AddEmbeddedComputation(max_computation.Build());
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(2);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
+ b.AddInstruction(HloInstruction::CreateReduceWindow(
+ shape, arg_instruction, init_value, window, max_func));
+
+ module().AddEntryComputation(b.Build());
+
+ Literal result = Evaluate();
+
+ auto expected = LiteralUtil::CreateR2<float>({{11}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
HloComputation::Builder b(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index b2d12c94b8..a450dc6ff5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -2613,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> base_index(rank);
bool out_of_bound = false;
for (int64 i = 0; i < rank; ++i) {
- base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
+ base_index[i] =
+ window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] * window.dimensions(i).window_dilation() -
+ window.dimensions(i).padding_low();
+ // We are not in the base area if the dilation placed us out of bounds.
+ if (base_index[i] % window.dimensions(i).base_dilation() != 0) {
+ out_of_bound = true;
+ break;
+ }
+ // Apply the dilation to the base area.
+ base_index[i] /= window.dimensions(i).base_dilation();
if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
out_of_bound = true;
break;
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index c25ccafaf8..22fe4a2670 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
/*padding=*/padding);
CHECK(reducer == kAdd || reducer == kMax);
@@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
@@ -1369,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }