aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/eigen_backward_spatial_convolutions.h')
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h41
1 files changed, 18 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index 8d06107553..960920c55b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -238,8 +238,8 @@ SpatialConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // the kernelRows and the kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, the kernelRows and the kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -332,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic,
- const Input> > > > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward> > > >::type
@@ -456,12 +449,16 @@ SpatialConvolutionBackwardKernel(
eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
-
+ // We will contract along the collapsed dimension that contains the
+ // outputCols, outputRows and OTHERS.
array<IndexPair<TensorIndex>, 1> contract_dims;
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ if (isColMajor) {
+ // col-major: output_backward.contract(input.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ } else {
+ // row-major: input.patches.contract(output_backward)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ }
// After the contraction, the kernel will have the desired shape
// out_depth X in_shape X kernel_rows X kernel_cols
@@ -487,8 +484,7 @@ SpatialConvolutionBackwardKernel(
kernelRows, kernelCols, row_stride, col_stride,
row_in_stride, col_in_stride, 1, 1, padding_top,
padding_bottom, padding_left, padding_right, OutScalar(0))
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
+ .reshape(pre_contract_dims),
contract_dims)
.reshape(kernel_dims),
input
@@ -497,7 +493,6 @@ SpatialConvolutionBackwardKernel(
padding_top, padding_bottom, padding_left,
padding_right, OutScalar(0))
.reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
.contract(output_backward.reshape(output_dims), contract_dims)
.reshape(kernel_dims));
}