diff options
author | 2017-10-05 09:45:14 -0700 | |
---|---|---|
committer | 2017-10-05 09:52:14 -0700 | |
commit | b0e751a73d211872f8d937e5778b9e0e0a7b950b (patch) | |
tree | a52ec41f282297e05e77083b8c5b3fa1419a14d8 /tensorflow/stream_executor/dnn.h | |
parent | 8dc5e3718b85b72a8bc6e5a2ea8270eecfdf99a1 (diff) |
Add dilation rates support for ConvolutionDescriptor...
...in stream executor.
In preparation for the support of native cudnn dilated convolution.
PiperOrigin-RevId: 171165137
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 4beb46090c..5fe523602a 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -487,6 +487,10 @@ string PadAlignmentString(PadAlignment alignment); // window is moved in the "y dimension" according to this stride value. // - horizontal_filter_stride: analogous to the vertical stride above, but in // the "x dimension". +// - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped +// cells between each filter element in the "y dimension". +// - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1) +// skipped cells between each filter element in the "x dimension". class ConvolutionDescriptor { public: // By default construction, there is no zero-padding and the filter stride is @@ -523,6 +527,18 @@ class ConvolutionDescriptor { SetDim(&filter_strides_, dim, value); return *this; } + ConvolutionDescriptor& set_vertical_dilation_rate(int64 value) { + SetDim(&dilation_rates_, DimIndex::Y, value); + return *this; + } + ConvolutionDescriptor& set_horizontal_dilation_rate(int64 value) { + SetDim(&dilation_rates_, DimIndex::X, value); + return *this; + } + ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64 value) { + SetDim(&dilation_rates_, dim, value); + return *this; + } ConvolutionDescriptor& set_pad_alignment(PadAlignment pad_alignment) { pad_alignment_ = pad_alignment; return *this; @@ -539,19 +555,28 @@ class ConvolutionDescriptor { int64 horizontal_filter_stride() const { return GetDim(filter_strides_, DimIndex::X); } + int64 vertical_dilation_rate() const { + return GetDim(dilation_rates_, DimIndex::Y); + } + int64 horizontal_dilation_rate() const { + return GetDim(dilation_rates_, DimIndex::X); + } int zero_padding(DimIndex dim) const { return GetDim(zero_padding_, dim); } int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); } + int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); } PadAlignment pad_alignment() const { return pad_alignment_; } int ndims() const { return ndims_; } std::vector<int64> strides() const { return filter_strides_; } + std::vector<int64> dilations() const { return dilation_rates_; } std::vector<int64> padding() const { return zero_padding_; } private: // Stored as: .. y, x. std::vector<int64> zero_padding_; std::vector<int64> filter_strides_; + std::vector<int64> dilation_rates_; PadAlignment pad_alignment_; int ndims_; // TODO(leary) cudnn provides these fields, but need to characterize what |