aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r--tensorflow/stream_executor/dnn.cc16
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 07fe8a85f4..29fd6d0e87 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -472,7 +472,8 @@ PoolingDescriptor::PoolingDescriptor(int ndims)
ndims_(ndims),
window_(ndims, 0),
padding_(ndims, 0),
- strides_(ndims, 1) {}
+ strides_(ndims, 1),
+ propagate_nans_(false) {}
PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
@@ -482,6 +483,7 @@ void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
window_ = other.window_;
padding_ = other.padding_;
strides_ = other.strides_;
+ propagate_nans_ = other.propagate_nans_;
}
string PoolingDescriptor::ToString() const {
@@ -495,9 +497,12 @@ string PoolingDescriptor::ToString() const {
port::Appendf(&padding, "%lld", padding_[i]);
}
- return port::Printf("{mode: %s window: %s strides: %s padding: %s}",
- mode_string, window.c_str(), strides.c_str(),
- padding.c_str());
+ const char* propagate_string = propagate_nans_ ? "Yes" : "No";
+
+ return port::Printf(
+ "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}",
+ mode_string, window.c_str(), strides.c_str(), padding.c_str(),
+ propagate_string);
}
string PoolingDescriptor::ToShortString() const {
@@ -508,7 +513,8 @@ string PoolingDescriptor::ToShortString() const {
port::Appendf(&padding, "_p%d:%lld", i, padding_[i]);
}
return port::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
- window, strides, padding);
+ window, strides, padding,
+ propagate_nans_ ? "propagate_nans" : "ignore_nans");
}
// -- NormalizeDescriptor