aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-11-29 16:38:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-29 16:42:20 -0800
commitcb4ef362e4a18b3c42a2c90bdad8754d5ead4caf (patch)
treecca762531469418ade8c8d5bdf2a9ba6719ec111 /tensorflow/contrib/fused_conv
parent4ada275eed7472ae32c67a1ec0b9b1dc8d80d1f0 (diff)
Add native dilated support for conv2d and its gradients in cudnn v>=6.
PiperOrigin-RevId: 177382431
Diffstat (limited to 'tensorflow/contrib/fused_conv')
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc2
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h9
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc6
3 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 88306094ab..5fec69ea43 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -493,6 +493,8 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
{{conv_input_rows, conv_input_cols}},
output_depth,
{{filter_rows, filter_cols}},
+ // TODO(yangzihao): Add support for arbitrary dilations for fused conv.
+ {{1, 1}}, // dilation_rows, dilation_cols
{{row_stride, col_stride}},
{{padding_rows, padding_cols}},
conv_input->dtype(),
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
index dc43af1158..fa7a3c03aa 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
@@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters {
public:
FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
int64 out_depths, const SpatialArray& filter,
- const SpatialArray& stride, const SpatialArray& padding,
- DataType dtype, int device_id, bool has_side_input,
+ const SpatialArray& dilation, const SpatialArray& stride,
+ const SpatialArray& padding, DataType dtype,
+ int device_id, bool has_side_input,
ActivationMode activation_mode)
- : ConvParameters(batch, in_depths, in, out_depths, filter, stride,
- padding, dtype, device_id),
+ : ConvParameters(batch, in_depths, in, out_depths, filter, dilation,
+ stride, padding, dtype, device_id),
activation_mode_(activation_mode),
has_side_input_(has_side_input) {
hash_code_ = Hash64Combine(hash_code_, has_side_input);
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
index 887ebc5a6c..6a56237f67 100644
--- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -52,6 +52,7 @@ REGISTER_OP("FusedConv2DBiasActivation")
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
.Attr("activation_mode: {'Relu'} = 'Relu'")
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](shape_inference::InferenceContext* c) {
using shape_inference::ShapeHandle;
using shape_inference::DimensionHandle;
@@ -151,6 +152,11 @@ REGISTER_OP("FusedConv2DBiasActivation")
kernel_height, kernel_width, input_channels % 4 ]`
activation_mode: The activation applied to the output.
Currently must be "Relu".
+ dilations: 1-D tensor of length 4. The dilation factor for each dimension
+ of `input`. If set to k > 1, there will be k-1 skipped cells between
+ each filter element on that dimension. The dimension order is determined
+ by the value of `data_format`, see above for details. Dilations in the
+ batch and depth dimensions must be 1.
)doc");
} // namespace tensorflow