// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2014 Benoit Steiner // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_CXX11_NEURAL_NETWORKS_SPATIAL_CONVOLUTIONS_H #define EIGEN_CXX11_NEURAL_NETWORKS_SPATIAL_CONVOLUTIONS_H namespace Eigen { namespace internal { // These optimizations require vector instructions #ifdef EIGEN_VECTORIZE // TODO: Consolidate this part of the code with the image patch extraction code // since they are both very similar. template class TensorContractionInputMapper >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> { public: typedef TensorContractionInputMapper >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self; typedef Self SubMapper; typedef Self VectorMapper; typedef Self LinearMapper; typedef typename packet_traits::type Packet; TensorContractionInputMapper(const TensorEvaluator >, Device>& tensor, const nocontract_t&, const nocontract_t&, const contract_t&, const contract_t&, const Index depth_offset = 0, const Index col_offset = 0) : m_depth_offset(depth_offset), m_col_offset(col_offset), m_impl(tensor.impl().impl()) { if (internal::traits::Layout == ColMajor) { m_patch_depth = tensor.impl().dimensions()[0]; m_patch_rows = tensor.impl().dimensions()[1]; m_patch_cols = tensor.impl().dimensions()[2]; m_num_patches = tensor.impl().dimensions()[3]; } else { static const int NumDims = tensor.impl().dimensions().size(); m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; m_patch_rows = tensor.impl().dimensions()[NumDims - 2]; m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; m_num_patches = tensor.impl().dimensions()[NumDims - 4]; } m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); m_patch_col_inflate_strides = tensor.impl().colInflateStride(); m_colStride = m_patch_rows; m_outputRows = tensor.impl().outputRows(); m_row_strides = tensor.impl().userRowStride(); m_col_strides = tensor.impl().userColStride(); m_in_row_strides = tensor.impl().userInRowStride(); m_in_col_strides = tensor.impl().userInColStride(); if (internal::traits::Layout == ColMajor) { m_inputRows = tensor.impl().impl().dimensions()[1]; m_inputCols = tensor.impl().impl().dimensions()[2]; } else { static const int NumDims = tensor.impl().impl().dimensions().size(); m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; } m_rowInputStride = m_patch_depth; m_colInputStride = m_patch_depth * m_inputRows; m_patchInputStride = m_patch_depth * m_inputRows * m_inputCols; m_rowPaddingTop = tensor.impl().rowPaddingTop(); m_colPaddingLeft = tensor.impl().colPaddingLeft(); m_fastInputRowStride = internal::TensorIntDivisor(m_patch_row_inflate_strides); m_fastInputColStride = internal::TensorIntDivisor(m_patch_col_inflate_strides); m_fastNumPatches = internal::TensorIntDivisor(m_num_patches); m_fastColStride = internal::TensorIntDivisor(m_colStride); m_fastOutputRows = internal::TensorIntDivisor(m_outputRows); m_fastDimZero = internal::TensorIntDivisor(m_patch_depth); computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex); } TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper, const Index depth_offset, const Index col_offset) : m_depth_offset(depth_offset), m_col_offset(col_offset), m_impl(base_mapper.m_impl) { m_patch_depth = base_mapper.m_patch_depth; m_patch_rows = base_mapper.m_patch_rows; m_patch_cols = base_mapper.m_patch_cols; m_num_patches = base_mapper.m_num_patches; m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; m_colStride = base_mapper.m_colStride; m_rowInputStride = base_mapper.m_rowInputStride; m_colInputStride = base_mapper.m_colInputStride; m_patchInputStride = base_mapper.m_patchInputStride; m_inputRows = base_mapper.m_inputRows; m_inputCols = base_mapper.m_inputCols; m_outputRows = base_mapper.m_outputRows; m_row_strides = base_mapper.m_row_strides; m_col_strides = base_mapper.m_col_strides; m_in_row_strides = base_mapper.m_in_row_strides; m_in_col_strides = base_mapper.m_in_col_strides; m_rowPaddingTop = base_mapper.m_rowPaddingTop; m_colPaddingLeft = base_mapper.m_colPaddingLeft; m_fastInputRowStride = base_mapper.m_fastInputRowStride; m_fastInputColStride = base_mapper.m_fastInputColStride; m_fastNumPatches = base_mapper.m_fastNumPatches; m_fastColStride = base_mapper.m_fastColStride; m_fastOutputRows = base_mapper.m_fastOutputRows; m_fastDimZero = base_mapper.m_fastDimZero; computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, m_otherIndex); } // If true, turns off some optimizations for loading packets since the image // patches are "non-standard" such as there are non-trivial strides or // inflations in the input. EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { return m_in_row_strides != 1 || m_in_col_strides != 1 || m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { return SubMapper(*this, m_depth_offset + i, m_col_offset + j); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { return LinearMapper(*this, m_depth_offset + i, m_col_offset + j); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { return loadCoeff(row + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex); } // Load the coefficient at the patchIndex location instead of the usual m_rowIndex, // m_colIndex, m_otherIndex. This is currently only used by the gpu code. EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { checkZeroOffsets(); Index rowIndex, colIndex, otherIndex; computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); return loadCoeff(row, rowIndex, colIndex, otherIndex); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { return loadPacket(row + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex); } // Load the packet at the patchIndex location instead of the usual m_rowIndex, // m_colIndex, m_otherIndex. This is currently only used by the gpu code. EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { checkZeroOffsets(); Index rowIndex, colIndex, otherIndex; computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); return loadPacket(row, rowIndex, colIndex, otherIndex); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const TensorEvaluator& impl() const { return m_impl; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { const Index r = m_rowIndex + row; return r < 0 | r >= m_inputRows; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { const Index c = m_colIndex + col; return c < 0 | c >= m_inputCols; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { const Index r = m_rowIndex + row; const Index c = m_colIndex + col; return r * m_rowInputStride + c * m_colInputStride + m_otherIndex; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const { const Index inputIndex = depth + baseIndex; return m_impl.template packet(inputIndex); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index rowOffset() const { const Index patchOffset = m_depth_offset / m_fastDimZero; const Index colOffset = patchOffset / m_fastColStride; return patchOffset-colOffset*m_colStride; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index colOffset() const { const Index patchOffset = m_depth_offset / m_fastDimZero; const Index colOffset = patchOffset / m_fastColStride; return colOffset; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index depthOffset() const { const Index patchOffset = m_depth_offset % m_patch_depth; return patchOffset; } private: EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { // Find the offset of the element wrt the location of the first element. const Index patchOffset = patchId / m_fastDimZero; const Index colOffset = patchOffset / m_fastColStride; const Index inputCol = colIndex + colOffset * m_in_col_strides; const Index origInputCol = (m_patch_col_inflate_strides == 1) ? inputCol : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); const Index rowOffset = patchOffset - colOffset * m_colStride; const Index inputRow = rowIndex + rowOffset * m_in_row_strides; const Index origInputRow = (m_patch_row_inflate_strides == 1) ? inputRow : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); if (origInputCol < 0 | origInputRow < 0 | origInputCol >= m_inputCols | origInputRow >= m_inputRows | (inputCol != origInputCol * m_patch_col_inflate_strides) | (inputRow != origInputRow * m_patch_row_inflate_strides)) { return Scalar(0); } const Index depth = patchId - patchOffset * m_patch_depth; const Index inputIndex = depth + origInputRow * m_rowInputStride + origInputCol * m_colInputStride + otherIndex; return m_impl.coeff(inputIndex); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { const Index packetSize = internal::unpacket_traits::size; EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) eigen_assert(patchId < m_patch_depth*m_patch_rows*m_patch_cols); if (nonStandardPatches()) { return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); } if ((m_patch_depth % packetSize) == 0) { // Find the offset of the element wrt the location of the first element. const Index patchOffset = patchId / m_fastDimZero; eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); const Index colOffset = patchOffset / m_fastColStride; const Index inputCol = colIndex + colOffset; const Index rowOffset = patchOffset - colOffset*m_colStride; const Index inputRow = rowIndex + rowOffset; if (inputCol < 0 | inputRow < 0 | inputCol >= m_inputCols | inputRow >= m_inputRows) { // all zeros return internal::pset1(Scalar(0)); } // no padding const Index depth = patchId - patchOffset * m_patch_depth; const Index inputIndex = depth + inputRow * m_rowInputStride + inputCol * m_colInputStride + otherIndex; return m_impl.template packet(inputIndex); } else { const Index patchOffsets[2] = {patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, patchOffsets[1] / m_fastColStride}; const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]}; if (inputCols[0] >= m_inputCols | inputCols[1] < 0) { // all zeros return internal::pset1(Scalar(0)); } if (inputCols[0] == inputCols[1]) { const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0]*m_colStride, patchOffsets[1] - colOffsets[1]*m_colStride}; eigen_assert(rowOffsets[0] <= rowOffsets[1]); const Index inputRows[2] = {rowIndex + rowOffsets[0], rowIndex + rowOffsets[1]}; if (inputRows[0] >= m_inputRows | inputRows[1] < 0) { // all zeros return internal::pset1(Scalar(0)); } if (inputRows[0] >= 0 & inputRows[1] < m_inputRows) { // no padding const Index depth = patchId - patchOffsets[0] * m_patch_depth; const Index inputIndex = depth + inputRows[0] * m_rowInputStride + inputCols[0] * m_colInputStride + otherIndex; return m_impl.template packet(inputIndex); } } } return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { const int packetSize = internal::unpacket_traits::size; EIGEN_ALIGN_MAX typename internal::remove_const::type values[packetSize]; for (int i = 0; i < packetSize; ++i) { values[i] = loadCoeff(patchId+i, rowIndex, colIndex, otherIndex); } Packet rslt = internal::pload(values); return rslt; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(Index patchIndex, Index& rowIndex, Index& colIndex, Index& otherIndex) const { const int NumInputDims = array_size::Dimensions>::value; otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; const Index patch2DIndex = (NumInputDims == 3) ? patchIndex : (patchIndex - otherIndex * m_num_patches); otherIndex *= m_patchInputStride; colIndex = patch2DIndex / m_fastOutputRows; rowIndex = patch2DIndex - colIndex * m_outputRows; colIndex = colIndex * m_col_strides - m_colPaddingLeft; rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void checkZeroOffsets() const { eigen_assert(m_col_offset == 0); eigen_assert(m_depth_offset == 0); eigen_assert(m_rowIndex == 0); eigen_assert(m_colIndex == 0); eigen_assert(m_otherIndex == 0); } Index m_depth_offset; // First row in the input matrix Index m_col_offset; // First col in the input matrix Index m_patch_depth; // patch depth, which is equal to the input depth Index m_patch_rows; // number of rows in the patch Index m_patch_cols; // number of colums in the patch Index m_num_patches; // number of patches to extract. Index m_patch_row_inflate_strides; // the strides for row inflation in the image patch Index m_patch_col_inflate_strides; // the strides for col inflation in the image patch // Fast representation of inflation strides. internal::TensorIntDivisor m_fastInputRowStride; internal::TensorIntDivisor m_fastInputColStride; Index m_otherStride; Index m_colStride; internal::TensorIntDivisor m_fastNumPatches; internal::TensorIntDivisor m_fastColStride; Index m_rowInputStride; // row stride in the input tensor Index m_colInputStride; // col stride in the input tensor Index m_patchInputStride; // patch stride in the input tensor Index m_inputRows; // Number of rows in the input tensor Index m_inputCols; // Number of cols in the input tensor Index m_outputRows; // Number of patch rows Index m_row_strides; // User specified row stride Index m_col_strides; // User specified col stride Index m_in_row_strides; // User specified input row stride Index m_in_col_strides; // User specified input col stride Index m_rowPaddingTop; // Row padding Index m_colPaddingLeft; // Column padding internal::TensorIntDivisor m_fastOutputRows; internal::TensorIntDivisor m_fastDimZero; Index m_rowIndex; // precomputed row index corresponding to the col offset Index m_colIndex; // precomputed col index corresponding to the col offset Index m_otherIndex; // precomputed other index corresponding to the col offset const TensorEvaluator m_impl; }; template struct gemm_pack_rhs >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>, nr, ColMajor, false, false> { typedef TensorContractionInputMapper >, Device>, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> DataMapper; static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0) const { eigen_assert(stride == 0); eigen_assert(offset == 0); EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); typedef typename DataMapper::LinearMapper LinearMapper; typedef typename packet_traits::type Packet; const Index packet_cols4 = (cols/4) * 4; const Index peeled_k = (depth/packet_size) * packet_size; for(Index j2=0; j2(ceil_div(peeled_k, patch_rows*patch_depth)+startCol, patch_cols); for (Index c = startCol; c < max_cols; ++c) { eigen_assert(k < peeled_k); const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; const Index max_rows = std::min(ceil_div(peeled_k-c*patch_rows*patch_depth, patch_depth)+startRow, patch_rows); const bool pad_col0 = dm0.padCol(c); const bool pad_col1 = dm1.padCol(c); const bool pad_col2 = dm2.padCol(c); const bool pad_col3 = dm3.padCol(c); for (Index r = startRow; r < max_rows; ++r) { eigen_assert(k < peeled_k); const bool pad0 = pad_col0 || dm0.padRow(r); const bool pad1 = pad_col1 || dm1.padRow(r); const bool pad2 = pad_col2 || dm2.padRow(r); const bool pad3 = pad_col3 || dm3.padRow(r); const Index idx0 = dm0.baseIndex(r, c); const Index idx1 = dm1.baseIndex(r, c); const Index idx2 = dm2.baseIndex(r, c); const Index idx3 = dm3.baseIndex(r, c); const Index startDepth = ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0; const Index max_depth = std::min(peeled_k-c*patch_rows*patch_depth-r*patch_depth+startDepth, patch_depth); eigen_assert(max_depth % packet_size == 0); for (Index d = startDepth; d < max_depth; d += packet_size) { eigen_assert(k < peeled_k); PacketBlock kernel; kernel.packet[0] = pad0 ? pset1(0) : dm0.packetNoPadding(d, idx0); kernel.packet[1] = pad1 ? pset1(0) : dm1.packetNoPadding(d, idx1); kernel.packet[2] = pad2 ? pset1(0) : dm2.packetNoPadding(d, idx2); kernel.packet[3] = pad3 ? pset1(0) : dm3.packetNoPadding(d, idx3); ptranspose(kernel); pstoreu(block+0*packet_size, kernel.packet[0]); pstoreu(block+1*packet_size, kernel.packet[1]); pstoreu(block+2*packet_size, kernel.packet[2]); pstoreu(block+3*packet_size, kernel.packet[3]); block+=4*packet_size; k += packet_size; } } } } for(; k kernel; kernel.packet[0] = dm0.loadPacket(k); kernel.packet[1] = dm1.loadPacket(k); kernel.packet[2] = dm2.loadPacket(k); kernel.packet[3] = dm3.loadPacket(k); ptranspose(kernel); pstoreu(block+0*packet_size, kernel.packet[0]); pstoreu(block+1*packet_size, kernel.packet[1]); pstoreu(block+2*packet_size, kernel.packet[2]); pstoreu(block+3*packet_size, kernel.packet[3]); block+=4*packet_size; } } for(; k 1, then applies convolution with holes (aka atrous convolution), sampling every in_stride input pixels. * * The result can be assigned to a tensor of rank equal to the rank of the input. The dimensions of the result will be filters, height, width (and others if applicable). * * It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output. * */ template EIGEN_ALWAYS_INLINE static const typename internal::conditional< internal::traits::Layout == ColMajor, TensorReshapingOp::Index, internal::traits::NumDimensions>, const TensorContractionOp::Index>, 1>, const TensorReshapingOp::Index, 2>, const Kernel>, const TensorReshapingOp::Index, 2>, const TensorImagePatchOp > > >, TensorReshapingOp::Index, internal::traits::NumDimensions>, const TensorContractionOp::Index>, 1>, const TensorReshapingOp::Index, 2>, const TensorImagePatchOp >, const TensorReshapingOp::Index, 2>, const Kernel> > > >::type SpatialConvolution(const Input& input, const Kernel& kernel, const DenseIndex stride = 1, const PaddingType padding_type = PADDING_SAME, const DenseIndex in_stride = 1) { typedef typename internal::traits::Index TensorIndex; TensorRef::Scalar, internal::traits::NumDimensions, internal::traits::Layout, TensorIndex> > in(input); TensorRef::Scalar, internal::traits::NumDimensions, internal::traits::Layout, TensorIndex> > kern(kernel); EIGEN_STATIC_ASSERT(internal::traits::Layout == internal::traits::Layout, YOU_MADE_A_PROGRAMMING_MISTAKE); static const bool isColMajor = (internal::traits::Layout == ColMajor); static const int NumDims = internal::traits::NumDimensions; // Number of filters to apply. This is the same as the output depth of the result const TensorIndex kernelFilters = isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; // Number of channels. This is the same as the input depth. const TensorIndex kernelChannels = isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; const TensorIndex kernelRows = isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; const TensorIndex kernelCols = isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; const DenseIndex kernelRowsEff = kernelRows + (kernelRows - 1) * (in_stride - 1); const DenseIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1); array, 1> contract_dims; contract_dims[0] = IndexPair(1, 0); const TensorIndex InputRows = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); const TensorIndex InputCols = isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); TensorIndex out_height; TensorIndex out_width; switch (padding_type) { case PADDING_VALID: out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / static_cast(stride)); out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / static_cast(stride)); break; case PADDING_SAME: out_height = numext::ceil(InputRows / static_cast(stride)); out_width = numext::ceil(InputCols / static_cast(stride)); break; default: eigen_assert(false && "unexpected padding"); } // Molds the output of the patch extraction code into a 2d tensor: // - the first dimension (dims[0]): the patch values to be multiplied with the kernels // - the second dimension (dims[1]): everything else DSizes pre_contract_dims; if (isColMajor) { pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; pre_contract_dims[1] = out_height * out_width; for (int i = 3; i < NumDims; ++i) { pre_contract_dims[1] *= in.dimension(i); } } else { pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; pre_contract_dims[0] = out_height * out_width; for (int i = 0; i < NumDims - 3; ++i) { pre_contract_dims[0] *= in.dimension(i); } } // Molds the output of the contraction into the shape expected by the used // (assuming this is ColMajor): // - 1st dim: kernel filters // - 2nd dim: output height // - 3rd dim: output width // - 4th dim and beyond: everything else including batch size DSizes post_contract_dims; if (isColMajor) { post_contract_dims[0] = kernelFilters; post_contract_dims[1] = out_height; post_contract_dims[2] = out_width; for (int i = 3; i < NumDims; ++i) { post_contract_dims[i] = in.dimension(i); } } else { post_contract_dims[NumDims - 1] = kernelFilters; post_contract_dims[NumDims - 2] = out_height; post_contract_dims[NumDims - 3] = out_width; for (int i = 0; i < NumDims - 3; ++i) { post_contract_dims[i] = in.dimension(i); } } DSizes kernel_dims; if (isColMajor) { kernel_dims[0] = kernelFilters; kernel_dims[1] = kernelChannels * kernelRows * kernelCols; } else { kernel_dims[0] = kernelChannels * kernelRows * kernelCols; kernel_dims[1] = kernelFilters; } // TODO(yangke): choose() is defined in TensorContraction.h -- consider // moving it to somewhere more "common". return choose(Cond::Layout == ColMajor>(), kernel.reshape(kernel_dims).contract(input.extract_image_patches(kernelRows, kernelCols, stride, stride, in_stride, in_stride, padding_type).reshape(pre_contract_dims), contract_dims).reshape(post_contract_dims), input.extract_image_patches(kernelRows, kernelCols, stride, stride, in_stride, in_stride, padding_type).reshape(pre_contract_dims).contract(kernel.reshape(kernel_dims), contract_dims).reshape(post_contract_dims)); } } // end namespace Eigen #endif // EIGEN_CXX11_NEURAL_NETWORKS_SPATIAL_CONVOLUTIONS_H