aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_matmul_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-10 12:30:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-10 13:32:14 -0700
commit6b373bb149396beec11347881b2d6dedfbcc83c4 (patch)
tree9eddf82919c16becd71b55243f48a0967e9b5db4 /tensorflow/core/kernels/sparse_matmul_op.cc
parent49d25eae890216f15833adfdd1e668479470745d (diff)
Allow bfloat16 inputs to SparseMatMul.
Change: 121980920
Diffstat (limited to 'tensorflow/core/kernels/sparse_matmul_op.cc')
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.cc1019
1 files changed, 714 insertions, 305 deletions
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index 067aa2624d..6cfaa3b958 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -18,9 +18,11 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/sparse_matmul_op.h"
+
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -35,12 +37,16 @@ namespace tensorflow {
namespace {
+using Eigen::operator==;
typedef Eigen::Tensor<float, 2, Eigen::RowMajor> Matrix;
typedef Eigen::DSizes<Eigen::DenseIndex, 2> DSizes;
-typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>,
- Eigen::Aligned> MatrixMap;
typedef Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor>,
- Eigen::Aligned> ConstMatrixMap;
+ Eigen::Aligned>
+ ConstMatrixMap;
+typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ MatrixMap;
+
typedef Eigen::ThreadPoolDevice CPUDevice;
// Blocksizes
@@ -52,8 +58,7 @@ static const int N = 128;
// This stores a sparse representation of a slice of a matrix with size
// (num_rows, num_cols). The slice is represented as a series of blocks of size
// (num_rows, b), where b = block_size for all but the last block, which may
-// have
-// fewer columns.
+// have fewer columns.
//
// num_rows and block_size are assumed to be <= 256. This allows storing
// different indices as uint8.
@@ -71,7 +76,12 @@ static const int N = 128;
// are the values in the following range:
// [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
// index_offset.
+template <typename T>
struct SparseSlice {
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMap;
+
public:
// Indices of three elements on the same row.
struct Index3 {
@@ -105,13 +115,13 @@ struct SparseSlice {
// See comments above.
std::vector<int> index3_offset;
std::vector<Index3> index3;
- std::vector<float> data3;
+ std::vector<T> data3;
// See comments above. Similar to "index3" except that each element in "index"
// corresponds to one element in data.
std::vector<int> index_offset;
std::vector<Index> index;
- std::vector<float> data;
+ std::vector<T> data;
// Number of rows and columns for the slice.
const int num_rows;
@@ -121,8 +131,10 @@ struct SparseSlice {
const int block_size;
};
+template <typename T>
template <bool Transpose>
-void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
+void SparseSlice<T>::Initialize(const SparseSlice<T>::ConstMatrixMap& mat,
+ int col_offset) {
const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
DCHECK_LE(num_rows, mat_rows);
@@ -142,6 +154,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
Index3 idx3;
Index idx;
int data3_size = 0;
+ static const T zero(0);
for (int i = 0; i < num_blocks; ++i) {
int num_block_cols = std::min(block_size, num_cols - block_size * i);
for (int row = 0; row < num_rows; ++row) {
@@ -150,17 +163,17 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
// *curr is nonzero and then reads it again on use. However, the result
// of the race is only that some of the "nonzeros" in the resulting sparse
// representation may actually be zero, which is harmless.
- const float* start =
+ const auto* start =
Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
- const float* curr = start;
+ const auto* curr = start;
const int stride = Transpose ? mat.dimension(1) : 1;
- const float* end = start + stride * num_block_cols;
+ const auto* end = start + stride * num_block_cols;
uint8 k = 0;
#define NEXT_ELEM \
curr += stride; \
++k;
while (true) {
- while (curr < end && (*curr == 0)) {
+ while (curr < end && (*curr == zero)) {
NEXT_ELEM;
}
if (curr >= end) break;
@@ -168,7 +181,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
data3.push_back(*curr);
NEXT_ELEM;
- while (curr < end && (*curr == 0)) {
+ while (curr < end && (*curr == zero)) {
NEXT_ELEM;
}
if (curr >= end) break;
@@ -176,7 +189,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
data3.push_back(*curr);
NEXT_ELEM;
- while (curr < end && (*curr == 0)) {
+ while (curr < end && (*curr == zero)) {
NEXT_ELEM;
}
if (curr >= end) break;
@@ -213,7 +226,8 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
DCHECK_EQ(index.size(), data.size());
}
-void SparseSlice::Clear() {
+template <typename T>
+void SparseSlice<T>::Clear() {
index3_offset.clear();
index3.clear();
data3.clear();
@@ -222,107 +236,356 @@ void SparseSlice::Clear() {
data.clear();
}
-#define SCALAR_MULADD(a, inp, out) *out++ += *a * *inp++;
-
-#define SCALAR_MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \
- *out++ += *a1 * *inp1++ + *a2 * *inp2++ + *a3 * *inp3++;
-
typedef Eigen::internal::packet_traits<float>::type Packet;
static const int kNumOperands = (sizeof(Packet) / sizeof(float));
#define LOAD(x) Eigen::internal::pload<Packet>(x);
+#define EXPAND_BFLOAT_L(x, y) \
+ const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
+#define EXPAND_BFLOAT_U(x, y) \
+ const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
#define STORE(x, y) Eigen::internal::pstore<float>(x, y);
-#define LOAD_SCALAR(x, y) const auto y = Eigen::internal::pload1<Packet>(x);
#define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
-// Vectorized version of SCALAR_MULADD.
-#define MULADD(a, inp, out) \
- do { \
- const auto b = LOAD(inp); \
- inp += kNumOperands; \
- auto c = LOAD(out); \
- FMA(a, b, c, c); \
- STORE(out, c); \
- out += kNumOperands; \
- } while (false)
-
-// Vectorized version of SCALAR_MULADD3WAY.
-#define MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \
- do { \
- auto c = LOAD(out); \
- const auto b1 = LOAD(inp1); \
- inp1 += kNumOperands; \
- const auto b2 = LOAD(inp2); \
- inp2 += kNumOperands; \
- const auto b3 = LOAD(inp3); \
- inp3 += kNumOperands; \
- FMA(a1, b1, c, c); \
- FMA(a2, b2, c, c); \
- FMA(a3, b3, c, c); \
- STORE(out, c); \
- out += kNumOperands; \
- } while (false)
-
-#ifdef EIGEN_VECTORIZE_AVX2
-// Unroll MULADD3WAY for two iterations
-#define MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out) \
- do { \
- auto c1 = LOAD(out); \
- const auto b1 = LOAD(inp1); \
- const auto b2 = LOAD(inp2); \
- const auto b3 = LOAD(inp3); \
- \
- auto c2 = LOAD(out + kNumOperands); \
- const auto b4 = LOAD(inp1 + kNumOperands); \
- const auto b5 = LOAD(inp2 + kNumOperands); \
- const auto b6 = LOAD(inp3 + kNumOperands); \
- \
- FMA(a1, b1, c1, c1); \
- FMA(a1, b4, c2, c2); \
- FMA(a2, b2, c1, c1); \
- FMA(a2, b5, c2, c2); \
- FMA(a3, b3, c1, c1); \
- FMA(a3, b6, c2, c2); \
- STORE(out, c1); \
- STORE(out + kNumOperands, c2); \
- out += 2 * kNumOperands; \
- inp1 += 2 * kNumOperands; \
- inp2 += 2 * kNumOperands; \
- inp3 += 2 * kNumOperands; \
- } while (false)
-// Further unroll MULADD3WAY.
-#define MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out) \
- MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out);
-#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out);
-#else
-#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \
- for (int __i = 0; __i < 128 / (4 * kNumOperands); ++__i) { \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
+#define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
+
+ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
+ float out = 0;
+ auto tmp = reinterpret_cast<bfloat16*>(&out);
+ tmp[1] = *src;
+ return out;
+}
+
+ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
+ return Eigen::internal::pload4bf16<Packet>(
+ reinterpret_cast<const float*>(src));
+}
+
+ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
+ return Eigen::internal::pload2bf16<Packet>(
+ reinterpret_cast<const float*>(src));
+}
+
+ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
+ **out += a * **inp;
+ ++*inp;
+ ++*out;
+}
+
+ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
+ float** out) {
+ float inp_f = ConvertBfloat16ToFloat(*inp);
+ **out += a * inp_f;
+ ++*inp;
+ ++*out;
+}
+ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
+ const float a3, const bfloat16** inp1,
+ const bfloat16** inp2,
+ const bfloat16** inp3, float** out) {
+ float inp1_f = ConvertBfloat16ToFloat(*inp1);
+ float inp2_f = ConvertBfloat16ToFloat(*inp2);
+ float inp3_f = ConvertBfloat16ToFloat(*inp3);
+ **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
+ ++*out;
+ ++*inp1;
+ ++*inp2;
+ ++*inp3;
+}
+
+ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
+ const float a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
+ ++*out;
+ ++*inp1;
+ ++*inp2;
+ ++*inp3;
+}
+
+ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
+ auto tmp = ConvertBfloat16ToFloat(*data);
+ *l = Eigen::internal::pset1<Packet>(tmp);
+ ++*data;
+}
+
+ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
+ Packet* l2) {
+ if (kNumOperands >= 2) {
+ auto tmp = ConvertTwoBfloat16ToFloat(*data);
+ *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
+ *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
+ *data += 2;
+ } else {
+ LoadSingleScalar(data, l1);
+ LoadSingleScalar(data, l2);
+ }
+}
+
+ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
+ Packet* l2, Packet* l3, Packet* l4) {
+ if (kNumOperands >= 4) {
+ auto tmp = ConvertFourBfloat16ToFloat(*data);
+ *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
+ *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
+ *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
+ *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
+ *data += 4;
+ } else {
+ LoadTwoScalars(data, l1, l2);
+ LoadTwoScalars(data, l3, l4);
+ }
+}
+
+ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
+ *l = Eigen::internal::pload1<Packet>(*data);
+ ++(*data);
+}
+
+ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
+ LoadSingleScalar(data, l1);
+ LoadSingleScalar(data, l2);
+}
+
+ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
+ Packet* l3, Packet* l4) {
+ LoadTwoScalars(data, l1, l2);
+ LoadTwoScalars(data, l3, l4);
+}
+
+template <typename T>
+ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
+ Packet* l3) {
+ LoadTwoScalars(data, l1, l2);
+ LoadSingleScalar(data, l3);
+}
+
+template <typename T>
+ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
+ Packet* l3, Packet* l4, Packet* l5,
+ Packet* l6) {
+ LoadFourScalars(data, l1, l2, l3, l4);
+ LoadTwoScalars(data, l5, l6);
+}
+
+// Vectorized version of ScalarMulAdd.
+ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
+ auto inp = reinterpret_cast<const float*>(*binp);
+ const auto b = LOAD(inp);
+ EXPAND_BFLOAT_L(b, b_0);
+ EXPAND_BFLOAT_U(b, b_1);
+ *binp += 2 * kNumOperands;
+ auto c1 = LOAD(*out);
+ auto c2 = LOAD(*out + kNumOperands);
+ FMA(a, b_0, c1, c1);
+ FMA(a, b_1, c2, c2);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ *out += 2 * kNumOperands;
+}
+
+// Vectorized version of ScalarMulAdd3Way.
+ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
+ const bfloat16** binp1, const bfloat16** binp2,
+ const bfloat16** binp3, float** out) {
+ auto inp1 = reinterpret_cast<const float*>(*binp1);
+ auto inp2 = reinterpret_cast<const float*>(*binp2);
+ auto inp3 = reinterpret_cast<const float*>(*binp3);
+ auto c1 = LOAD(*out);
+ auto c2 = LOAD(*out + kNumOperands);
+ const auto b1 = LOAD(inp1);
+ EXPAND_BFLOAT_L(b1, b1_0);
+ EXPAND_BFLOAT_U(b1, b1_1);
+ *binp1 += 2 * kNumOperands;
+ const auto b2 = LOAD(inp2);
+ EXPAND_BFLOAT_L(b2, b2_0);
+ EXPAND_BFLOAT_U(b2, b2_1);
+ *binp2 += 2 * kNumOperands;
+ const auto b3 = LOAD(inp3);
+ EXPAND_BFLOAT_L(b3, b3_0);
+ EXPAND_BFLOAT_U(b3, b3_1);
+ *binp3 += 2 * kNumOperands;
+ FMA(a1, b1_0, c1, c1);
+ FMA(a1, b1_1, c2, c2);
+ FMA(a2, b2_0, c1, c1);
+ FMA(a2, b2_1, c2, c2);
+ FMA(a3, b3_0, c1, c1);
+ FMA(a3, b3_1, c2, c2);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ *out += 2 * kNumOperands;
+}
+
+// Unroll MulAdd3Way for two iterations
+ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
+ const Packet a3, const bfloat16** binp1,
+ const bfloat16** binp2, const bfloat16** binp3,
+ float** out) {
+ auto inp1 = reinterpret_cast<const float*>(*binp1);
+ auto inp2 = reinterpret_cast<const float*>(*binp2);
+ auto inp3 = reinterpret_cast<const float*>(*binp3);
+ auto c1 = LOAD(*out);
+ auto c2 = LOAD(*out + kNumOperands);
+ const auto b1 = LOAD(inp1);
+ const auto b2 = LOAD(inp2);
+ const auto b3 = LOAD(inp3);
+
+ EXPAND_BFLOAT_L(b1, b1_0);
+ EXPAND_BFLOAT_U(b1, b1_1);
+ EXPAND_BFLOAT_L(b2, b2_0);
+ EXPAND_BFLOAT_U(b2, b2_1);
+ EXPAND_BFLOAT_L(b3, b3_0);
+ EXPAND_BFLOAT_U(b3, b3_1);
+ auto c3 = LOAD(*out + 2 * kNumOperands);
+ auto c4 = LOAD(*out + 3 * kNumOperands);
+ const auto b4 = LOAD(inp1 + kNumOperands);
+ const auto b5 = LOAD(inp2 + kNumOperands);
+ const auto b6 = LOAD(inp3 + kNumOperands);
+
+ EXPAND_BFLOAT_L(b4, b4_0);
+ EXPAND_BFLOAT_U(b4, b4_1);
+ EXPAND_BFLOAT_L(b5, b5_0);
+ EXPAND_BFLOAT_U(b5, b5_1);
+ EXPAND_BFLOAT_L(b6, b6_0);
+ EXPAND_BFLOAT_U(b6, b6_1);
+
+ FMA(a1, b1_0, c1, c1);
+ FMA(a1, b1_1, c2, c2);
+ FMA(a1, b4_0, c3, c3);
+ FMA(a1, b4_1, c4, c4);
+ FMA(a2, b2_0, c1, c1);
+ FMA(a2, b2_1, c2, c2);
+ FMA(a2, b5_0, c3, c3);
+ FMA(a2, b5_1, c4, c4);
+ FMA(a3, b3_0, c1, c1);
+ FMA(a3, b3_1, c2, c2);
+ FMA(a3, b6_0, c3, c3);
+ FMA(a3, b6_1, c4, c4);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ STORE(*out + 2 * kNumOperands, c3);
+ STORE(*out + 3 * kNumOperands, c4);
+ *out += 4 * kNumOperands;
+ *binp1 += 4 * kNumOperands;
+ *binp2 += 4 * kNumOperands;
+ *binp3 += 4 * kNumOperands;
+}
+
+// Apply MulAdd3Way on 128 operands.
+ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
+ const Packet a3, const bfloat16** inp1,
+ const bfloat16** inp2, const bfloat16** inp3,
+ float** out) {
+ for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
}
-#endif
+}
+// Vectorized version of ScalarMulAdd
+ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
+ const auto b = LOAD(*inp);
+ *inp += kNumOperands;
+ auto c = LOAD(*out);
+ FMA(a, b, c, c);
+ STORE(*out, c);
+ *out += kNumOperands;
+}
+
+// Vectorized version of ScalarMulAdd3Way
+ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
+ const float** inp1, const float** inp2,
+ const float** inp3, float** out) {
+ auto c = LOAD(*out);
+ const auto b1 = LOAD(*inp1);
+ *inp1 += kNumOperands;
+ const auto b2 = LOAD(*inp2);
+ *inp2 += kNumOperands;
+ const auto b3 = LOAD(*inp3);
+ *inp3 += kNumOperands;
+ FMA(a1, b1, c, c);
+ FMA(a2, b2, c, c);
+ FMA(a3, b3, c, c);
+ STORE(*out, c);
+ *out += kNumOperands;
+}
+
+// Unroll MulAdd3Way for two iterations
+ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
+ const Packet a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ auto c1 = LOAD(*out);
+ const auto b1 = LOAD(*inp1);
+ const auto b2 = LOAD(*inp2);
+ const auto b3 = LOAD(*inp3);
+
+ auto c2 = LOAD(*out + kNumOperands);
+ const auto b4 = LOAD(*inp1 + kNumOperands);
+ const auto b5 = LOAD(*inp2 + kNumOperands);
+ const auto b6 = LOAD(*inp3 + kNumOperands);
+
+ FMA(a1, b1, c1, c1);
+ FMA(a1, b4, c2, c2);
+ FMA(a2, b2, c1, c1);
+ FMA(a2, b5, c2, c2);
+ FMA(a3, b3, c1, c1);
+ FMA(a3, b6, c2, c2);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ *out += 2 * kNumOperands;
+ *inp1 += 2 * kNumOperands;
+ *inp2 += 2 * kNumOperands;
+ *inp3 += 2 * kNumOperands;
+}
+
+// Unroll MulAdd3Way for four iterations
+ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
+ const Packet a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+}
+
+// Apply MulAdd3Way on 128 operands.
+ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
+ const Packet a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ if (kNumOperands == 8) {
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ } else {
+ DCHECK_LE(4 * kNumOperands, 128);
+ for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ }
+ }
+}
// Computes product of "left_slices" with "num_cols" columns of "right", and
// stores the output in *"output".
// Note that left_slices is a list of SparseSlices, which are conceptually
// assumed to be concatenated along the column dimension. Also each SparseSlice
// is encoded as a list of blocks with upto N columns. See SparseSlice for more
// details.
-template <int Cols>
-inline void GEPP(const std::vector<SparseSlice*>& left_slices,
- const ConstMatrixMap& right, const int num_cols,
- Matrix* output) {
+template <typename TL, typename TR, int Cols>
+inline void GEPP(
+ const std::vector<SparseSlice<TL>*>& left_slices,
+ const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>& right,
+ const int num_cols, Matrix* output) {
const int cols = (Cols == -1) ? num_cols : Cols;
DCHECK_EQ(num_cols, cols);
const int right_num_cols = right.dimension(1);
const int output_num_cols = output->dimension(1);
- const int cols_mod = cols % kNumOperands;
+ static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
+ const int cols_mod = cols % kNumOperandsR;
int k_offset = 0;
// Pre-compute pointers for output matrix.
float* out_ptrs[M];
@@ -332,15 +595,15 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices,
}
for (const auto* left_slice : left_slices) {
const auto& left = *left_slice;
- const float* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr;
- const float* data = (left.data.size() > 0) ? &left.data[0] : nullptr;
+ const auto* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr;
+ const auto* data = (left.data.size() > 0) ? &left.data[0] : nullptr;
const int num_blocks = left.index3_offset.size();
int begin3 = 0;
int begin = 0;
for (int i = 0; i < num_blocks; ++i) {
// Pre-compute pointers for right matrix
- const float* right_ptrs[K];
- const float* const right_start = &right(k_offset, 0);
+ const TR* right_ptrs[K];
+ const auto* const right_start = &right(k_offset, 0);
DCHECK_LT(k_offset, right.dimension(0));
for (int j = 0; j < K; ++j) {
right_ptrs[j] = right_start + right_num_cols * j;
@@ -350,79 +613,116 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices,
int j = begin3;
// Loop unrolled for 2 iterations.
for (; j + 1 < end3; j += 2) {
- const float* sl1 = data3++;
- LOAD_SCALAR(sl1, l1);
- const float* sl2 = data3++;
- LOAD_SCALAR(sl2, l2);
- const float* sl3 = data3++;
- LOAD_SCALAR(sl3, l3);
- const float* nsl1 = data3++;
- LOAD_SCALAR(nsl1, nl1);
- const float* nsl2 = data3++;
- LOAD_SCALAR(nsl2, nl2);
- const float* nsl3 = data3++;
- LOAD_SCALAR(nsl3, nl3);
- const SparseSlice::Index3& index = left.index3[j];
- const SparseSlice::Index3& nindex = left.index3[j + 1];
+ Packet l1, l2, l3, nl1, nl2, nl3;
+ LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
+ const auto& index = left.index3[j];
+ const auto& nindex = left.index3[j + 1];
float* out = out_ptrs[index.m];
float* nout = out_ptrs[nindex.m];
- const float* r1 = right_ptrs[index.k1];
- const float* r2 = right_ptrs[index.k2];
- const float* r3 = right_ptrs[index.k3];
- const float* nr1 = right_ptrs[nindex.k1];
- const float* nr2 = right_ptrs[nindex.k2];
- const float* nr3 = right_ptrs[nindex.k3];
+ const auto* r1 = right_ptrs[index.k1];
+ const auto* r2 = right_ptrs[index.k2];
+ const auto* r3 = right_ptrs[index.k3];
+
+ const auto* nr1 = right_ptrs[nindex.k1];
+ const auto* nr2 = right_ptrs[nindex.k2];
+ const auto* nr3 = right_ptrs[nindex.k3];
if (cols == 128) {
- MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out);
- MULADD3WAY_128(nl1, nl2, nl3, nr1, nr2, nr3, nout);
+ MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
+ MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
} else {
- for (int n = 0; n < cols / kNumOperands; ++n) {
- MULADD3WAY(l1, l2, l3, r1, r2, r3, out);
- MULADD3WAY(nl1, nl2, nl3, nr1, nr2, nr3, nout);
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
+ MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
}
+
+ const float sl1 = Eigen::internal::pfirst<Packet>(l1);
+ const float sl2 = Eigen::internal::pfirst<Packet>(l2);
+ const float sl3 = Eigen::internal::pfirst<Packet>(l3);
+ const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
+ const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
+ const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
for (int k = 0; k < cols_mod; ++k) {
- SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out);
- SCALAR_MULADD3WAY(nsl1, nsl2, nsl3, nr1, nr2, nr3, nout);
+ ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
+ ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
}
}
}
if (j < end3) {
- const float* sl1 = data3++;
- LOAD_SCALAR(sl1, l1);
- const float* sl2 = data3++;
- LOAD_SCALAR(sl2, l2);
- const float* sl3 = data3++;
- LOAD_SCALAR(sl3, l3);
- const SparseSlice::Index3& index = left.index3[j];
+ Packet l1, l2, l3;
+ LoadThreeScalars(&data3, &l1, &l2, &l3);
+
+ const auto& index = left.index3[j];
float* out = out_ptrs[index.m];
- const float* r1 = right_ptrs[index.k1];
- const float* r2 = right_ptrs[index.k2];
- const float* r3 = right_ptrs[index.k3];
+ const auto* r1 = right_ptrs[index.k1];
+ const auto* r2 = right_ptrs[index.k2];
+ const auto* r3 = right_ptrs[index.k3];
if (cols == 128) {
- MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out);
+ MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
} else {
- for (int n = 0; n < cols / kNumOperands; ++n) {
- MULADD3WAY(l1, l2, l3, r1, r2, r3, out);
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
}
+ const float sl1 = Eigen::internal::pfirst<Packet>(l1);
+ const float sl2 = Eigen::internal::pfirst<Packet>(l2);
+ const float sl3 = Eigen::internal::pfirst<Packet>(l3);
for (int k = 0; k < cols_mod; ++k) {
- SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out);
+ ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
}
}
}
begin3 = end3;
int end = left.index_offset[i];
- for (int j = begin; j < end; ++j) {
- const float* sl = data++;
- LOAD_SCALAR(sl, l);
- const SparseSlice::Index& index = left.index[j];
- const float* r = right_ptrs[index.k];
+ // Loop unrolled for 4 iterations.
+ j = begin;
+ for (; j + 3 < end; j += 4) {
+ Packet l, nl, n2l, n3l;
+ LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
+
+ const auto& index = left.index[j];
+ const auto& nindex = left.index[j + 1];
+ const auto& n2index = left.index[j + 2];
+ const auto& n3index = left.index[j + 3];
+ const auto* r = right_ptrs[index.k];
+ const auto* nr = right_ptrs[nindex.k];
+ const auto* n2r = right_ptrs[n2index.k];
+ const auto* n3r = right_ptrs[n3index.k];
+ float* out = out_ptrs[index.m];
+ float* nout = out_ptrs[nindex.m];
+ float* n2out = out_ptrs[n2index.m];
+ float* n3out = out_ptrs[n3index.m];
+
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd(l, &r, &out);
+ MulAdd(nl, &nr, &nout);
+ MulAdd(n2l, &n2r, &n2out);
+ MulAdd(n3l, &n3r, &n3out);
+ }
+
+ const float sl1 = Eigen::internal::pfirst<Packet>(l);
+ const float sl2 = Eigen::internal::pfirst<Packet>(nl);
+ const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
+ const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
+ for (int k = 0; k < cols_mod; ++k) {
+ ScalarMulAdd(sl1, &r, &out);
+ ScalarMulAdd(sl2, &nr, &nout);
+ ScalarMulAdd(sl3, &n2r, &n2out);
+ ScalarMulAdd(sl4, &n3r, &n3out);
+ }
+ }
+ while (j < end) {
+ Packet l;
+ LoadSingleScalar(&data, &l);
+ const auto& index = left.index[j];
+ const auto* r = right_ptrs[index.k];
float* out = out_ptrs[index.m];
- for (int n = 0; n < cols / kNumOperands; ++n) {
- MULADD(l, r, out);
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd(l, &r, &out);
}
+ const float sl = Eigen::internal::pfirst<Packet>(l);
for (int k = 0; k < cols_mod; ++k) {
- SCALAR_MULADD(sl, r, out);
+ ScalarMulAdd(sl, &r, &out);
}
+ j++;
}
k_offset += left.block_size;
begin = end;
@@ -430,21 +730,101 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices,
}
}
-#undef SCALAR_MULADD
-#undef SCALAR_MULADD3WAY
#undef LOAD
+#undef EXPAND_BFLOAT_L
+#undef EXPAND_BFLOAT_U
#undef STORE
-#undef LOAD_SCALAR
#undef FMA
-#undef MULADD
-#undef MULADD3WAY
-#undef MULADD3WAY_16
-#undef MULADD3WAY_32
-#undef MULADD3WAY_128
} // namespace
+template <typename TL, typename TR>
+class SparseMatMul {
+ typedef Eigen::Tensor<TL, 2, Eigen::RowMajor> MatrixL;
+ typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR;
+ typedef Eigen::TensorMap<Eigen::Tensor<const TL, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMapL;
+ typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMapR;
+ typedef Eigen::TensorMap<Eigen::Tensor<TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ MatrixMapR;
+ // Perform matrix multiplication of "left" and "right", and store the result
+ // in *"output".
+ public:
+ static inline void Compute(const ConstMatrixMapL& left,
+ const ConstMatrixMapR& right, bool transpose_left,
+ const DeviceBase::CpuWorkerThreads* thread_pool,
+ bool transpose_output, MatrixMap* output);
+
+ private:
+ // Computes multiplication of left and num_cols columns of right, and stores
+ // the output block in *"output" at offsets "output_row_offset" and
+ // "output_col_offset". If assign is true, assigns the value to that block,
+ // else adds the values to the existing values.
+ static inline void ComputeOutputBlock(
+ const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
+ int num_cols, int output_row_offset, int output_col_offset, bool assign,
+ bool transpose_output, MatrixMap* output);
+
+ // Encodes "mat" using a sparse representation and stores that in
+ // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
+ // "slice_num_cols", each grid element is converted into a SparseSlice and
+ // stored in mat_slices. "slice_block_size" is used to perform further column
+ // blocking of each slice.
+ static inline BlockingCounter* CreateSparseSlices(
+ const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
+ int slice_block_size, int slice_num_cols,
+ std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
+ const DeviceBase::CpuWorkerThreads* thread_pool);
+
+ // This function chops "mat" along column dimension into pieces with at most N
+ // columns, and concatenates the pieces one after the other in "buffer". It
+ // returns the list of the pieces in "slices". It returns a BlockingCounter
+ // which should be used to wait for the shuffle operations to complete.
+ static inline BlockingCounter* CreateDenseSlices(
+ const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
+ int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
+ MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
+
+ // Helper function for CreateDenseSlices to move the data around. It returns a
+ // BlockingCounter which should be used to wait for the shuffle operations to
+ // complete.
+ static inline BlockingCounter* ShuffleMatrix(
+ const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
+ int slice_col_start, int slice_num_cols, const int N,
+ const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
+
+ // Helper function for CreateDenseSlices to create slices.
+ static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
+ const int num_slices,
+ std::vector<ConstMatrixMapR*>* slices);
+
+ // Heuristics to compute various block sizes.
+ // KR, NR: block sizes for "right". We run blocking iterations that operate on
+ // matrices with at most this size.
+ // KL: grid size along the column dimension used while encoding left.
+ // IB, JB: number of left and right slices to multiply together. This is used
+ // for ordering different ComputeBlockOutput operations inside each blocking
+ // iteration so as to potentially reduce the working set size.
+ static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
+ const ConstMatrixMapR& right,
+ bool transpose_left, int num_threads,
+ int* KR, int* NR, int* KL, int* JB,
+ int* IB);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
+};
+
+template <typename TL, typename TR>
class SparseMatMulOp : public OpKernel {
+ typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR;
+ typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMapR;
+
public:
explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
@@ -461,12 +841,10 @@ class SparseMatMulOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
errors::InvalidArgument("b is not a matrix"));
- auto left = a.matrix<float>();
- auto right = b.matrix<float>();
- const int m = transpose_a_ ? left.dimension(1) : left.dimension(0);
- const int k = transpose_a_ ? left.dimension(0) : left.dimension(1);
- const int n = transpose_b_ ? right.dimension(0) : right.dimension(1);
- const int k2 = transpose_b_ ? right.dimension(1) : right.dimension(0);
+ const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
+ const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
+ const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
+ const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
OP_REQUIRES(ctx, k == k2,
errors::InvalidArgument("Matrix size incompatible: a: ",
@@ -476,127 +854,94 @@ class SparseMatMulOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
auto out = output->matrix<float>();
+ std::unique_ptr<Tensor> a_float;
+ std::unique_ptr<Tensor> b_float;
if (!a_is_sparse_ && !b_is_sparse_) {
- // Fallback to Eigen contract.
- // Note that we currently don't optimize the case where only right is
- // sparse. That can generally be handled by transposing the order of the
- // matmul.
+ auto left = &a;
+ auto right = &b;
+ // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
+ if (std::is_same<TL, bfloat16>::value) {
+ a_float.reset(new Tensor(DT_FLOAT, a.shape()));
+ BFloat16ToFloat(a.flat<bfloat16>().data(),
+ a_float->flat<float>().data(), a.NumElements());
+ left = a_float.get();
+ }
+ if (std::is_same<TR, bfloat16>::value) {
+ b_float.reset(new Tensor(DT_FLOAT, b.shape()));
+ BFloat16ToFloat(b.flat<bfloat16>().data(),
+ b_float->flat<float>().data(), b.NumElements());
+ right = b_float.get();
+ }
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0].first = transpose_a_ ? 0 : 1;
dim_pair[0].second = transpose_b_ ? 1 : 0;
+
out.device(ctx->template eigen_device<CPUDevice>()) =
- left.contract(right, dim_pair);
+ left->matrix<float>().contract(right->matrix<float>(), dim_pair);
return;
}
- auto left_mat = &left;
- auto right_mat = &right;
+
+ auto left = &a;
+ auto right = &b;
bool transpose_output = false;
bool transpose_a = transpose_a_;
bool transpose_b = transpose_b_;
if (!a_is_sparse_) {
// Swap the order of multiplications using the identity:
// A * B = (B' * A')'.
- std::swap(left_mat, right_mat);
+ std::swap(left, right);
std::swap(transpose_a, transpose_b);
transpose_a = !transpose_a;
transpose_b = !transpose_b;
transpose_output = !transpose_output;
}
- std::unique_ptr<Matrix> right_tr_mat;
- std::unique_ptr<TTypes<float>::ConstMatrix> right_tr_map;
+
+ std::unique_ptr<Tensor> right_tr;
if (transpose_b) {
// TODO(agarwal): avoid transposing the matrix here and directly handle
// transpose in CreateDenseSlices.
- right_tr_mat.reset(
- new Matrix(right_mat->dimension(1), right_mat->dimension(0)));
+ right_tr.reset(
+ new Tensor(right->dtype(),
+ TensorShape({right->dim_size(1), right->dim_size(0)})));
+
Eigen::array<int, 2> perm({1, 0});
- right_tr_mat->device(ctx->template eigen_device<CPUDevice>()) =
- right_mat->shuffle(perm);
- right_tr_map.reset(new TTypes<float>::ConstMatrix(
- right_tr_mat->data(), right_tr_mat->dimensions()));
- right_mat = right_tr_map.get();
+ if (transpose_output) {
+ right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
+ right->matrix<TL>().shuffle(perm);
+ } else {
+ right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
+ right->matrix<TR>().shuffle(perm);
+ }
+ right = right_tr.get();
}
- SparseMatMul(*left_mat, *right_mat, transpose_a,
- ctx->device()->tensorflow_cpu_worker_threads(),
- transpose_output, &out);
+ if (transpose_output) {
+ SparseMatMul<TR, TL>::Compute(
+ left->matrix<TR>(), right->matrix<TL>(), transpose_a,
+ ctx->device()->tensorflow_cpu_worker_threads(), transpose_output,
+ &out);
+ } else {
+ SparseMatMul<TL, TR>::Compute(
+ left->matrix<TL>(), right->matrix<TR>(), transpose_a,
+ ctx->device()->tensorflow_cpu_worker_threads(), transpose_output,
+ &out);
+ }
}
private:
- // Perform matrix multiplication of "left" and "right", and store the result
- // in *"output".
- static inline void SparseMatMul(
- const ConstMatrixMap& left, const ConstMatrixMap& right,
- bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
- bool transpose_output, MatrixMap* output);
-
- // Computes multiplication of left and num_cols columns of right, and stores
- // the output block in *"output" at offsets "output_row_offset" and
- // "output_col_offset". If assign is true, assigns the value to that block,
- // else adds the values to the existing values.
- static inline void ComputeOutputBlock(const std::vector<SparseSlice*>& left,
- const ConstMatrixMap& right,
- int num_cols, int output_row_offset,
- int output_col_offset, bool assign,
- bool transpose_output,
- MatrixMap* output);
-
- // Encodes "mat" using a sparse representation and stores that in
- // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
- // "slice_num_cols", each grid element is converted into a SparseSlice and
- // stored in mat_slices. "slice_block_size" is used to perform further column
- // blocking of each slice.
- static inline BlockingCounter* CreateSparseSlices(
- const ConstMatrixMap& mat, bool transpose, int slice_num_rows,
- int slice_block_size, int slice_num_cols,
- std::vector<std::vector<SparseSlice*>>* mat_slices,
- const DeviceBase::CpuWorkerThreads* thread_pool);
-
- // This function chops "mat" along column dimension into pieces with at most N
- // columns, and concatenates the pieces one after the other in "buffer". It
- // returns the list of the pieces in "slices". It returns a BlockingCounter
- // which should be used to wait for the shuffle operations to complete.
- static inline BlockingCounter* CreateDenseSlices(
- const ConstMatrixMap& mat, int row_start, int num_rows, int col_start,
- int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
- Matrix* buffer, std::vector<ConstMatrixMap*>* slices);
-
- // Helper function for CreateDenseSlices to move the data around. It returns a
- // BlockingCounter which should be used to wait for the shuffle operations to
- // complete.
- static inline BlockingCounter* ShuffleMatrix(
- const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows,
- int slice_col_start, int slice_num_cols, const int N,
- const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer);
-
- // Helper function for CreateDenseSlices to create slices.
- static inline void SliceMatrix(const Matrix& mat, const int num_rows,
- const int num_slices,
- std::vector<ConstMatrixMap*>* slices);
-
- // Heuristics to compute various block sizes.
- // KR, NR: block sizes for "right". We run blocking iterations that operate on
- // matrices with at most this size.
- // KL: grid size along the column dimension used while encoding left.
- // IB, JB: number of left and right slices to multiply together. This is used
- // for ordering different ComputeBlockOutput operations inside each blocking
- // iteration so as to potentially reduce the working set size.
- static inline void ComputeBlockSizes(const ConstMatrixMap& left,
- const ConstMatrixMap& right,
- bool transpose_left, int num_threads,
- int* KR, int* NR, int* KL, int* JB,
- int* IB);
-
bool transpose_a_;
bool transpose_b_;
bool a_is_sparse_;
bool b_is_sparse_;
+
TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
};
-inline void SparseMatMulOp::ComputeOutputBlock(
- const std::vector<SparseSlice*>& left, const ConstMatrixMap& right,
- int num_cols, int output_row_offset, int output_col_offset, bool assign,
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
+ const std::vector<SparseSlice<TL>*>& left,
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
+ int output_row_offset, int output_col_offset, bool assign,
bool transpose_output, MatrixMap* output) {
static const Eigen::array<int, 2> perm({1, 0});
int num_rows = left[0]->num_rows;
@@ -605,9 +950,9 @@ inline void SparseMatMulOp::ComputeOutputBlock(
Matrix out(num_rows, rhs_num_cols);
out.setZero();
if (num_cols == N) {
- GEPP<N>(left, right, num_cols, &out);
+ GEPP<TL, TR, N>(left, right, num_cols, &out);
} else {
- GEPP<-1>(left, right, num_cols, &out);
+ GEPP<TL, TR, -1>(left, right, num_cols, &out);
}
if (!assign) {
const Eigen::array<int, 2> begin = {output_row_offset, output_col_offset};
@@ -643,10 +988,11 @@ inline void SparseMatMulOp::ComputeOutputBlock(
}
}
-inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
- const ConstMatrixMap& mat, bool transpose, int slice_num_rows,
- int slice_block_size, int slice_num_cols,
- std::vector<std::vector<SparseSlice*>>* mat_slices,
+template <typename TL, typename TR>
+inline BlockingCounter* SparseMatMul<TL, TR>::CreateSparseSlices(
+ const SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
+ int slice_num_rows, int slice_block_size, int slice_num_cols,
+ std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
const DeviceBase::CpuWorkerThreads* thread_pool) {
const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
@@ -657,12 +1003,13 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
mat_slices->resize(num_slices_dim0);
BlockingCounter* counter =
new BlockingCounter(num_slices_dim0 * num_slices_dim1);
- auto work = [counter, transpose](SparseSlice* sparse_slice,
- ConstMatrixMap* slice, int col_offset) {
+ auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
+ SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
+ int col_offset) {
if (transpose) {
- sparse_slice->Initialize<true>(*slice, col_offset);
+ sparse_slice->template Initialize<true>(*slice, col_offset);
} else {
- sparse_slice->Initialize<false>(*slice, col_offset);
+ sparse_slice->template Initialize<false>(*slice, col_offset);
}
delete slice;
counter->DecrementCount();
@@ -674,16 +1021,17 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
for (int j = 0; j < num_slices_dim1; ++j) {
int num_cols =
std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
- ConstMatrixMap* slice = nullptr;
+ SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
if (transpose) {
- slice =
- new ConstMatrixMap(&mat(0, i * slice_num_rows), mat.dimensions());
+ slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
+ &mat(0, i * slice_num_rows), mat.dimensions());
} else {
DSizes d(num_rows, mat_num_cols);
- slice = new ConstMatrixMap(&mat(i * slice_num_rows, 0), d);
+ slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
+ &mat(i * slice_num_rows, 0), d);
}
- SparseSlice* sparse_slice =
- new SparseSlice(num_rows, num_cols, slice_block_size);
+ auto* sparse_slice =
+ new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
(*mat_slices)[i][j] = sparse_slice;
thread_pool->workers->Schedule(
std::bind(work, sparse_slice, slice, slice_num_cols * j));
@@ -691,11 +1039,58 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
}
return counter;
}
+#define LOAD(x) Eigen::internal::ploadu<Packet>((x));
+#define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
+#define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
+
+template <int NUM_ELEM = -1>
+ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
+ int num_elements) {
+ DCHECK_GE(kNumOperands, 8);
+ static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
+ const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
+ DCHECK_EQ(num, num_elements);
+ const float* src = reinterpret_cast<const float*>(bsrc);
+ float* dst = reinterpret_cast<float*>(bdst);
+ for (int index = 0; index + kStep <= num; index += kStep) {
+ auto in = LOAD(src);
+ auto tmp = INTERLEAVE(in);
+ STORE(dst, tmp);
+ src += kNumOperands;
+ dst += kNumOperands;
+ }
+ if (num % kStep != 0) {
+ memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
+ }
+}
-inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
- const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows,
- int slice_col_start, int slice_num_cols, const int N,
- const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer) {
+template <typename T>
+ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
+ int num_elements) {
+ if (std::is_same<T, float>::value || kNumOperands < 8) {
+ memcpy(dst, src, num_elements * sizeof(T));
+ } else if (std::is_same<T, bfloat16>::value) {
+ if (num_elements == N) {
+ CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
+ } else {
+ CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
+ }
+ } else {
+ LOG(FATAL) << "Unsupported type";
+ }
+}
+
+#undef LOAD
+#undef Interleave
+#undef Store
+
+template <typename TL, typename TR>
+inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int slice_row_start,
+ int slice_num_rows, int slice_col_start, int slice_num_cols, const int N,
+ const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
+ DCHECK_EQ(N % 2, 0);
+ DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
int num_threads = std::min(thread_pool->num_threads, 16);
BlockingCounter* counter = new BlockingCounter(num_threads);
DCHECK_EQ(N, buffer->dimension(1));
@@ -703,17 +1098,17 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
slice_num_cols, N, buffer, counter](int s, int e) {
const int row_start = s % slice_num_rows + slice_row_start;
const int col_start = s / slice_num_rows * N + slice_col_start;
- float* out_start = &(*buffer)(s, 0);
- const float* input_start = &mat(row_start, col_start);
- const float* input_end = &mat(slice_row_start + slice_num_rows - 1,
- slice_col_start + slice_num_cols - 1);
+ auto* out_start = &(*buffer)(s, 0);
+ const auto* input_start = &mat(row_start, col_start);
+ const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
+ slice_col_start + slice_num_cols - 1);
const int mat_num_cols = mat.dimension(1);
const int row_slice_size = slice_num_rows * mat_num_cols;
const int aligned_end = slice_num_cols / N * slice_num_rows;
const int e1 = std::min(e, aligned_end);
while (s < e1) {
- memcpy(out_start, input_start, N * sizeof(float));
+ CopyAndMayBeInterleave<TR>(out_start, input_start, N);
out_start += N;
input_start += mat_num_cols;
if (input_start > input_end) {
@@ -724,7 +1119,7 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
int s1 = std::max(s, aligned_end);
const int copy_num_cols = slice_num_cols % N;
while (s1 < e) {
- memcpy(out_start, input_start, copy_num_cols * sizeof(float));
+ CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
out_start += N;
input_start += mat_num_cols;
++s1;
@@ -745,21 +1140,24 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
return counter;
}
-inline void SparseMatMulOp::SliceMatrix(const Matrix& mat, const int num_rows,
- const int num_slices,
- std::vector<ConstMatrixMap*>* slices) {
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::SliceMatrix(
+ const MatrixR& mat, const int num_rows, const int num_slices,
+ std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
slices->resize(num_slices);
DSizes d(num_rows, mat.dimension(1));
DCHECK_LE(num_rows * num_slices, mat.dimension(0));
for (int i = 0; i < num_slices; ++i) {
- (*slices)[i] = new ConstMatrixMap(&mat(i * num_rows, 0), d);
+ (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
}
}
-inline BlockingCounter* SparseMatMulOp::CreateDenseSlices(
- const ConstMatrixMap& mat, int row_start, int num_rows, int col_start,
- int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
- Matrix* buffer, std::vector<ConstMatrixMap*>* slices) {
+template <typename TL, typename TR>
+inline BlockingCounter* SparseMatMul<TL, TR>::CreateDenseSlices(
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
+ int num_rows, int col_start, int num_cols,
+ const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
+ std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
BlockingCounter* shuffle_counter = ShuffleMatrix(
mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer);
const int num_slices = (num_cols + N - 1) / N;
@@ -767,11 +1165,11 @@ inline BlockingCounter* SparseMatMulOp::CreateDenseSlices(
return shuffle_counter;
}
-inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left,
- const ConstMatrixMap& right,
- bool transpose_left,
- int num_threads, int* KR, int* NR,
- int* KL, int* JB, int* IB) {
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
+ const SparseMatMul<TL, TR>::ConstMatrixMapL& left,
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
+ int num_threads, int* KR, int* NR, int* KL, int* JB, int* IB) {
// Heuristics for calculating block sizes
// Assume two hyperthreads per core.
const int est_num_cores = std::max(1, (num_threads + 1) / 2);
@@ -838,20 +1236,21 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left,
// and update the output block o_ij. These calls are further blocked to
// reduce the working set size. In each iteration we take IB elements from
// {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
-inline void SparseMatMulOp::SparseMatMul(
- const ConstMatrixMap& left, const ConstMatrixMap& right,
- bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
- bool transpose_output, MatrixMap* output) {
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::Compute(
+ const SparseMatMul<TL, TR>::ConstMatrixMapL& left,
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
+ const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output,
+ MatrixMap* output) {
const int num_threads = thread_pool->num_threads;
int KR, NR, KL, JB, IB;
ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
&JB, &IB);
-
// Slice the left matrix
- std::vector<std::vector<SparseSlice*>> left_slices;
+ std::vector<std::vector<SparseSlice<TL>*>> left_slices;
std::unique_ptr<BlockingCounter> sparse_slice_counter;
sparse_slice_counter.reset(
- CreateSparseSlices(ConstMatrixMap(left.data(), left.dimensions()),
+ CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
transpose_left, M, K, KL, &left_slices, thread_pool));
const int num_left_slices = left_slices.size();
@@ -862,10 +1261,10 @@ inline void SparseMatMulOp::SparseMatMul(
// is the block size per iteration.
const int buffer_num_rows =
std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
- Matrix buffer(buffer_num_rows, N);
- std::vector<ConstMatrixMap*> right_slices;
+ MatrixR buffer(buffer_num_rows, N);
+ std::vector<ConstMatrixMapR*> right_slices;
- std::vector<SparseSlice*> block_left_slices;
+ std::vector<SparseSlice<TL>*> block_left_slices;
std::vector<std::function<void(void)>> tasks;
// Number of blocks based on block sizes of KR * NR.
const int num_k_blocks = (right_dim0 + KR - 1) / KR;
@@ -890,7 +1289,6 @@ inline void SparseMatMulOp::SparseMatMul(
const int num_cols = std::min(N, right_num_cols - N * j_inner);
for (int i_inner = i_outer;
i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
- // Figure out which left slices to use.
block_left_slices.clear();
int begin = kb * KR / KL;
int end = std::min<int>((kb + 1) * KR / KL,
@@ -900,7 +1298,7 @@ inline void SparseMatMulOp::SparseMatMul(
left_slices[i_inner].begin() + begin,
left_slices[i_inner].begin() + end);
tasks.push_back(std::bind(
- &SparseMatMulOp::ComputeOutputBlock, block_left_slices,
+ &ComputeOutputBlock, block_left_slices,
std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
N * j_inner + nb * NR, kb == 0, transpose_output, output));
}
@@ -933,7 +1331,18 @@ inline void SparseMatMulOp::SparseMatMul(
}
}
-REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU),
- SparseMatMulOp);
+#define REGISTER_SPARSE_MATMUL(TA, TB) \
+ REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TA>("Ta") \
+ .TypeConstraint<TB>("Tb"), \
+ SparseMatMulOp<TA, TB>);
+
+REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
+REGISTER_SPARSE_MATMUL(float, bfloat16);
+REGISTER_SPARSE_MATMUL(bfloat16, float);
+REGISTER_SPARSE_MATMUL(float, float);
+
+#undef REGISTER_SPARSE_MATMUL
} // end namespace tensorflow