diff options
author | Bjarke Hammersholt Roune <broune@google.com> | 2017-04-28 16:55:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-28 18:11:22 -0700 |
commit | f8051d04d982b5877717f1e4145488263aa69c68 (patch) | |
tree | 6475a4d0ac49db9434011f5ccc24c8eb9d20e025 /tensorflow | |
parent | 0a9b39caefd437fec742ae48b25061abd6e2699b (diff) |
ReferenceUtil::ReduceWindow4DGeneric() now supports arbitrary edge padding.
Change: 154603866
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/reference_util.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/reference_util.h | 6 |
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 5630033ac8..4194d5fc6b 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -180,14 +180,28 @@ ReferenceUtil::ReduceWindow4DGeneric( const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow4DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + +/* static */ std::unique_ptr<Array4D<float>> +ReferenceUtil::ReduceWindow4DGeneric( + const Array4D<float>& operand, float init, + const std::function<float(float, float)>& reduce_func, + const tensorflow::gtl::ArraySlice<int64>& window, + const tensorflow::gtl::ArraySlice<int64>& stride, + const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { + std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; std::vector<int64> window_counts(window.size(), 0); std::vector<int64> pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1], window_counts[2], window_counts[3]); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index eb1eea7fc4..cdcad08c33 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -162,6 +162,12 @@ class ReferenceUtil { const std::function<float(float, float)>& reduce_func, const tensorflow::gtl::ArraySlice<int64>& window, const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding); + static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric( + const Array4D<float>& operand, float init, + const std::function<float(float, float)>& reduce_func, + const tensorflow::gtl::ArraySlice<int64>& window, + const tensorflow::gtl::ArraySlice<int64>& stride, + const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding); // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. |