diff options
author | Yangzihao Wang <yangzihao@google.com> | 2017-10-05 09:45:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-05 09:52:14 -0700 |
commit | b0e751a73d211872f8d937e5778b9e0e0a7b950b (patch) | |
tree | a52ec41f282297e05e77083b8c5b3fa1419a14d8 /tensorflow/stream_executor/dnn.cc | |
parent | 8dc5e3718b85b72a8bc6e5a2ea8270eecfdf99a1 (diff) |
Add dilation rates support for ConvolutionDescriptor...
...in stream executor.
In preparation for the support of native cudnn dilated convolution.
PiperOrigin-RevId: 171165137
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/dnn.cc | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index ed9bdf2bc2..2c40e18f5c 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -424,6 +424,7 @@ int64 FilterDescriptor::ComputeWeightCount() const { ConvolutionDescriptor::ConvolutionDescriptor(int ndims) : zero_padding_(ndims, 0), filter_strides_(ndims, 1), + dilation_rates_(ndims, 1), pad_alignment_(PadAlignment::kDefault), ndims_(ndims) {} @@ -435,15 +436,18 @@ ConvolutionDescriptor::~ConvolutionDescriptor() {} string ConvolutionDescriptor::ToString() const { string padding; string strides; + string dilations; for (int i = 0; i < ndims_; i++) { port::Appendf(&padding, "%lld ", zero_padding_[i]); port::Appendf(&strides, "%lld ", filter_strides_[i]); + port::Appendf(&dilations, "%lld ", dilation_rates_[i]); } - return port::Printf("{zero_padding: %s pad_alignment: %s filter_strides: %s}", - padding.c_str(), - PadAlignmentString(pad_alignment_).c_str(), - strides.c_str()); + return port::Printf( + "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: " + "%s}", + padding.c_str(), PadAlignmentString(pad_alignment_).c_str(), + strides.c_str(), dilations.c_str()); } string ConvolutionDescriptor::ToShortString() const { @@ -455,6 +459,9 @@ string ConvolutionDescriptor::ToShortString() const { for (int i = 0; i < ndims_; i++) { port::Appendf(&desc, "_s%d:%lld", i, filter_strides_[i]); } + for (int i = 0; i < ndims_; i++) { + port::Appendf(&desc, "_d%d:%lld", i, dilation_rates_[i]); + } return desc; } |