diff options
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/dnn.cc | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 07fe8a85f4..29fd6d0e87 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -472,7 +472,8 @@ PoolingDescriptor::PoolingDescriptor(int ndims) ndims_(ndims), window_(ndims, 0), padding_(ndims, 0), - strides_(ndims, 1) {} + strides_(ndims, 1), + propagate_nans_(false) {} PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {} @@ -482,6 +483,7 @@ void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) { window_ = other.window_; padding_ = other.padding_; strides_ = other.strides_; + propagate_nans_ = other.propagate_nans_; } string PoolingDescriptor::ToString() const { @@ -495,9 +497,12 @@ string PoolingDescriptor::ToString() const { port::Appendf(&padding, "%lld", padding_[i]); } - return port::Printf("{mode: %s window: %s strides: %s padding: %s}", - mode_string, window.c_str(), strides.c_str(), - padding.c_str()); + const char* propagate_string = propagate_nans_ ? "Yes" : "No"; + + return port::Printf( + "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}", + mode_string, window.c_str(), strides.c_str(), padding.c_str(), + propagate_string); } string PoolingDescriptor::ToShortString() const { @@ -508,7 +513,8 @@ string PoolingDescriptor::ToShortString() const { port::Appendf(&padding, "_p%d:%lld", i, padding_[i]); } return port::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg", - window, strides, padding); + window, strides, padding, + propagate_nans_ ? "propagate_nans" : "ignore_nans"); } // -- NormalizeDescriptor |