diff options
Diffstat (limited to 'tensorflow/core/kernels/eigen_spatial_convolutions.h')
-rw-r--r-- | tensorflow/core/kernels/eigen_spatial_convolutions.h | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h index 1acbe3a658..a4dff4b91c 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h @@ -797,6 +797,188 @@ struct gemm_pack_rhs< } }; +// Template specialization for packet_size = 2. We must special-case packet +// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. +template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, + typename ArgType, typename Device, typename Scalar, typename Index, + typename nocontract_t, typename contract_t, bool inner_dim_contiguous, + bool inner_dim_reordered, int Alignment, int nr> +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, + Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator< + const TensorReshapingOp< + NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, + Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_DEVICE_FUNC + static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + typedef typename packet_traits<Scalar>::type Packet; + + const int packet_size = 2; + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if (!non_standard_patches) { + const Index patch_depth = rhs.patchDepth(); + if ((patch_depth % packet_size) == 0) { + const Index patch_cols = rhs.patchCols(); + const Index patch_rows = rhs.patchRows(); + + const Index startCol = rhs.colOffset(); + const Index max_cols = std::min<Index>( + ceil_div(peeled_k, patch_rows * patch_depth) + startCol, + patch_cols); + + for (Index c = startCol; c < max_cols; ++c) { + eigen_assert(k < peeled_k); + const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; + const Index max_rows = std::min<Index>( + ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) + + startRow, + patch_rows); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + for (Index r = startRow; r < max_rows; ++r) { + eigen_assert(k < peeled_k); + const bool pad0 = pad_col0 || dm0.padRow(r); + const bool pad1 = pad_col1 || dm1.padRow(r); + const bool pad2 = pad_col2 || dm2.padRow(r); + const bool pad3 = pad_col3 || dm3.padRow(r); + + const Index idx0 = dm0.baseIndex(r, c); + const Index idx1 = dm1.baseIndex(r, c); + const Index idx2 = dm2.baseIndex(r, c); + const Index idx3 = dm3.baseIndex(r, c); + + const Index startDepth = + ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0; + const Index max_depth = + std::min<Index>(peeled_k - c * patch_rows * patch_depth - + r * patch_depth + startDepth, + patch_depth); + eigen_assert((max_depth - startDepth) % packet_size == 0); + for (Index d = startDepth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + k += packet_size; + } + } + } + + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = dm0.loadPacketFast(k); + kernel0.packet[1] = dm1.loadPacketFast(k); + kernel1.packet[0] = dm2.loadPacketFast(k); + kernel1.packet[1] = dm3.loadPacketFast(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = dm0.loadPacketStandard(k); + kernel0.packet[1] = dm1.loadPacketStandard(k); + kernel1.packet[0] = dm2.loadPacketStandard(k); + kernel1.packet[1] = dm3.loadPacketStandard(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } + } + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + // Special case for non-vectorized types such as float16. template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device, typename Scalar, typename Index, |