aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.cc
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-10-05 09:45:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 09:52:14 -0700
commitb0e751a73d211872f8d937e5778b9e0e0a7b950b (patch)
treea52ec41f282297e05e77083b8c5b3fa1419a14d8 /tensorflow/stream_executor/dnn.cc
parent8dc5e3718b85b72a8bc6e5a2ea8270eecfdf99a1 (diff)
Add dilation rates support for ConvolutionDescriptor...
...in stream executor. In preparation for the support of native cudnn dilated convolution. PiperOrigin-RevId: 171165137
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r--tensorflow/stream_executor/dnn.cc15
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index ed9bdf2bc2..2c40e18f5c 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -424,6 +424,7 @@ int64 FilterDescriptor::ComputeWeightCount() const {
ConvolutionDescriptor::ConvolutionDescriptor(int ndims)
: zero_padding_(ndims, 0),
filter_strides_(ndims, 1),
+ dilation_rates_(ndims, 1),
pad_alignment_(PadAlignment::kDefault),
ndims_(ndims) {}
@@ -435,15 +436,18 @@ ConvolutionDescriptor::~ConvolutionDescriptor() {}
string ConvolutionDescriptor::ToString() const {
string padding;
string strides;
+ string dilations;
for (int i = 0; i < ndims_; i++) {
port::Appendf(&padding, "%lld ", zero_padding_[i]);
port::Appendf(&strides, "%lld ", filter_strides_[i]);
+ port::Appendf(&dilations, "%lld ", dilation_rates_[i]);
}
- return port::Printf("{zero_padding: %s pad_alignment: %s filter_strides: %s}",
- padding.c_str(),
- PadAlignmentString(pad_alignment_).c_str(),
- strides.c_str());
+ return port::Printf(
+ "{zero_padding: %s pad_alignment: %s filter_strides: %s dilation_rates: "
+ "%s}",
+ padding.c_str(), PadAlignmentString(pad_alignment_).c_str(),
+ strides.c_str(), dilations.c_str());
}
string ConvolutionDescriptor::ToShortString() const {
@@ -455,6 +459,9 @@ string ConvolutionDescriptor::ToShortString() const {
for (int i = 0; i < ndims_; i++) {
port::Appendf(&desc, "_s%d:%lld", i, filter_strides_[i]);
}
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&desc, "_d%d:%lld", i, dilation_rates_[i]);
+ }
return desc;
}