diff options
author | Tayo Oguntebi <tayo@google.com> | 2017-02-09 13:17:37 -0800 |
---|---|---|
committer | gunan <gunan@google.com> | 2017-02-09 17:14:08 -0800 |
commit | 9f3eefc41c97a19cb9a38d8c0280d17dd2ea3b97 (patch) | |
tree | 2cdc897d238d01647abee042fe663cbf4c36aa22 | |
parent | 31ff2c6029b887f35b546df75cc927f571be47d3 (diff) |
[XLA] ReferenceUtil support for generic reduction functions in ReduceWindow.
Change: 147072336
-rw-r--r-- | tensorflow/compiler/xla/reference_util.cc | 18 | ||||
-rw-r--r-- | tensorflow/compiler/xla/reference_util.h | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/reduce_window_test.cc | 14 |
3 files changed, 32 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 7b1a361b93..86c9c3b1ac 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -172,8 +172,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, return result; } -/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( +/* static */ std::unique_ptr<Array4D<float>> +ReferenceUtil::ReduceWindow4DGeneric( const Array4D<float>& operand, float init, + const std::function<float(float, float)>& reduce_func, const tensorflow::gtl::ArraySlice<int64>& window, const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), @@ -210,8 +212,9 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, i1_base + i1_win < operand.n2() && i2_base + i2_win < operand.n3() && i3_base + i3_win < operand.n4()) { - val += operand(i0_base + i0_win, i1_base + i1_win, - i2_base + i2_win, i3_base + i3_win); + val = reduce_func( + val, operand(i0_base + i0_win, i1_base + i1_win, + i2_base + i2_win, i3_base + i3_win)); } } } @@ -225,6 +228,15 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, return result; } +/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( + const Array4D<float>& operand, float init, + const tensorflow::gtl::ArraySlice<int64>& window, + const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, + padding); +} + /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::SelectAndScatter4DGePlus( const Array4D<float>& operand, const Array4D<float>& source, float init, diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 52578a1b23..9e0f247203 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -156,6 +156,13 @@ class ReferenceUtil { const tensorflow::gtl::ArraySlice<int64>& window, const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding); + // Performs a 4D window reduction with a generic reduce function. + static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric( + const Array4D<float>& operand, float init, + const std::function<float(float, float)>& reduce_func, + const tensorflow::gtl::ArraySlice<int64>& window, + const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding); + // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus( diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 149a75c8e1..56501e43b5 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -446,10 +446,16 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); - Array4D<float> expected(1, 2, 1, 1); - expected(0, 0, 0, 0) = 6; - expected(0, 1, 0, 0) = 8; - ComputeAndCompareR4<float>(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3)); + const auto reduce_func = [](float arg1, float arg2) { + return std::min<float>(arg1 + arg2, 8.0f); + }; + + auto expected = + ReferenceUtil::ReduceWindow4DGeneric(input_array, 3.0f, reduce_func, + /*window=*/{1, 1, 2, 1}, + /*stride=*/{1, 1, 1, 1}, padding); + + ComputeAndCompareR4<float>(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); } } // namespace |