diff options
author | 2016-05-10 12:30:20 -0800 | |
---|---|---|
committer | 2016-05-10 13:32:14 -0700 | |
commit | 6b373bb149396beec11347881b2d6dedfbcc83c4 (patch) | |
tree | 9eddf82919c16becd71b55243f48a0967e9b5db4 /tensorflow/core/kernels/sparse_matmul_op.cc | |
parent | 49d25eae890216f15833adfdd1e668479470745d (diff) |
Allow bfloat16 inputs to SparseMatMul.
Change: 121980920
Diffstat (limited to 'tensorflow/core/kernels/sparse_matmul_op.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_matmul_op.cc | 1019 |
1 files changed, 714 insertions, 305 deletions
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 067aa2624d..6cfaa3b958 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -18,9 +18,11 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/sparse_matmul_op.h" + #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -35,12 +37,16 @@ namespace tensorflow { namespace { +using Eigen::operator==; typedef Eigen::Tensor<float, 2, Eigen::RowMajor> Matrix; typedef Eigen::DSizes<Eigen::DenseIndex, 2> DSizes; -typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>, - Eigen::Aligned> MatrixMap; typedef Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor>, - Eigen::Aligned> ConstMatrixMap; + Eigen::Aligned> + ConstMatrixMap; +typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>, + Eigen::Aligned> + MatrixMap; + typedef Eigen::ThreadPoolDevice CPUDevice; // Blocksizes @@ -52,8 +58,7 @@ static const int N = 128; // This stores a sparse representation of a slice of a matrix with size // (num_rows, num_cols). The slice is represented as a series of blocks of size // (num_rows, b), where b = block_size for all but the last block, which may -// have -// fewer columns. +// have fewer columns. // // num_rows and block_size are assumed to be <= 256. This allows storing // different indices as uint8. @@ -71,7 +76,12 @@ static const int N = 128; // are the values in the following range: // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for // index_offset. +template <typename T> struct SparseSlice { + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Aligned> + ConstMatrixMap; + public: // Indices of three elements on the same row. struct Index3 { @@ -105,13 +115,13 @@ struct SparseSlice { // See comments above. std::vector<int> index3_offset; std::vector<Index3> index3; - std::vector<float> data3; + std::vector<T> data3; // See comments above. Similar to "index3" except that each element in "index" // corresponds to one element in data. std::vector<int> index_offset; std::vector<Index> index; - std::vector<float> data; + std::vector<T> data; // Number of rows and columns for the slice. const int num_rows; @@ -121,8 +131,10 @@ struct SparseSlice { const int block_size; }; +template <typename T> template <bool Transpose> -void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { +void SparseSlice<T>::Initialize(const SparseSlice<T>::ConstMatrixMap& mat, + int col_offset) { const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0); const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1); DCHECK_LE(num_rows, mat_rows); @@ -142,6 +154,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { Index3 idx3; Index idx; int data3_size = 0; + static const T zero(0); for (int i = 0; i < num_blocks; ++i) { int num_block_cols = std::min(block_size, num_cols - block_size * i); for (int row = 0; row < num_rows; ++row) { @@ -150,17 +163,17 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { // *curr is nonzero and then reads it again on use. However, the result // of the race is only that some of the "nonzeros" in the resulting sparse // representation may actually be zero, which is harmless. - const float* start = + const auto* start = Transpose ? &mat(col_offset, row) : &mat(row, col_offset); - const float* curr = start; + const auto* curr = start; const int stride = Transpose ? mat.dimension(1) : 1; - const float* end = start + stride * num_block_cols; + const auto* end = start + stride * num_block_cols; uint8 k = 0; #define NEXT_ELEM \ curr += stride; \ ++k; while (true) { - while (curr < end && (*curr == 0)) { + while (curr < end && (*curr == zero)) { NEXT_ELEM; } if (curr >= end) break; @@ -168,7 +181,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { data3.push_back(*curr); NEXT_ELEM; - while (curr < end && (*curr == 0)) { + while (curr < end && (*curr == zero)) { NEXT_ELEM; } if (curr >= end) break; @@ -176,7 +189,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { data3.push_back(*curr); NEXT_ELEM; - while (curr < end && (*curr == 0)) { + while (curr < end && (*curr == zero)) { NEXT_ELEM; } if (curr >= end) break; @@ -213,7 +226,8 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { DCHECK_EQ(index.size(), data.size()); } -void SparseSlice::Clear() { +template <typename T> +void SparseSlice<T>::Clear() { index3_offset.clear(); index3.clear(); data3.clear(); @@ -222,107 +236,356 @@ void SparseSlice::Clear() { data.clear(); } -#define SCALAR_MULADD(a, inp, out) *out++ += *a * *inp++; - -#define SCALAR_MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \ - *out++ += *a1 * *inp1++ + *a2 * *inp2++ + *a3 * *inp3++; - typedef Eigen::internal::packet_traits<float>::type Packet; static const int kNumOperands = (sizeof(Packet) / sizeof(float)); #define LOAD(x) Eigen::internal::pload<Packet>(x); +#define EXPAND_BFLOAT_L(x, y) \ + const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x); +#define EXPAND_BFLOAT_U(x, y) \ + const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x); #define STORE(x, y) Eigen::internal::pstore<float>(x, y); -#define LOAD_SCALAR(x, y) const auto y = Eigen::internal::pload1<Packet>(x); #define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c); -// Vectorized version of SCALAR_MULADD. -#define MULADD(a, inp, out) \ - do { \ - const auto b = LOAD(inp); \ - inp += kNumOperands; \ - auto c = LOAD(out); \ - FMA(a, b, c, c); \ - STORE(out, c); \ - out += kNumOperands; \ - } while (false) - -// Vectorized version of SCALAR_MULADD3WAY. -#define MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \ - do { \ - auto c = LOAD(out); \ - const auto b1 = LOAD(inp1); \ - inp1 += kNumOperands; \ - const auto b2 = LOAD(inp2); \ - inp2 += kNumOperands; \ - const auto b3 = LOAD(inp3); \ - inp3 += kNumOperands; \ - FMA(a1, b1, c, c); \ - FMA(a2, b2, c, c); \ - FMA(a3, b3, c, c); \ - STORE(out, c); \ - out += kNumOperands; \ - } while (false) - -#ifdef EIGEN_VECTORIZE_AVX2 -// Unroll MULADD3WAY for two iterations -#define MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out) \ - do { \ - auto c1 = LOAD(out); \ - const auto b1 = LOAD(inp1); \ - const auto b2 = LOAD(inp2); \ - const auto b3 = LOAD(inp3); \ - \ - auto c2 = LOAD(out + kNumOperands); \ - const auto b4 = LOAD(inp1 + kNumOperands); \ - const auto b5 = LOAD(inp2 + kNumOperands); \ - const auto b6 = LOAD(inp3 + kNumOperands); \ - \ - FMA(a1, b1, c1, c1); \ - FMA(a1, b4, c2, c2); \ - FMA(a2, b2, c1, c1); \ - FMA(a2, b5, c2, c2); \ - FMA(a3, b3, c1, c1); \ - FMA(a3, b6, c2, c2); \ - STORE(out, c1); \ - STORE(out + kNumOperands, c2); \ - out += 2 * kNumOperands; \ - inp1 += 2 * kNumOperands; \ - inp2 += 2 * kNumOperands; \ - inp3 += 2 * kNumOperands; \ - } while (false) -// Further unroll MULADD3WAY. -#define MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out) \ - MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); -#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); -#else -#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \ - for (int __i = 0; __i < 128 / (4 * kNumOperands); ++__i) { \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ +#define ALWAYS_INLINE EIGEN_ALWAYS_INLINE + +ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) { + float out = 0; + auto tmp = reinterpret_cast<bfloat16*>(&out); + tmp[1] = *src; + return out; +} + +ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) { + return Eigen::internal::pload4bf16<Packet>( + reinterpret_cast<const float*>(src)); +} + +ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) { + return Eigen::internal::pload2bf16<Packet>( + reinterpret_cast<const float*>(src)); +} + +ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) { + **out += a * **inp; + ++*inp; + ++*out; +} + +ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp, + float** out) { + float inp_f = ConvertBfloat16ToFloat(*inp); + **out += a * inp_f; + ++*inp; + ++*out; +} +ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, + const float a3, const bfloat16** inp1, + const bfloat16** inp2, + const bfloat16** inp3, float** out) { + float inp1_f = ConvertBfloat16ToFloat(*inp1); + float inp2_f = ConvertBfloat16ToFloat(*inp2); + float inp3_f = ConvertBfloat16ToFloat(*inp3); + **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f; + ++*out; + ++*inp1; + ++*inp2; + ++*inp3; +} + +ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, + const float a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3; + ++*out; + ++*inp1; + ++*inp2; + ++*inp3; +} + +ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) { + auto tmp = ConvertBfloat16ToFloat(*data); + *l = Eigen::internal::pset1<Packet>(tmp); + ++*data; +} + +ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1, + Packet* l2) { + if (kNumOperands >= 2) { + auto tmp = ConvertTwoBfloat16ToFloat(*data); + *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); + *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); + *data += 2; + } else { + LoadSingleScalar(data, l1); + LoadSingleScalar(data, l2); + } +} + +ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1, + Packet* l2, Packet* l3, Packet* l4) { + if (kNumOperands >= 4) { + auto tmp = ConvertFourBfloat16ToFloat(*data); + *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp); + *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp); + *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp); + *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp); + *data += 4; + } else { + LoadTwoScalars(data, l1, l2); + LoadTwoScalars(data, l3, l4); + } +} + +ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) { + *l = Eigen::internal::pload1<Packet>(*data); + ++(*data); +} + +ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) { + LoadSingleScalar(data, l1); + LoadSingleScalar(data, l2); +} + +ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2, + Packet* l3, Packet* l4) { + LoadTwoScalars(data, l1, l2); + LoadTwoScalars(data, l3, l4); +} + +template <typename T> +ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2, + Packet* l3) { + LoadTwoScalars(data, l1, l2); + LoadSingleScalar(data, l3); +} + +template <typename T> +ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2, + Packet* l3, Packet* l4, Packet* l5, + Packet* l6) { + LoadFourScalars(data, l1, l2, l3, l4); + LoadTwoScalars(data, l5, l6); +} + +// Vectorized version of ScalarMulAdd. +ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) { + auto inp = reinterpret_cast<const float*>(*binp); + const auto b = LOAD(inp); + EXPAND_BFLOAT_L(b, b_0); + EXPAND_BFLOAT_U(b, b_1); + *binp += 2 * kNumOperands; + auto c1 = LOAD(*out); + auto c2 = LOAD(*out + kNumOperands); + FMA(a, b_0, c1, c1); + FMA(a, b_1, c2, c2); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + *out += 2 * kNumOperands; +} + +// Vectorized version of ScalarMulAdd3Way. +ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, + const bfloat16** binp1, const bfloat16** binp2, + const bfloat16** binp3, float** out) { + auto inp1 = reinterpret_cast<const float*>(*binp1); + auto inp2 = reinterpret_cast<const float*>(*binp2); + auto inp3 = reinterpret_cast<const float*>(*binp3); + auto c1 = LOAD(*out); + auto c2 = LOAD(*out + kNumOperands); + const auto b1 = LOAD(inp1); + EXPAND_BFLOAT_L(b1, b1_0); + EXPAND_BFLOAT_U(b1, b1_1); + *binp1 += 2 * kNumOperands; + const auto b2 = LOAD(inp2); + EXPAND_BFLOAT_L(b2, b2_0); + EXPAND_BFLOAT_U(b2, b2_1); + *binp2 += 2 * kNumOperands; + const auto b3 = LOAD(inp3); + EXPAND_BFLOAT_L(b3, b3_0); + EXPAND_BFLOAT_U(b3, b3_1); + *binp3 += 2 * kNumOperands; + FMA(a1, b1_0, c1, c1); + FMA(a1, b1_1, c2, c2); + FMA(a2, b2_0, c1, c1); + FMA(a2, b2_1, c2, c2); + FMA(a3, b3_0, c1, c1); + FMA(a3, b3_1, c2, c2); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + *out += 2 * kNumOperands; +} + +// Unroll MulAdd3Way for two iterations +ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, + const Packet a3, const bfloat16** binp1, + const bfloat16** binp2, const bfloat16** binp3, + float** out) { + auto inp1 = reinterpret_cast<const float*>(*binp1); + auto inp2 = reinterpret_cast<const float*>(*binp2); + auto inp3 = reinterpret_cast<const float*>(*binp3); + auto c1 = LOAD(*out); + auto c2 = LOAD(*out + kNumOperands); + const auto b1 = LOAD(inp1); + const auto b2 = LOAD(inp2); + const auto b3 = LOAD(inp3); + + EXPAND_BFLOAT_L(b1, b1_0); + EXPAND_BFLOAT_U(b1, b1_1); + EXPAND_BFLOAT_L(b2, b2_0); + EXPAND_BFLOAT_U(b2, b2_1); + EXPAND_BFLOAT_L(b3, b3_0); + EXPAND_BFLOAT_U(b3, b3_1); + auto c3 = LOAD(*out + 2 * kNumOperands); + auto c4 = LOAD(*out + 3 * kNumOperands); + const auto b4 = LOAD(inp1 + kNumOperands); + const auto b5 = LOAD(inp2 + kNumOperands); + const auto b6 = LOAD(inp3 + kNumOperands); + + EXPAND_BFLOAT_L(b4, b4_0); + EXPAND_BFLOAT_U(b4, b4_1); + EXPAND_BFLOAT_L(b5, b5_0); + EXPAND_BFLOAT_U(b5, b5_1); + EXPAND_BFLOAT_L(b6, b6_0); + EXPAND_BFLOAT_U(b6, b6_1); + + FMA(a1, b1_0, c1, c1); + FMA(a1, b1_1, c2, c2); + FMA(a1, b4_0, c3, c3); + FMA(a1, b4_1, c4, c4); + FMA(a2, b2_0, c1, c1); + FMA(a2, b2_1, c2, c2); + FMA(a2, b5_0, c3, c3); + FMA(a2, b5_1, c4, c4); + FMA(a3, b3_0, c1, c1); + FMA(a3, b3_1, c2, c2); + FMA(a3, b6_0, c3, c3); + FMA(a3, b6_1, c4, c4); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + STORE(*out + 2 * kNumOperands, c3); + STORE(*out + 3 * kNumOperands, c4); + *out += 4 * kNumOperands; + *binp1 += 4 * kNumOperands; + *binp2 += 4 * kNumOperands; + *binp3 += 4 * kNumOperands; +} + +// Apply MulAdd3Way on 128 operands. +ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, + const Packet a3, const bfloat16** inp1, + const bfloat16** inp2, const bfloat16** inp3, + float** out) { + for (int k = 0; k < 128 / (8 * kNumOperands); ++k) { + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); } -#endif +} +// Vectorized version of ScalarMulAdd +ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) { + const auto b = LOAD(*inp); + *inp += kNumOperands; + auto c = LOAD(*out); + FMA(a, b, c, c); + STORE(*out, c); + *out += kNumOperands; +} + +// Vectorized version of ScalarMulAdd3Way +ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, + const float** inp1, const float** inp2, + const float** inp3, float** out) { + auto c = LOAD(*out); + const auto b1 = LOAD(*inp1); + *inp1 += kNumOperands; + const auto b2 = LOAD(*inp2); + *inp2 += kNumOperands; + const auto b3 = LOAD(*inp3); + *inp3 += kNumOperands; + FMA(a1, b1, c, c); + FMA(a2, b2, c, c); + FMA(a3, b3, c, c); + STORE(*out, c); + *out += kNumOperands; +} + +// Unroll MulAdd3Way for two iterations +ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, + const Packet a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + auto c1 = LOAD(*out); + const auto b1 = LOAD(*inp1); + const auto b2 = LOAD(*inp2); + const auto b3 = LOAD(*inp3); + + auto c2 = LOAD(*out + kNumOperands); + const auto b4 = LOAD(*inp1 + kNumOperands); + const auto b5 = LOAD(*inp2 + kNumOperands); + const auto b6 = LOAD(*inp3 + kNumOperands); + + FMA(a1, b1, c1, c1); + FMA(a1, b4, c2, c2); + FMA(a2, b2, c1, c1); + FMA(a2, b5, c2, c2); + FMA(a3, b3, c1, c1); + FMA(a3, b6, c2, c2); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + *out += 2 * kNumOperands; + *inp1 += 2 * kNumOperands; + *inp2 += 2 * kNumOperands; + *inp3 += 2 * kNumOperands; +} + +// Unroll MulAdd3Way for four iterations +ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2, + const Packet a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); +} + +// Apply MulAdd3Way on 128 operands. +ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, + const Packet a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + if (kNumOperands == 8) { + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + } else { + DCHECK_LE(4 * kNumOperands, 128); + for (int i = 0; i < 128 / (4 * kNumOperands); ++i) { + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + } + } +} // Computes product of "left_slices" with "num_cols" columns of "right", and // stores the output in *"output". // Note that left_slices is a list of SparseSlices, which are conceptually // assumed to be concatenated along the column dimension. Also each SparseSlice // is encoded as a list of blocks with upto N columns. See SparseSlice for more // details. -template <int Cols> -inline void GEPP(const std::vector<SparseSlice*>& left_slices, - const ConstMatrixMap& right, const int num_cols, - Matrix* output) { +template <typename TL, typename TR, int Cols> +inline void GEPP( + const std::vector<SparseSlice<TL>*>& left_slices, + const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>, + Eigen::Aligned>& right, + const int num_cols, Matrix* output) { const int cols = (Cols == -1) ? num_cols : Cols; DCHECK_EQ(num_cols, cols); const int right_num_cols = right.dimension(1); const int output_num_cols = output->dimension(1); - const int cols_mod = cols % kNumOperands; + static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR); + const int cols_mod = cols % kNumOperandsR; int k_offset = 0; // Pre-compute pointers for output matrix. float* out_ptrs[M]; @@ -332,15 +595,15 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices, } for (const auto* left_slice : left_slices) { const auto& left = *left_slice; - const float* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr; - const float* data = (left.data.size() > 0) ? &left.data[0] : nullptr; + const auto* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr; + const auto* data = (left.data.size() > 0) ? &left.data[0] : nullptr; const int num_blocks = left.index3_offset.size(); int begin3 = 0; int begin = 0; for (int i = 0; i < num_blocks; ++i) { // Pre-compute pointers for right matrix - const float* right_ptrs[K]; - const float* const right_start = &right(k_offset, 0); + const TR* right_ptrs[K]; + const auto* const right_start = &right(k_offset, 0); DCHECK_LT(k_offset, right.dimension(0)); for (int j = 0; j < K; ++j) { right_ptrs[j] = right_start + right_num_cols * j; @@ -350,79 +613,116 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices, int j = begin3; // Loop unrolled for 2 iterations. for (; j + 1 < end3; j += 2) { - const float* sl1 = data3++; - LOAD_SCALAR(sl1, l1); - const float* sl2 = data3++; - LOAD_SCALAR(sl2, l2); - const float* sl3 = data3++; - LOAD_SCALAR(sl3, l3); - const float* nsl1 = data3++; - LOAD_SCALAR(nsl1, nl1); - const float* nsl2 = data3++; - LOAD_SCALAR(nsl2, nl2); - const float* nsl3 = data3++; - LOAD_SCALAR(nsl3, nl3); - const SparseSlice::Index3& index = left.index3[j]; - const SparseSlice::Index3& nindex = left.index3[j + 1]; + Packet l1, l2, l3, nl1, nl2, nl3; + LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3); + const auto& index = left.index3[j]; + const auto& nindex = left.index3[j + 1]; float* out = out_ptrs[index.m]; float* nout = out_ptrs[nindex.m]; - const float* r1 = right_ptrs[index.k1]; - const float* r2 = right_ptrs[index.k2]; - const float* r3 = right_ptrs[index.k3]; - const float* nr1 = right_ptrs[nindex.k1]; - const float* nr2 = right_ptrs[nindex.k2]; - const float* nr3 = right_ptrs[nindex.k3]; + const auto* r1 = right_ptrs[index.k1]; + const auto* r2 = right_ptrs[index.k2]; + const auto* r3 = right_ptrs[index.k3]; + + const auto* nr1 = right_ptrs[nindex.k1]; + const auto* nr2 = right_ptrs[nindex.k2]; + const auto* nr3 = right_ptrs[nindex.k3]; if (cols == 128) { - MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out); - MULADD3WAY_128(nl1, nl2, nl3, nr1, nr2, nr3, nout); + MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); + MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); } else { - for (int n = 0; n < cols / kNumOperands; ++n) { - MULADD3WAY(l1, l2, l3, r1, r2, r3, out); - MULADD3WAY(nl1, nl2, nl3, nr1, nr2, nr3, nout); + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); + MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); } + + const float sl1 = Eigen::internal::pfirst<Packet>(l1); + const float sl2 = Eigen::internal::pfirst<Packet>(l2); + const float sl3 = Eigen::internal::pfirst<Packet>(l3); + const float nsl1 = Eigen::internal::pfirst<Packet>(nl1); + const float nsl2 = Eigen::internal::pfirst<Packet>(nl2); + const float nsl3 = Eigen::internal::pfirst<Packet>(nl3); for (int k = 0; k < cols_mod; ++k) { - SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out); - SCALAR_MULADD3WAY(nsl1, nsl2, nsl3, nr1, nr2, nr3, nout); + ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); + ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout); } } } if (j < end3) { - const float* sl1 = data3++; - LOAD_SCALAR(sl1, l1); - const float* sl2 = data3++; - LOAD_SCALAR(sl2, l2); - const float* sl3 = data3++; - LOAD_SCALAR(sl3, l3); - const SparseSlice::Index3& index = left.index3[j]; + Packet l1, l2, l3; + LoadThreeScalars(&data3, &l1, &l2, &l3); + + const auto& index = left.index3[j]; float* out = out_ptrs[index.m]; - const float* r1 = right_ptrs[index.k1]; - const float* r2 = right_ptrs[index.k2]; - const float* r3 = right_ptrs[index.k3]; + const auto* r1 = right_ptrs[index.k1]; + const auto* r2 = right_ptrs[index.k2]; + const auto* r3 = right_ptrs[index.k3]; if (cols == 128) { - MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out); + MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); } else { - for (int n = 0; n < cols / kNumOperands; ++n) { - MULADD3WAY(l1, l2, l3, r1, r2, r3, out); + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); } + const float sl1 = Eigen::internal::pfirst<Packet>(l1); + const float sl2 = Eigen::internal::pfirst<Packet>(l2); + const float sl3 = Eigen::internal::pfirst<Packet>(l3); for (int k = 0; k < cols_mod; ++k) { - SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out); + ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); } } } begin3 = end3; int end = left.index_offset[i]; - for (int j = begin; j < end; ++j) { - const float* sl = data++; - LOAD_SCALAR(sl, l); - const SparseSlice::Index& index = left.index[j]; - const float* r = right_ptrs[index.k]; + // Loop unrolled for 4 iterations. + j = begin; + for (; j + 3 < end; j += 4) { + Packet l, nl, n2l, n3l; + LoadFourScalars(&data, &l, &nl, &n2l, &n3l); + + const auto& index = left.index[j]; + const auto& nindex = left.index[j + 1]; + const auto& n2index = left.index[j + 2]; + const auto& n3index = left.index[j + 3]; + const auto* r = right_ptrs[index.k]; + const auto* nr = right_ptrs[nindex.k]; + const auto* n2r = right_ptrs[n2index.k]; + const auto* n3r = right_ptrs[n3index.k]; + float* out = out_ptrs[index.m]; + float* nout = out_ptrs[nindex.m]; + float* n2out = out_ptrs[n2index.m]; + float* n3out = out_ptrs[n3index.m]; + + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd(l, &r, &out); + MulAdd(nl, &nr, &nout); + MulAdd(n2l, &n2r, &n2out); + MulAdd(n3l, &n3r, &n3out); + } + + const float sl1 = Eigen::internal::pfirst<Packet>(l); + const float sl2 = Eigen::internal::pfirst<Packet>(nl); + const float sl3 = Eigen::internal::pfirst<Packet>(n2l); + const float sl4 = Eigen::internal::pfirst<Packet>(n3l); + for (int k = 0; k < cols_mod; ++k) { + ScalarMulAdd(sl1, &r, &out); + ScalarMulAdd(sl2, &nr, &nout); + ScalarMulAdd(sl3, &n2r, &n2out); + ScalarMulAdd(sl4, &n3r, &n3out); + } + } + while (j < end) { + Packet l; + LoadSingleScalar(&data, &l); + const auto& index = left.index[j]; + const auto* r = right_ptrs[index.k]; float* out = out_ptrs[index.m]; - for (int n = 0; n < cols / kNumOperands; ++n) { - MULADD(l, r, out); + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd(l, &r, &out); } + const float sl = Eigen::internal::pfirst<Packet>(l); for (int k = 0; k < cols_mod; ++k) { - SCALAR_MULADD(sl, r, out); + ScalarMulAdd(sl, &r, &out); } + j++; } k_offset += left.block_size; begin = end; @@ -430,21 +730,101 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices, } } -#undef SCALAR_MULADD -#undef SCALAR_MULADD3WAY #undef LOAD +#undef EXPAND_BFLOAT_L +#undef EXPAND_BFLOAT_U #undef STORE -#undef LOAD_SCALAR #undef FMA -#undef MULADD -#undef MULADD3WAY -#undef MULADD3WAY_16 -#undef MULADD3WAY_32 -#undef MULADD3WAY_128 } // namespace +template <typename TL, typename TR> +class SparseMatMul { + typedef Eigen::Tensor<TL, 2, Eigen::RowMajor> MatrixL; + typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR; + typedef Eigen::TensorMap<Eigen::Tensor<const TL, 2, Eigen::RowMajor>, + Eigen::Aligned> + ConstMatrixMapL; + typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>, + Eigen::Aligned> + ConstMatrixMapR; + typedef Eigen::TensorMap<Eigen::Tensor<TR, 2, Eigen::RowMajor>, + Eigen::Aligned> + MatrixMapR; + // Perform matrix multiplication of "left" and "right", and store the result + // in *"output". + public: + static inline void Compute(const ConstMatrixMapL& left, + const ConstMatrixMapR& right, bool transpose_left, + const DeviceBase::CpuWorkerThreads* thread_pool, + bool transpose_output, MatrixMap* output); + + private: + // Computes multiplication of left and num_cols columns of right, and stores + // the output block in *"output" at offsets "output_row_offset" and + // "output_col_offset". If assign is true, assigns the value to that block, + // else adds the values to the existing values. + static inline void ComputeOutputBlock( + const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right, + int num_cols, int output_row_offset, int output_col_offset, bool assign, + bool transpose_output, MatrixMap* output); + + // Encodes "mat" using a sparse representation and stores that in + // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and + // "slice_num_cols", each grid element is converted into a SparseSlice and + // stored in mat_slices. "slice_block_size" is used to perform further column + // blocking of each slice. + static inline BlockingCounter* CreateSparseSlices( + const ConstMatrixMapL& mat, bool transpose, int slice_num_rows, + int slice_block_size, int slice_num_cols, + std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, + const DeviceBase::CpuWorkerThreads* thread_pool); + + // This function chops "mat" along column dimension into pieces with at most N + // columns, and concatenates the pieces one after the other in "buffer". It + // returns the list of the pieces in "slices". It returns a BlockingCounter + // which should be used to wait for the shuffle operations to complete. + static inline BlockingCounter* CreateDenseSlices( + const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start, + int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, + MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices); + + // Helper function for CreateDenseSlices to move the data around. It returns a + // BlockingCounter which should be used to wait for the shuffle operations to + // complete. + static inline BlockingCounter* ShuffleMatrix( + const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows, + int slice_col_start, int slice_num_cols, const int N, + const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer); + + // Helper function for CreateDenseSlices to create slices. + static inline void SliceMatrix(const MatrixR& mat, const int num_rows, + const int num_slices, + std::vector<ConstMatrixMapR*>* slices); + + // Heuristics to compute various block sizes. + // KR, NR: block sizes for "right". We run blocking iterations that operate on + // matrices with at most this size. + // KL: grid size along the column dimension used while encoding left. + // IB, JB: number of left and right slices to multiply together. This is used + // for ordering different ComputeBlockOutput operations inside each blocking + // iteration so as to potentially reduce the working set size. + static inline void ComputeBlockSizes(const ConstMatrixMapL& left, + const ConstMatrixMapR& right, + bool transpose_left, int num_threads, + int* KR, int* NR, int* KL, int* JB, + int* IB); + + TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul); +}; + +template <typename TL, typename TR> class SparseMatMulOp : public OpKernel { + typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR; + typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>, + Eigen::Aligned> + ConstMatrixMapR; + public: explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); @@ -461,12 +841,10 @@ class SparseMatMulOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), errors::InvalidArgument("b is not a matrix")); - auto left = a.matrix<float>(); - auto right = b.matrix<float>(); - const int m = transpose_a_ ? left.dimension(1) : left.dimension(0); - const int k = transpose_a_ ? left.dimension(0) : left.dimension(1); - const int n = transpose_b_ ? right.dimension(0) : right.dimension(1); - const int k2 = transpose_b_ ? right.dimension(1) : right.dimension(0); + const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0); + const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1); + const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1); + const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0); OP_REQUIRES(ctx, k == k2, errors::InvalidArgument("Matrix size incompatible: a: ", @@ -476,127 +854,94 @@ class SparseMatMulOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); auto out = output->matrix<float>(); + std::unique_ptr<Tensor> a_float; + std::unique_ptr<Tensor> b_float; if (!a_is_sparse_ && !b_is_sparse_) { - // Fallback to Eigen contract. - // Note that we currently don't optimize the case where only right is - // sparse. That can generally be handled by transposing the order of the - // matmul. + auto left = &a; + auto right = &b; + // TODO(agarwal): multi-thread the conversions from bfloat16 to float. + if (std::is_same<TL, bfloat16>::value) { + a_float.reset(new Tensor(DT_FLOAT, a.shape())); + BFloat16ToFloat(a.flat<bfloat16>().data(), + a_float->flat<float>().data(), a.NumElements()); + left = a_float.get(); + } + if (std::is_same<TR, bfloat16>::value) { + b_float.reset(new Tensor(DT_FLOAT, b.shape())); + BFloat16ToFloat(b.flat<bfloat16>().data(), + b_float->flat<float>().data(), b.NumElements()); + right = b_float.get(); + } Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; dim_pair[0].first = transpose_a_ ? 0 : 1; dim_pair[0].second = transpose_b_ ? 1 : 0; + out.device(ctx->template eigen_device<CPUDevice>()) = - left.contract(right, dim_pair); + left->matrix<float>().contract(right->matrix<float>(), dim_pair); return; } - auto left_mat = &left; - auto right_mat = &right; + + auto left = &a; + auto right = &b; bool transpose_output = false; bool transpose_a = transpose_a_; bool transpose_b = transpose_b_; if (!a_is_sparse_) { // Swap the order of multiplications using the identity: // A * B = (B' * A')'. - std::swap(left_mat, right_mat); + std::swap(left, right); std::swap(transpose_a, transpose_b); transpose_a = !transpose_a; transpose_b = !transpose_b; transpose_output = !transpose_output; } - std::unique_ptr<Matrix> right_tr_mat; - std::unique_ptr<TTypes<float>::ConstMatrix> right_tr_map; + + std::unique_ptr<Tensor> right_tr; if (transpose_b) { // TODO(agarwal): avoid transposing the matrix here and directly handle // transpose in CreateDenseSlices. - right_tr_mat.reset( - new Matrix(right_mat->dimension(1), right_mat->dimension(0))); + right_tr.reset( + new Tensor(right->dtype(), + TensorShape({right->dim_size(1), right->dim_size(0)}))); + Eigen::array<int, 2> perm({1, 0}); - right_tr_mat->device(ctx->template eigen_device<CPUDevice>()) = - right_mat->shuffle(perm); - right_tr_map.reset(new TTypes<float>::ConstMatrix( - right_tr_mat->data(), right_tr_mat->dimensions())); - right_mat = right_tr_map.get(); + if (transpose_output) { + right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) = + right->matrix<TL>().shuffle(perm); + } else { + right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) = + right->matrix<TR>().shuffle(perm); + } + right = right_tr.get(); } - SparseMatMul(*left_mat, *right_mat, transpose_a, - ctx->device()->tensorflow_cpu_worker_threads(), - transpose_output, &out); + if (transpose_output) { + SparseMatMul<TR, TL>::Compute( + left->matrix<TR>(), right->matrix<TL>(), transpose_a, + ctx->device()->tensorflow_cpu_worker_threads(), transpose_output, + &out); + } else { + SparseMatMul<TL, TR>::Compute( + left->matrix<TL>(), right->matrix<TR>(), transpose_a, + ctx->device()->tensorflow_cpu_worker_threads(), transpose_output, + &out); + } } private: - // Perform matrix multiplication of "left" and "right", and store the result - // in *"output". - static inline void SparseMatMul( - const ConstMatrixMap& left, const ConstMatrixMap& right, - bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, - bool transpose_output, MatrixMap* output); - - // Computes multiplication of left and num_cols columns of right, and stores - // the output block in *"output" at offsets "output_row_offset" and - // "output_col_offset". If assign is true, assigns the value to that block, - // else adds the values to the existing values. - static inline void ComputeOutputBlock(const std::vector<SparseSlice*>& left, - const ConstMatrixMap& right, - int num_cols, int output_row_offset, - int output_col_offset, bool assign, - bool transpose_output, - MatrixMap* output); - - // Encodes "mat" using a sparse representation and stores that in - // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and - // "slice_num_cols", each grid element is converted into a SparseSlice and - // stored in mat_slices. "slice_block_size" is used to perform further column - // blocking of each slice. - static inline BlockingCounter* CreateSparseSlices( - const ConstMatrixMap& mat, bool transpose, int slice_num_rows, - int slice_block_size, int slice_num_cols, - std::vector<std::vector<SparseSlice*>>* mat_slices, - const DeviceBase::CpuWorkerThreads* thread_pool); - - // This function chops "mat" along column dimension into pieces with at most N - // columns, and concatenates the pieces one after the other in "buffer". It - // returns the list of the pieces in "slices". It returns a BlockingCounter - // which should be used to wait for the shuffle operations to complete. - static inline BlockingCounter* CreateDenseSlices( - const ConstMatrixMap& mat, int row_start, int num_rows, int col_start, - int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, - Matrix* buffer, std::vector<ConstMatrixMap*>* slices); - - // Helper function for CreateDenseSlices to move the data around. It returns a - // BlockingCounter which should be used to wait for the shuffle operations to - // complete. - static inline BlockingCounter* ShuffleMatrix( - const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows, - int slice_col_start, int slice_num_cols, const int N, - const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer); - - // Helper function for CreateDenseSlices to create slices. - static inline void SliceMatrix(const Matrix& mat, const int num_rows, - const int num_slices, - std::vector<ConstMatrixMap*>* slices); - - // Heuristics to compute various block sizes. - // KR, NR: block sizes for "right". We run blocking iterations that operate on - // matrices with at most this size. - // KL: grid size along the column dimension used while encoding left. - // IB, JB: number of left and right slices to multiply together. This is used - // for ordering different ComputeBlockOutput operations inside each blocking - // iteration so as to potentially reduce the working set size. - static inline void ComputeBlockSizes(const ConstMatrixMap& left, - const ConstMatrixMap& right, - bool transpose_left, int num_threads, - int* KR, int* NR, int* KL, int* JB, - int* IB); - bool transpose_a_; bool transpose_b_; bool a_is_sparse_; bool b_is_sparse_; + TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); }; -inline void SparseMatMulOp::ComputeOutputBlock( - const std::vector<SparseSlice*>& left, const ConstMatrixMap& right, - int num_cols, int output_row_offset, int output_col_offset, bool assign, +template <typename TL, typename TR> +inline void SparseMatMul<TL, TR>::ComputeOutputBlock( + const std::vector<SparseSlice<TL>*>& left, + const SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols, + int output_row_offset, int output_col_offset, bool assign, bool transpose_output, MatrixMap* output) { static const Eigen::array<int, 2> perm({1, 0}); int num_rows = left[0]->num_rows; @@ -605,9 +950,9 @@ inline void SparseMatMulOp::ComputeOutputBlock( Matrix out(num_rows, rhs_num_cols); out.setZero(); if (num_cols == N) { - GEPP<N>(left, right, num_cols, &out); + GEPP<TL, TR, N>(left, right, num_cols, &out); } else { - GEPP<-1>(left, right, num_cols, &out); + GEPP<TL, TR, -1>(left, right, num_cols, &out); } if (!assign) { const Eigen::array<int, 2> begin = {output_row_offset, output_col_offset}; @@ -643,10 +988,11 @@ inline void SparseMatMulOp::ComputeOutputBlock( } } -inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( - const ConstMatrixMap& mat, bool transpose, int slice_num_rows, - int slice_block_size, int slice_num_cols, - std::vector<std::vector<SparseSlice*>>* mat_slices, +template <typename TL, typename TR> +inline BlockingCounter* SparseMatMul<TL, TR>::CreateSparseSlices( + const SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose, + int slice_num_rows, int slice_block_size, int slice_num_cols, + std::vector<std::vector<SparseSlice<TL>*>>* mat_slices, const DeviceBase::CpuWorkerThreads* thread_pool) { const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0); const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1); @@ -657,12 +1003,13 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( mat_slices->resize(num_slices_dim0); BlockingCounter* counter = new BlockingCounter(num_slices_dim0 * num_slices_dim1); - auto work = [counter, transpose](SparseSlice* sparse_slice, - ConstMatrixMap* slice, int col_offset) { + auto work = [counter, transpose](SparseSlice<TL>* sparse_slice, + SparseMatMul<TL, TR>::ConstMatrixMapL* slice, + int col_offset) { if (transpose) { - sparse_slice->Initialize<true>(*slice, col_offset); + sparse_slice->template Initialize<true>(*slice, col_offset); } else { - sparse_slice->Initialize<false>(*slice, col_offset); + sparse_slice->template Initialize<false>(*slice, col_offset); } delete slice; counter->DecrementCount(); @@ -674,16 +1021,17 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( for (int j = 0; j < num_slices_dim1; ++j) { int num_cols = std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols); - ConstMatrixMap* slice = nullptr; + SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr; if (transpose) { - slice = - new ConstMatrixMap(&mat(0, i * slice_num_rows), mat.dimensions()); + slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( + &mat(0, i * slice_num_rows), mat.dimensions()); } else { DSizes d(num_rows, mat_num_cols); - slice = new ConstMatrixMap(&mat(i * slice_num_rows, 0), d); + slice = new SparseMatMul<TL, TR>::ConstMatrixMapL( + &mat(i * slice_num_rows, 0), d); } - SparseSlice* sparse_slice = - new SparseSlice(num_rows, num_cols, slice_block_size); + auto* sparse_slice = + new SparseSlice<TL>(num_rows, num_cols, slice_block_size); (*mat_slices)[i][j] = sparse_slice; thread_pool->workers->Schedule( std::bind(work, sparse_slice, slice, slice_num_cols * j)); @@ -691,11 +1039,58 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( } return counter; } +#define LOAD(x) Eigen::internal::ploadu<Packet>((x)); +#define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x); +#define STORE(x, y) Eigen::internal::pstoreu<float>(x, y); + +template <int NUM_ELEM = -1> +ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc, + int num_elements) { + DCHECK_GE(kNumOperands, 8); + static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16); + const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM; + DCHECK_EQ(num, num_elements); + const float* src = reinterpret_cast<const float*>(bsrc); + float* dst = reinterpret_cast<float*>(bdst); + for (int index = 0; index + kStep <= num; index += kStep) { + auto in = LOAD(src); + auto tmp = INTERLEAVE(in); + STORE(dst, tmp); + src += kNumOperands; + dst += kNumOperands; + } + if (num % kStep != 0) { + memcpy(dst, src, (num % kStep) * sizeof(bfloat16)); + } +} -inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( - const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows, - int slice_col_start, int slice_num_cols, const int N, - const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer) { +template <typename T> +ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src, + int num_elements) { + if (std::is_same<T, float>::value || kNumOperands < 8) { + memcpy(dst, src, num_elements * sizeof(T)); + } else if (std::is_same<T, bfloat16>::value) { + if (num_elements == N) { + CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements); + } else { + CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements); + } + } else { + LOG(FATAL) << "Unsupported type"; + } +} + +#undef LOAD +#undef Interleave +#undef Store + +template <typename TL, typename TR> +inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix( + const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int slice_row_start, + int slice_num_rows, int slice_col_start, int slice_num_cols, const int N, + const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) { + DCHECK_EQ(N % 2, 0); + DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N); int num_threads = std::min(thread_pool->num_threads, 16); BlockingCounter* counter = new BlockingCounter(num_threads); DCHECK_EQ(N, buffer->dimension(1)); @@ -703,17 +1098,17 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( slice_num_cols, N, buffer, counter](int s, int e) { const int row_start = s % slice_num_rows + slice_row_start; const int col_start = s / slice_num_rows * N + slice_col_start; - float* out_start = &(*buffer)(s, 0); - const float* input_start = &mat(row_start, col_start); - const float* input_end = &mat(slice_row_start + slice_num_rows - 1, - slice_col_start + slice_num_cols - 1); + auto* out_start = &(*buffer)(s, 0); + const auto* input_start = &mat(row_start, col_start); + const auto* input_end = &mat(slice_row_start + slice_num_rows - 1, + slice_col_start + slice_num_cols - 1); const int mat_num_cols = mat.dimension(1); const int row_slice_size = slice_num_rows * mat_num_cols; const int aligned_end = slice_num_cols / N * slice_num_rows; const int e1 = std::min(e, aligned_end); while (s < e1) { - memcpy(out_start, input_start, N * sizeof(float)); + CopyAndMayBeInterleave<TR>(out_start, input_start, N); out_start += N; input_start += mat_num_cols; if (input_start > input_end) { @@ -724,7 +1119,7 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( int s1 = std::max(s, aligned_end); const int copy_num_cols = slice_num_cols % N; while (s1 < e) { - memcpy(out_start, input_start, copy_num_cols * sizeof(float)); + CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols); out_start += N; input_start += mat_num_cols; ++s1; @@ -745,21 +1140,24 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( return counter; } -inline void SparseMatMulOp::SliceMatrix(const Matrix& mat, const int num_rows, - const int num_slices, - std::vector<ConstMatrixMap*>* slices) { +template <typename TL, typename TR> +inline void SparseMatMul<TL, TR>::SliceMatrix( + const MatrixR& mat, const int num_rows, const int num_slices, + std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { slices->resize(num_slices); DSizes d(num_rows, mat.dimension(1)); DCHECK_LE(num_rows * num_slices, mat.dimension(0)); for (int i = 0; i < num_slices; ++i) { - (*slices)[i] = new ConstMatrixMap(&mat(i * num_rows, 0), d); + (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d); } } -inline BlockingCounter* SparseMatMulOp::CreateDenseSlices( - const ConstMatrixMap& mat, int row_start, int num_rows, int col_start, - int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, - Matrix* buffer, std::vector<ConstMatrixMap*>* slices) { +template <typename TL, typename TR> +inline BlockingCounter* SparseMatMul<TL, TR>::CreateDenseSlices( + const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start, + int num_rows, int col_start, int num_cols, + const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer, + std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) { BlockingCounter* shuffle_counter = ShuffleMatrix( mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer); const int num_slices = (num_cols + N - 1) / N; @@ -767,11 +1165,11 @@ inline BlockingCounter* SparseMatMulOp::CreateDenseSlices( return shuffle_counter; } -inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left, - const ConstMatrixMap& right, - bool transpose_left, - int num_threads, int* KR, int* NR, - int* KL, int* JB, int* IB) { +template <typename TL, typename TR> +inline void SparseMatMul<TL, TR>::ComputeBlockSizes( + const SparseMatMul<TL, TR>::ConstMatrixMapL& left, + const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left, + int num_threads, int* KR, int* NR, int* KL, int* JB, int* IB) { // Heuristics for calculating block sizes // Assume two hyperthreads per core. const int est_num_cores = std::max(1, (num_threads + 1) / 2); @@ -838,20 +1236,21 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left, // and update the output block o_ij. These calls are further blocked to // reduce the working set size. In each iteration we take IB elements from // {l_i} and JB elements from {r_j} and compute the IB * JB inner products. -inline void SparseMatMulOp::SparseMatMul( - const ConstMatrixMap& left, const ConstMatrixMap& right, - bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, - bool transpose_output, MatrixMap* output) { +template <typename TL, typename TR> +inline void SparseMatMul<TL, TR>::Compute( + const SparseMatMul<TL, TR>::ConstMatrixMapL& left, + const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left, + const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output, + MatrixMap* output) { const int num_threads = thread_pool->num_threads; int KR, NR, KL, JB, IB; ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, &JB, &IB); - // Slice the left matrix - std::vector<std::vector<SparseSlice*>> left_slices; + std::vector<std::vector<SparseSlice<TL>*>> left_slices; std::unique_ptr<BlockingCounter> sparse_slice_counter; sparse_slice_counter.reset( - CreateSparseSlices(ConstMatrixMap(left.data(), left.dimensions()), + CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()), transpose_left, M, K, KL, &left_slices, thread_pool)); const int num_left_slices = left_slices.size(); @@ -862,10 +1261,10 @@ inline void SparseMatMulOp::SparseMatMul( // is the block size per iteration. const int buffer_num_rows = std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N; - Matrix buffer(buffer_num_rows, N); - std::vector<ConstMatrixMap*> right_slices; + MatrixR buffer(buffer_num_rows, N); + std::vector<ConstMatrixMapR*> right_slices; - std::vector<SparseSlice*> block_left_slices; + std::vector<SparseSlice<TL>*> block_left_slices; std::vector<std::function<void(void)>> tasks; // Number of blocks based on block sizes of KR * NR. const int num_k_blocks = (right_dim0 + KR - 1) / KR; @@ -890,7 +1289,6 @@ inline void SparseMatMulOp::SparseMatMul( const int num_cols = std::min(N, right_num_cols - N * j_inner); for (int i_inner = i_outer; i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) { - // Figure out which left slices to use. block_left_slices.clear(); int begin = kb * KR / KL; int end = std::min<int>((kb + 1) * KR / KL, @@ -900,7 +1298,7 @@ inline void SparseMatMulOp::SparseMatMul( left_slices[i_inner].begin() + begin, left_slices[i_inner].begin() + end); tasks.push_back(std::bind( - &SparseMatMulOp::ComputeOutputBlock, block_left_slices, + &ComputeOutputBlock, block_left_slices, std::ref(*right_slices[j_inner]), num_cols, M * i_inner, N * j_inner + nb * NR, kb == 0, transpose_output, output)); } @@ -933,7 +1331,18 @@ inline void SparseMatMulOp::SparseMatMul( } } -REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU), - SparseMatMulOp); +#define REGISTER_SPARSE_MATMUL(TA, TB) \ + REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<TA>("Ta") \ + .TypeConstraint<TB>("Tb"), \ + SparseMatMulOp<TA, TB>); + +REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); +REGISTER_SPARSE_MATMUL(float, bfloat16); +REGISTER_SPARSE_MATMUL(bfloat16, float); +REGISTER_SPARSE_MATMUL(float, float); + +#undef REGISTER_SPARSE_MATMUL } // end namespace tensorflow |