aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Bjarke Hammersholt Roune <broune@google.com>2017-04-28 16:55:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 18:11:22 -0700
commitf8051d04d982b5877717f1e4145488263aa69c68 (patch)
tree6475a4d0ac49db9434011f5ccc24c8eb9d20e025 /tensorflow
parent0a9b39caefd437fec742ae48b25061abd6e2699b (diff)
ReferenceUtil::ReduceWindow4DGeneric() now supports arbitrary edge padding.
Change: 154603866
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/reference_util.cc20
-rw-r--r--tensorflow/compiler/xla/reference_util.h6
2 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 5630033ac8..4194d5fc6b 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -180,14 +180,28 @@ ReferenceUtil::ReduceWindow4DGeneric(
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
- auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
+ return ReduceWindow4DGeneric(
+ operand, init, reduce_func, window, stride,
+ xla::MakePadding(dim_lengths, window, stride, padding));
+}
+
+/* static */ std::unique_ptr<Array4D<float>>
+ReferenceUtil::ReduceWindow4DGeneric(
+ const Array4D<float>& operand, float init,
+ const std::function<float(float, float)>& reduce_func,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride,
+ const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
+ operand.n4()};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
for (int64 i = 0; i < window.size(); ++i) {
+ int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
window_counts[i] =
- WindowCount(dim_lengths[i], window[i], stride[i], padding);
- pad_low[i] = padding_both[i].first;
+ window_util::StridedBound(padded_width, window[i], stride[i]);
+ pad_low[i] = padding[i].first;
}
auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
window_counts[2], window_counts[3]);
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index eb1eea7fc4..cdcad08c33 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -162,6 +162,12 @@ class ReferenceUtil {
const std::function<float(float, float)>& reduce_func,
const tensorflow::gtl::ArraySlice<int64>& window,
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
+ const Array4D<float>& operand, float init,
+ const std::function<float(float, float)>& reduce_func,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride,
+ const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
// Performs select and scatter with Greater Than or equal as the select, plus
// as the scatter, and Same Padding.