aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/eigen_spatial_convolutions.h
diff options
context:
space:
mode:
authorGravatar Brian Patton <bjp@google.com>2018-03-22 14:11:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 14:15:25 -0700
commite3468b56d323783fdfb79fa2d6c24effc58bcaa9 (patch)
treeec0cb8d1fb766c591b8f66808c283644d6d9cb2d /tensorflow/core/kernels/eigen_spatial_convolutions.h
parent48b0fb7a524425d57547dc23093d869538b888db (diff)
Adds float64 support for Conv2d, Conv2dBackpropInput, and Conv2dBackpropFilter
PiperOrigin-RevId: 190123191
Diffstat (limited to 'tensorflow/core/kernels/eigen_spatial_convolutions.h')
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h182
1 files changed, 182 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index 1acbe3a658..a4dff4b91c 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -797,6 +797,188 @@ struct gemm_pack_rhs<
}
};
+// Template specialization for packet_size = 2. We must special-case packet
+// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
+template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
+ typename ArgType, typename Device, typename Scalar, typename Index,
+ typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
+ bool inner_dim_reordered, int Alignment, int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<
+ const TensorReshapingOp<
+ NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
+ Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<
+ const TensorReshapingOp<
+ NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
+ Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_DEVICE_FUNC
+ static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ const int packet_size = 2;
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if (!non_standard_patches) {
+ const Index patch_depth = rhs.patchDepth();
+ if ((patch_depth % packet_size) == 0) {
+ const Index patch_cols = rhs.patchCols();
+ const Index patch_rows = rhs.patchRows();
+
+ const Index startCol = rhs.colOffset();
+ const Index max_cols = std::min<Index>(
+ ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
+ patch_cols);
+
+ for (Index c = startCol; c < max_cols; ++c) {
+ eigen_assert(k < peeled_k);
+ const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+ const Index max_rows = std::min<Index>(
+ ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
+ startRow,
+ patch_rows);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+ for (Index r = startRow; r < max_rows; ++r) {
+ eigen_assert(k < peeled_k);
+ const bool pad0 = pad_col0 || dm0.padRow(r);
+ const bool pad1 = pad_col1 || dm1.padRow(r);
+ const bool pad2 = pad_col2 || dm2.padRow(r);
+ const bool pad3 = pad_col3 || dm3.padRow(r);
+
+ const Index idx0 = dm0.baseIndex(r, c);
+ const Index idx1 = dm1.baseIndex(r, c);
+ const Index idx2 = dm2.baseIndex(r, c);
+ const Index idx3 = dm3.baseIndex(r, c);
+
+ const Index startDepth =
+ ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
+ const Index max_depth =
+ std::min<Index>(peeled_k - c * patch_rows * patch_depth -
+ r * patch_depth + startDepth,
+ patch_depth);
+ eigen_assert((max_depth - startDepth) % packet_size == 0);
+ for (Index d = startDepth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketFast(k);
+ kernel0.packet[1] = dm1.loadPacketFast(k);
+ kernel1.packet[0] = dm2.loadPacketFast(k);
+ kernel1.packet[1] = dm3.loadPacketFast(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketStandard(k);
+ kernel0.packet[1] = dm1.loadPacketStandard(k);
+ kernel1.packet[0] = dm2.loadPacketStandard(k);
+ kernel1.packet[1] = dm3.loadPacketStandard(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
// Special case for non-vectorized types such as float16.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar, typename Index,