aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tayo Oguntebi <tayo@google.com>2017-02-09 13:17:37 -0800
committerGravatar gunan <gunan@google.com>2017-02-09 17:14:08 -0800
commit9f3eefc41c97a19cb9a38d8c0280d17dd2ea3b97 (patch)
tree2cdc897d238d01647abee042fe663cbf4c36aa22
parent31ff2c6029b887f35b546df75cc927f571be47d3 (diff)
[XLA] ReferenceUtil support for generic reduction functions in ReduceWindow.
Change: 147072336
-rw-r--r--tensorflow/compiler/xla/reference_util.cc18
-rw-r--r--tensorflow/compiler/xla/reference_util.h7
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc14
3 files changed, 32 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 7b1a361b93..86c9c3b1ac 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -172,8 +172,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
return result;
}
-/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
+/* static */ std::unique_ptr<Array4D<float>>
+ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
+ const std::function<float(float, float)>& reduce_func,
const tensorflow::gtl::ArraySlice<int64>& window,
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
@@ -210,8 +212,9 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
i1_base + i1_win < operand.n2() &&
i2_base + i2_win < operand.n3() &&
i3_base + i3_win < operand.n4()) {
- val += operand(i0_base + i0_win, i1_base + i1_win,
- i2_base + i2_win, i3_base + i3_win);
+ val = reduce_func(
+ val, operand(i0_base + i0_win, i1_base + i1_win,
+ i2_base + i2_win, i3_base + i3_win));
}
}
}
@@ -225,6 +228,15 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
return result;
}
+/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
+ const Array4D<float>& operand, float init,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
+ return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
+ padding);
+}
+
/* static */ std::unique_ptr<Array4D<float>>
ReferenceUtil::SelectAndScatter4DGePlus(
const Array4D<float>& operand, const Array4D<float>& source, float init,
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 52578a1b23..9e0f247203 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -156,6 +156,13 @@ class ReferenceUtil {
const tensorflow::gtl::ArraySlice<int64>& window,
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ // Performs a 4D window reduction with a generic reduce function.
+ static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
+ const Array4D<float>& operand, float init,
+ const std::function<float(float, float)>& reduce_func,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+
// Performs select and scatter with Greater Than or equal as the select, plus
// as the scatter, and Same Padding.
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 149a75c8e1..56501e43b5 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -446,10 +446,16 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) {
/*window_dimensions=*/{1, 1, 2, 1},
/*window_strides=*/{1, 1, 1, 1}, padding);
- Array4D<float> expected(1, 2, 1, 1);
- expected(0, 0, 0, 0) = 6;
- expected(0, 1, 0, 0) = 8;
- ComputeAndCompareR4<float>(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3));
+ const auto reduce_func = [](float arg1, float arg2) {
+ return std::min<float>(arg1 + arg2, 8.0f);
+ };
+
+ auto expected =
+ ReferenceUtil::ReduceWindow4DGeneric(input_array, 3.0f, reduce_func,
+ /*window=*/{1, 1, 2, 1},
+ /*stride=*/{1, 1, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3));
}
} // namespace