diff options
Diffstat (limited to 'tensorflow/core/kernels/eigen_backward_spatial_convolutions.h')
-rw-r--r-- | tensorflow/core/kernels/eigen_backward_spatial_convolutions.h | 41 |
1 files changed, 18 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index 8d06107553..960920c55b 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -238,8 +238,8 @@ SpatialConvolutionBackwardInput( } } - // We will contract along the fused dimension that contains the kernelFilters, - // the kernelRows and the kernelCols. + // We will contract along the collapsed dimension that contains the + // kernelFilters, the kernelRows and the kernelCols. array<IndexPair<TensorIndex>, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) @@ -332,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>, const OutputBackward>, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, - 2>, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, 2>, - const TensorImagePatchOp<Dynamic, Dynamic, - const Input> > > > >, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >, TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 4>, const TensorContractionOp< const array<IndexPair<typename internal::traits<Input>::Index>, 1>, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, - 2>, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, 2>, - const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>, const OutputBackward> > > >::type @@ -456,12 +449,16 @@ SpatialConvolutionBackwardKernel( eigen_assert(output_dims[0] == pre_contract_dims[0]); } - array<TensorIndex, 2> shuffle_dims; - shuffle_dims[0] = 1; - shuffle_dims[1] = 0; - + // We will contract along the collapsed dimension that contains the + // outputCols, outputRows and OTHERS. array<IndexPair<TensorIndex>, 1> contract_dims; - contract_dims[0] = IndexPair<TensorIndex>(1, 0); + if (isColMajor) { + // col-major: output_backward.contract(input.patches) + contract_dims[0] = IndexPair<TensorIndex>(1, 1); + } else { + // row-major: input.patches.contract(output_backward) + contract_dims[0] = IndexPair<TensorIndex>(0, 0); + } // After the contraction, the kernel will have the desired shape // out_depth X in_shape X kernel_rows X kernel_cols @@ -487,8 +484,7 @@ SpatialConvolutionBackwardKernel( kernelRows, kernelCols, row_stride, col_stride, row_in_stride, col_in_stride, 1, 1, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) - .reshape(pre_contract_dims) - .shuffle(shuffle_dims), + .reshape(pre_contract_dims), contract_dims) .reshape(kernel_dims), input @@ -497,7 +493,6 @@ SpatialConvolutionBackwardKernel( padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) .reshape(pre_contract_dims) - .shuffle(shuffle_dims) .contract(output_backward.reshape(output_dims), contract_dims) .reshape(kernel_dims)); } |