aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-10-05 09:45:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-05 09:52:14 -0700
commitb0e751a73d211872f8d937e5778b9e0e0a7b950b (patch)
treea52ec41f282297e05e77083b8c5b3fa1419a14d8 /tensorflow/stream_executor/dnn.h
parent8dc5e3718b85b72a8bc6e5a2ea8270eecfdf99a1 (diff)
Add dilation rates support for ConvolutionDescriptor...
...in stream executor. In preparation for the support of native cudnn dilated convolution. PiperOrigin-RevId: 171165137
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 4beb46090c..5fe523602a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -487,6 +487,10 @@ string PadAlignmentString(PadAlignment alignment);
// window is moved in the "y dimension" according to this stride value.
// - horizontal_filter_stride: analogous to the vertical stride above, but in
// the "x dimension".
+// - vertical_dilation_rate: there will be (vertical_dilation_rate - 1) skipped
+// cells between each filter element in the "y dimension".
+// - horizontal_dilation_rate: there will be (horizontal_dilation_rate - 1)
+// skipped cells between each filter element in the "x dimension".
class ConvolutionDescriptor {
public:
// By default construction, there is no zero-padding and the filter stride is
@@ -523,6 +527,18 @@ class ConvolutionDescriptor {
SetDim(&filter_strides_, dim, value);
return *this;
}
+ ConvolutionDescriptor& set_vertical_dilation_rate(int64 value) {
+ SetDim(&dilation_rates_, DimIndex::Y, value);
+ return *this;
+ }
+ ConvolutionDescriptor& set_horizontal_dilation_rate(int64 value) {
+ SetDim(&dilation_rates_, DimIndex::X, value);
+ return *this;
+ }
+ ConvolutionDescriptor& set_dilation_rate(DimIndex dim, int64 value) {
+ SetDim(&dilation_rates_, dim, value);
+ return *this;
+ }
ConvolutionDescriptor& set_pad_alignment(PadAlignment pad_alignment) {
pad_alignment_ = pad_alignment;
return *this;
@@ -539,19 +555,28 @@ class ConvolutionDescriptor {
int64 horizontal_filter_stride() const {
return GetDim(filter_strides_, DimIndex::X);
}
+ int64 vertical_dilation_rate() const {
+ return GetDim(dilation_rates_, DimIndex::Y);
+ }
+ int64 horizontal_dilation_rate() const {
+ return GetDim(dilation_rates_, DimIndex::X);
+ }
int zero_padding(DimIndex dim) const { return GetDim(zero_padding_, dim); }
int filter_stride(DimIndex dim) const { return GetDim(filter_strides_, dim); }
+ int dilation_rate(DimIndex dim) const { return GetDim(dilation_rates_, dim); }
PadAlignment pad_alignment() const { return pad_alignment_; }
int ndims() const { return ndims_; }
std::vector<int64> strides() const { return filter_strides_; }
+ std::vector<int64> dilations() const { return dilation_rates_; }
std::vector<int64> padding() const { return zero_padding_; }
private:
// Stored as: .. y, x.
std::vector<int64> zero_padding_;
std::vector<int64> filter_strides_;
+ std::vector<int64> dilation_rates_;
PadAlignment pad_alignment_;
int ndims_;
// TODO(leary) cudnn provides these fields, but need to characterize what