diff options
author | 2018-01-18 14:24:50 -0800 | |
---|---|---|
committer | 2018-01-18 14:31:59 -0800 | |
commit | eb690b51d1a7fb0e14884a5074c545295b6c2b23 (patch) | |
tree | f3f870e05a737b3f3962929ab5a2c585ae90a83f /tensorflow/core/kernels/conv_grad_filter_ops.cc | |
parent | 89163b8c047aa7bdaa76a0a0e8a440313c76d406 (diff) |
Fixed Eigen-version Conv2DBackpropInput kernel, marked with "eigen_tensor".
LaunchConv2DBackpropInputOp, a template functor class used in
Conv2DFastBackpropInputOp class was specialized in a wrong place, in
conv_grad_filter_ops.cc (ODR violation). This caused wrong functor to be called
inside the op kernel class, because the correct functor definition was overriden
by a wrong specialization. The functor class that should have been specialized
in conv_grad_filter_ops.cc was LaunchConv2DBackpropFilterOp.
PiperOrigin-RevId: 182438025
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_filter_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_grad_filter_ops.cc | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 5e4feb2584..512bcc6c01 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -93,16 +93,15 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template <typename T> -struct LaunchConv2DBackpropInputOp<CPUDevice, T> { +struct LaunchConv2DBackpropFilterOp<CPUDevice, T> { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& out_backprop, const Tensor& input, int row_stride, int col_stride, const Padding& padding, Tensor* filter_backprop, TensorFormat data_format) { const CPUDevice& d = ctx->eigen_device<CPUDevice>(); - functor::SpatialConvolutionBackwardInput<CPUDevice, T>()( + functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()( d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(), - out_backprop.tensor<T, 4>(), filter_backprop->dim_size(0), - filter_backprop->dim_size(1), row_stride, col_stride); + out_backprop.tensor<T, 4>(), row_stride, col_stride); } }; @@ -273,7 +272,7 @@ class Conv2DFastBackpropFilterOp : public OpKernel { } #endif - LaunchConv2DBackpropInputOp<Device, T>()( + LaunchConv2DBackpropFilterOp<Device, T>()( context, false, false, out_backprop, input, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_); } @@ -603,7 +602,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { return; } - // For now we take the stride from the second and third dimensions only (we // do not support striding on the batch or depth dimension). const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); |