From 6b373bb149396beec11347881b2d6dedfbcc83c4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 10 May 2016 12:30:20 -0800 Subject: Allow bfloat16 inputs to SparseMatMul. Change: 121980920 --- tensorflow/core/kernels/sparse_matmul_op.cc | 1019 ++++++++++++++------ tensorflow/core/kernels/sparse_matmul_op_test.cc | 133 ++- tensorflow/core/kernels/transpose_op.cc | 1 + tensorflow/core/ops/math_ops.cc | 6 +- tensorflow/python/BUILD | 2 +- .../python/kernel_tests/sparse_matmul_op_test.py | 97 +- tensorflow/python/ops/math_grad.py | 26 +- tensorflow/python/ops/math_ops.py | 9 +- 8 files changed, 885 insertions(+), 408 deletions(-) diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 067aa2624d..6cfaa3b958 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -18,9 +18,11 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/sparse_matmul_op.h" + #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -35,12 +37,16 @@ namespace tensorflow { namespace { +using Eigen::operator==; typedef Eigen::Tensor Matrix; typedef Eigen::DSizes DSizes; -typedef Eigen::TensorMap, - Eigen::Aligned> MatrixMap; typedef Eigen::TensorMap, - Eigen::Aligned> ConstMatrixMap; + Eigen::Aligned> + ConstMatrixMap; +typedef Eigen::TensorMap, + Eigen::Aligned> + MatrixMap; + typedef Eigen::ThreadPoolDevice CPUDevice; // Blocksizes @@ -52,8 +58,7 @@ static const int N = 128; // This stores a sparse representation of a slice of a matrix with size // (num_rows, num_cols). The slice is represented as a series of blocks of size // (num_rows, b), where b = block_size for all but the last block, which may -// have -// fewer columns. +// have fewer columns. // // num_rows and block_size are assumed to be <= 256. This allows storing // different indices as uint8. @@ -71,7 +76,12 @@ static const int N = 128; // are the values in the following range: // [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for // index_offset. +template struct SparseSlice { + typedef Eigen::TensorMap, + Eigen::Aligned> + ConstMatrixMap; + public: // Indices of three elements on the same row. struct Index3 { @@ -105,13 +115,13 @@ struct SparseSlice { // See comments above. std::vector index3_offset; std::vector index3; - std::vector data3; + std::vector data3; // See comments above. Similar to "index3" except that each element in "index" // corresponds to one element in data. std::vector index_offset; std::vector index; - std::vector data; + std::vector data; // Number of rows and columns for the slice. const int num_rows; @@ -121,8 +131,10 @@ struct SparseSlice { const int block_size; }; +template template -void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { +void SparseSlice::Initialize(const SparseSlice::ConstMatrixMap& mat, + int col_offset) { const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0); const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1); DCHECK_LE(num_rows, mat_rows); @@ -142,6 +154,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { Index3 idx3; Index idx; int data3_size = 0; + static const T zero(0); for (int i = 0; i < num_blocks; ++i) { int num_block_cols = std::min(block_size, num_cols - block_size * i); for (int row = 0; row < num_rows; ++row) { @@ -150,17 +163,17 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { // *curr is nonzero and then reads it again on use. However, the result // of the race is only that some of the "nonzeros" in the resulting sparse // representation may actually be zero, which is harmless. - const float* start = + const auto* start = Transpose ? &mat(col_offset, row) : &mat(row, col_offset); - const float* curr = start; + const auto* curr = start; const int stride = Transpose ? mat.dimension(1) : 1; - const float* end = start + stride * num_block_cols; + const auto* end = start + stride * num_block_cols; uint8 k = 0; #define NEXT_ELEM \ curr += stride; \ ++k; while (true) { - while (curr < end && (*curr == 0)) { + while (curr < end && (*curr == zero)) { NEXT_ELEM; } if (curr >= end) break; @@ -168,7 +181,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { data3.push_back(*curr); NEXT_ELEM; - while (curr < end && (*curr == 0)) { + while (curr < end && (*curr == zero)) { NEXT_ELEM; } if (curr >= end) break; @@ -176,7 +189,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { data3.push_back(*curr); NEXT_ELEM; - while (curr < end && (*curr == 0)) { + while (curr < end && (*curr == zero)) { NEXT_ELEM; } if (curr >= end) break; @@ -213,7 +226,8 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { DCHECK_EQ(index.size(), data.size()); } -void SparseSlice::Clear() { +template +void SparseSlice::Clear() { index3_offset.clear(); index3.clear(); data3.clear(); @@ -222,107 +236,356 @@ void SparseSlice::Clear() { data.clear(); } -#define SCALAR_MULADD(a, inp, out) *out++ += *a * *inp++; - -#define SCALAR_MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \ - *out++ += *a1 * *inp1++ + *a2 * *inp2++ + *a3 * *inp3++; - typedef Eigen::internal::packet_traits::type Packet; static const int kNumOperands = (sizeof(Packet) / sizeof(float)); #define LOAD(x) Eigen::internal::pload(x); +#define EXPAND_BFLOAT_L(x, y) \ + const auto y = Eigen::internal::pexpand_bf16_l(x); +#define EXPAND_BFLOAT_U(x, y) \ + const auto y = Eigen::internal::pexpand_bf16_u(x); #define STORE(x, y) Eigen::internal::pstore(x, y); -#define LOAD_SCALAR(x, y) const auto y = Eigen::internal::pload1(x); #define FMA(a, b, c, d) d = Eigen::internal::pmadd(a, b, c); -// Vectorized version of SCALAR_MULADD. -#define MULADD(a, inp, out) \ - do { \ - const auto b = LOAD(inp); \ - inp += kNumOperands; \ - auto c = LOAD(out); \ - FMA(a, b, c, c); \ - STORE(out, c); \ - out += kNumOperands; \ - } while (false) - -// Vectorized version of SCALAR_MULADD3WAY. -#define MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \ - do { \ - auto c = LOAD(out); \ - const auto b1 = LOAD(inp1); \ - inp1 += kNumOperands; \ - const auto b2 = LOAD(inp2); \ - inp2 += kNumOperands; \ - const auto b3 = LOAD(inp3); \ - inp3 += kNumOperands; \ - FMA(a1, b1, c, c); \ - FMA(a2, b2, c, c); \ - FMA(a3, b3, c, c); \ - STORE(out, c); \ - out += kNumOperands; \ - } while (false) - -#ifdef EIGEN_VECTORIZE_AVX2 -// Unroll MULADD3WAY for two iterations -#define MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out) \ - do { \ - auto c1 = LOAD(out); \ - const auto b1 = LOAD(inp1); \ - const auto b2 = LOAD(inp2); \ - const auto b3 = LOAD(inp3); \ - \ - auto c2 = LOAD(out + kNumOperands); \ - const auto b4 = LOAD(inp1 + kNumOperands); \ - const auto b5 = LOAD(inp2 + kNumOperands); \ - const auto b6 = LOAD(inp3 + kNumOperands); \ - \ - FMA(a1, b1, c1, c1); \ - FMA(a1, b4, c2, c2); \ - FMA(a2, b2, c1, c1); \ - FMA(a2, b5, c2, c2); \ - FMA(a3, b3, c1, c1); \ - FMA(a3, b6, c2, c2); \ - STORE(out, c1); \ - STORE(out + kNumOperands, c2); \ - out += 2 * kNumOperands; \ - inp1 += 2 * kNumOperands; \ - inp2 += 2 * kNumOperands; \ - inp3 += 2 * kNumOperands; \ - } while (false) -// Further unroll MULADD3WAY. -#define MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out) \ - MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); -#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); -#else -#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \ - for (int __i = 0; __i < 128 / (4 * kNumOperands); ++__i) { \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ - MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \ +#define ALWAYS_INLINE EIGEN_ALWAYS_INLINE + +ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) { + float out = 0; + auto tmp = reinterpret_cast(&out); + tmp[1] = *src; + return out; +} + +ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) { + return Eigen::internal::pload4bf16( + reinterpret_cast(src)); +} + +ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) { + return Eigen::internal::pload2bf16( + reinterpret_cast(src)); +} + +ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) { + **out += a * **inp; + ++*inp; + ++*out; +} + +ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp, + float** out) { + float inp_f = ConvertBfloat16ToFloat(*inp); + **out += a * inp_f; + ++*inp; + ++*out; +} +ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, + const float a3, const bfloat16** inp1, + const bfloat16** inp2, + const bfloat16** inp3, float** out) { + float inp1_f = ConvertBfloat16ToFloat(*inp1); + float inp2_f = ConvertBfloat16ToFloat(*inp2); + float inp3_f = ConvertBfloat16ToFloat(*inp3); + **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f; + ++*out; + ++*inp1; + ++*inp2; + ++*inp3; +} + +ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2, + const float a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3; + ++*out; + ++*inp1; + ++*inp2; + ++*inp3; +} + +ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) { + auto tmp = ConvertBfloat16ToFloat(*data); + *l = Eigen::internal::pset1(tmp); + ++*data; +} + +ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1, + Packet* l2) { + if (kNumOperands >= 2) { + auto tmp = ConvertTwoBfloat16ToFloat(*data); + *l1 = Eigen::internal::pbroadcast_first(tmp); + *l2 = Eigen::internal::pbroadcast_second(tmp); + *data += 2; + } else { + LoadSingleScalar(data, l1); + LoadSingleScalar(data, l2); + } +} + +ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1, + Packet* l2, Packet* l3, Packet* l4) { + if (kNumOperands >= 4) { + auto tmp = ConvertFourBfloat16ToFloat(*data); + *l1 = Eigen::internal::pbroadcast_first(tmp); + *l2 = Eigen::internal::pbroadcast_second(tmp); + *l3 = Eigen::internal::pbroadcast_third(tmp); + *l4 = Eigen::internal::pbroadcast_fourth(tmp); + *data += 4; + } else { + LoadTwoScalars(data, l1, l2); + LoadTwoScalars(data, l3, l4); + } +} + +ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) { + *l = Eigen::internal::pload1(*data); + ++(*data); +} + +ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) { + LoadSingleScalar(data, l1); + LoadSingleScalar(data, l2); +} + +ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2, + Packet* l3, Packet* l4) { + LoadTwoScalars(data, l1, l2); + LoadTwoScalars(data, l3, l4); +} + +template +ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2, + Packet* l3) { + LoadTwoScalars(data, l1, l2); + LoadSingleScalar(data, l3); +} + +template +ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2, + Packet* l3, Packet* l4, Packet* l5, + Packet* l6) { + LoadFourScalars(data, l1, l2, l3, l4); + LoadTwoScalars(data, l5, l6); +} + +// Vectorized version of ScalarMulAdd. +ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) { + auto inp = reinterpret_cast(*binp); + const auto b = LOAD(inp); + EXPAND_BFLOAT_L(b, b_0); + EXPAND_BFLOAT_U(b, b_1); + *binp += 2 * kNumOperands; + auto c1 = LOAD(*out); + auto c2 = LOAD(*out + kNumOperands); + FMA(a, b_0, c1, c1); + FMA(a, b_1, c2, c2); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + *out += 2 * kNumOperands; +} + +// Vectorized version of ScalarMulAdd3Way. +ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, + const bfloat16** binp1, const bfloat16** binp2, + const bfloat16** binp3, float** out) { + auto inp1 = reinterpret_cast(*binp1); + auto inp2 = reinterpret_cast(*binp2); + auto inp3 = reinterpret_cast(*binp3); + auto c1 = LOAD(*out); + auto c2 = LOAD(*out + kNumOperands); + const auto b1 = LOAD(inp1); + EXPAND_BFLOAT_L(b1, b1_0); + EXPAND_BFLOAT_U(b1, b1_1); + *binp1 += 2 * kNumOperands; + const auto b2 = LOAD(inp2); + EXPAND_BFLOAT_L(b2, b2_0); + EXPAND_BFLOAT_U(b2, b2_1); + *binp2 += 2 * kNumOperands; + const auto b3 = LOAD(inp3); + EXPAND_BFLOAT_L(b3, b3_0); + EXPAND_BFLOAT_U(b3, b3_1); + *binp3 += 2 * kNumOperands; + FMA(a1, b1_0, c1, c1); + FMA(a1, b1_1, c2, c2); + FMA(a2, b2_0, c1, c1); + FMA(a2, b2_1, c2, c2); + FMA(a3, b3_0, c1, c1); + FMA(a3, b3_1, c2, c2); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + *out += 2 * kNumOperands; +} + +// Unroll MulAdd3Way for two iterations +ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, + const Packet a3, const bfloat16** binp1, + const bfloat16** binp2, const bfloat16** binp3, + float** out) { + auto inp1 = reinterpret_cast(*binp1); + auto inp2 = reinterpret_cast(*binp2); + auto inp3 = reinterpret_cast(*binp3); + auto c1 = LOAD(*out); + auto c2 = LOAD(*out + kNumOperands); + const auto b1 = LOAD(inp1); + const auto b2 = LOAD(inp2); + const auto b3 = LOAD(inp3); + + EXPAND_BFLOAT_L(b1, b1_0); + EXPAND_BFLOAT_U(b1, b1_1); + EXPAND_BFLOAT_L(b2, b2_0); + EXPAND_BFLOAT_U(b2, b2_1); + EXPAND_BFLOAT_L(b3, b3_0); + EXPAND_BFLOAT_U(b3, b3_1); + auto c3 = LOAD(*out + 2 * kNumOperands); + auto c4 = LOAD(*out + 3 * kNumOperands); + const auto b4 = LOAD(inp1 + kNumOperands); + const auto b5 = LOAD(inp2 + kNumOperands); + const auto b6 = LOAD(inp3 + kNumOperands); + + EXPAND_BFLOAT_L(b4, b4_0); + EXPAND_BFLOAT_U(b4, b4_1); + EXPAND_BFLOAT_L(b5, b5_0); + EXPAND_BFLOAT_U(b5, b5_1); + EXPAND_BFLOAT_L(b6, b6_0); + EXPAND_BFLOAT_U(b6, b6_1); + + FMA(a1, b1_0, c1, c1); + FMA(a1, b1_1, c2, c2); + FMA(a1, b4_0, c3, c3); + FMA(a1, b4_1, c4, c4); + FMA(a2, b2_0, c1, c1); + FMA(a2, b2_1, c2, c2); + FMA(a2, b5_0, c3, c3); + FMA(a2, b5_1, c4, c4); + FMA(a3, b3_0, c1, c1); + FMA(a3, b3_1, c2, c2); + FMA(a3, b6_0, c3, c3); + FMA(a3, b6_1, c4, c4); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + STORE(*out + 2 * kNumOperands, c3); + STORE(*out + 3 * kNumOperands, c4); + *out += 4 * kNumOperands; + *binp1 += 4 * kNumOperands; + *binp2 += 4 * kNumOperands; + *binp3 += 4 * kNumOperands; +} + +// Apply MulAdd3Way on 128 operands. +ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, + const Packet a3, const bfloat16** inp1, + const bfloat16** inp2, const bfloat16** inp3, + float** out) { + for (int k = 0; k < 128 / (8 * kNumOperands); ++k) { + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); } -#endif +} +// Vectorized version of ScalarMulAdd +ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) { + const auto b = LOAD(*inp); + *inp += kNumOperands; + auto c = LOAD(*out); + FMA(a, b, c, c); + STORE(*out, c); + *out += kNumOperands; +} + +// Vectorized version of ScalarMulAdd3Way +ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3, + const float** inp1, const float** inp2, + const float** inp3, float** out) { + auto c = LOAD(*out); + const auto b1 = LOAD(*inp1); + *inp1 += kNumOperands; + const auto b2 = LOAD(*inp2); + *inp2 += kNumOperands; + const auto b3 = LOAD(*inp3); + *inp3 += kNumOperands; + FMA(a1, b1, c, c); + FMA(a2, b2, c, c); + FMA(a3, b3, c, c); + STORE(*out, c); + *out += kNumOperands; +} + +// Unroll MulAdd3Way for two iterations +ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2, + const Packet a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + auto c1 = LOAD(*out); + const auto b1 = LOAD(*inp1); + const auto b2 = LOAD(*inp2); + const auto b3 = LOAD(*inp3); + + auto c2 = LOAD(*out + kNumOperands); + const auto b4 = LOAD(*inp1 + kNumOperands); + const auto b5 = LOAD(*inp2 + kNumOperands); + const auto b6 = LOAD(*inp3 + kNumOperands); + + FMA(a1, b1, c1, c1); + FMA(a1, b4, c2, c2); + FMA(a2, b2, c1, c1); + FMA(a2, b5, c2, c2); + FMA(a3, b3, c1, c1); + FMA(a3, b6, c2, c2); + STORE(*out, c1); + STORE(*out + kNumOperands, c2); + *out += 2 * kNumOperands; + *inp1 += 2 * kNumOperands; + *inp2 += 2 * kNumOperands; + *inp3 += 2 * kNumOperands; +} + +// Unroll MulAdd3Way for four iterations +ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2, + const Packet a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); +} + +// Apply MulAdd3Way on 128 operands. +ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2, + const Packet a3, const float** inp1, + const float** inp2, const float** inp3, + float** out) { + if (kNumOperands == 8) { + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + } else { + DCHECK_LE(4 * kNumOperands, 128); + for (int i = 0; i < 128 / (4 * kNumOperands); ++i) { + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out); + } + } +} // Computes product of "left_slices" with "num_cols" columns of "right", and // stores the output in *"output". // Note that left_slices is a list of SparseSlices, which are conceptually // assumed to be concatenated along the column dimension. Also each SparseSlice // is encoded as a list of blocks with upto N columns. See SparseSlice for more // details. -template -inline void GEPP(const std::vector& left_slices, - const ConstMatrixMap& right, const int num_cols, - Matrix* output) { +template +inline void GEPP( + const std::vector*>& left_slices, + const Eigen::TensorMap, + Eigen::Aligned>& right, + const int num_cols, Matrix* output) { const int cols = (Cols == -1) ? num_cols : Cols; DCHECK_EQ(num_cols, cols); const int right_num_cols = right.dimension(1); const int output_num_cols = output->dimension(1); - const int cols_mod = cols % kNumOperands; + static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR); + const int cols_mod = cols % kNumOperandsR; int k_offset = 0; // Pre-compute pointers for output matrix. float* out_ptrs[M]; @@ -332,15 +595,15 @@ inline void GEPP(const std::vector& left_slices, } for (const auto* left_slice : left_slices) { const auto& left = *left_slice; - const float* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr; - const float* data = (left.data.size() > 0) ? &left.data[0] : nullptr; + const auto* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr; + const auto* data = (left.data.size() > 0) ? &left.data[0] : nullptr; const int num_blocks = left.index3_offset.size(); int begin3 = 0; int begin = 0; for (int i = 0; i < num_blocks; ++i) { // Pre-compute pointers for right matrix - const float* right_ptrs[K]; - const float* const right_start = &right(k_offset, 0); + const TR* right_ptrs[K]; + const auto* const right_start = &right(k_offset, 0); DCHECK_LT(k_offset, right.dimension(0)); for (int j = 0; j < K; ++j) { right_ptrs[j] = right_start + right_num_cols * j; @@ -350,79 +613,116 @@ inline void GEPP(const std::vector& left_slices, int j = begin3; // Loop unrolled for 2 iterations. for (; j + 1 < end3; j += 2) { - const float* sl1 = data3++; - LOAD_SCALAR(sl1, l1); - const float* sl2 = data3++; - LOAD_SCALAR(sl2, l2); - const float* sl3 = data3++; - LOAD_SCALAR(sl3, l3); - const float* nsl1 = data3++; - LOAD_SCALAR(nsl1, nl1); - const float* nsl2 = data3++; - LOAD_SCALAR(nsl2, nl2); - const float* nsl3 = data3++; - LOAD_SCALAR(nsl3, nl3); - const SparseSlice::Index3& index = left.index3[j]; - const SparseSlice::Index3& nindex = left.index3[j + 1]; + Packet l1, l2, l3, nl1, nl2, nl3; + LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3); + const auto& index = left.index3[j]; + const auto& nindex = left.index3[j + 1]; float* out = out_ptrs[index.m]; float* nout = out_ptrs[nindex.m]; - const float* r1 = right_ptrs[index.k1]; - const float* r2 = right_ptrs[index.k2]; - const float* r3 = right_ptrs[index.k3]; - const float* nr1 = right_ptrs[nindex.k1]; - const float* nr2 = right_ptrs[nindex.k2]; - const float* nr3 = right_ptrs[nindex.k3]; + const auto* r1 = right_ptrs[index.k1]; + const auto* r2 = right_ptrs[index.k2]; + const auto* r3 = right_ptrs[index.k3]; + + const auto* nr1 = right_ptrs[nindex.k1]; + const auto* nr2 = right_ptrs[nindex.k2]; + const auto* nr3 = right_ptrs[nindex.k3]; if (cols == 128) { - MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out); - MULADD3WAY_128(nl1, nl2, nl3, nr1, nr2, nr3, nout); + MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); + MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); } else { - for (int n = 0; n < cols / kNumOperands; ++n) { - MULADD3WAY(l1, l2, l3, r1, r2, r3, out); - MULADD3WAY(nl1, nl2, nl3, nr1, nr2, nr3, nout); + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); + MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout); } + + const float sl1 = Eigen::internal::pfirst(l1); + const float sl2 = Eigen::internal::pfirst(l2); + const float sl3 = Eigen::internal::pfirst(l3); + const float nsl1 = Eigen::internal::pfirst(nl1); + const float nsl2 = Eigen::internal::pfirst(nl2); + const float nsl3 = Eigen::internal::pfirst(nl3); for (int k = 0; k < cols_mod; ++k) { - SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out); - SCALAR_MULADD3WAY(nsl1, nsl2, nsl3, nr1, nr2, nr3, nout); + ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); + ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout); } } } if (j < end3) { - const float* sl1 = data3++; - LOAD_SCALAR(sl1, l1); - const float* sl2 = data3++; - LOAD_SCALAR(sl2, l2); - const float* sl3 = data3++; - LOAD_SCALAR(sl3, l3); - const SparseSlice::Index3& index = left.index3[j]; + Packet l1, l2, l3; + LoadThreeScalars(&data3, &l1, &l2, &l3); + + const auto& index = left.index3[j]; float* out = out_ptrs[index.m]; - const float* r1 = right_ptrs[index.k1]; - const float* r2 = right_ptrs[index.k2]; - const float* r3 = right_ptrs[index.k3]; + const auto* r1 = right_ptrs[index.k1]; + const auto* r2 = right_ptrs[index.k2]; + const auto* r3 = right_ptrs[index.k3]; if (cols == 128) { - MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out); + MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out); } else { - for (int n = 0; n < cols / kNumOperands; ++n) { - MULADD3WAY(l1, l2, l3, r1, r2, r3, out); + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out); } + const float sl1 = Eigen::internal::pfirst(l1); + const float sl2 = Eigen::internal::pfirst(l2); + const float sl3 = Eigen::internal::pfirst(l3); for (int k = 0; k < cols_mod; ++k) { - SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out); + ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out); } } } begin3 = end3; int end = left.index_offset[i]; - for (int j = begin; j < end; ++j) { - const float* sl = data++; - LOAD_SCALAR(sl, l); - const SparseSlice::Index& index = left.index[j]; - const float* r = right_ptrs[index.k]; + // Loop unrolled for 4 iterations. + j = begin; + for (; j + 3 < end; j += 4) { + Packet l, nl, n2l, n3l; + LoadFourScalars(&data, &l, &nl, &n2l, &n3l); + + const auto& index = left.index[j]; + const auto& nindex = left.index[j + 1]; + const auto& n2index = left.index[j + 2]; + const auto& n3index = left.index[j + 3]; + const auto* r = right_ptrs[index.k]; + const auto* nr = right_ptrs[nindex.k]; + const auto* n2r = right_ptrs[n2index.k]; + const auto* n3r = right_ptrs[n3index.k]; + float* out = out_ptrs[index.m]; + float* nout = out_ptrs[nindex.m]; + float* n2out = out_ptrs[n2index.m]; + float* n3out = out_ptrs[n3index.m]; + + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd(l, &r, &out); + MulAdd(nl, &nr, &nout); + MulAdd(n2l, &n2r, &n2out); + MulAdd(n3l, &n3r, &n3out); + } + + const float sl1 = Eigen::internal::pfirst(l); + const float sl2 = Eigen::internal::pfirst(nl); + const float sl3 = Eigen::internal::pfirst(n2l); + const float sl4 = Eigen::internal::pfirst(n3l); + for (int k = 0; k < cols_mod; ++k) { + ScalarMulAdd(sl1, &r, &out); + ScalarMulAdd(sl2, &nr, &nout); + ScalarMulAdd(sl3, &n2r, &n2out); + ScalarMulAdd(sl4, &n3r, &n3out); + } + } + while (j < end) { + Packet l; + LoadSingleScalar(&data, &l); + const auto& index = left.index[j]; + const auto* r = right_ptrs[index.k]; float* out = out_ptrs[index.m]; - for (int n = 0; n < cols / kNumOperands; ++n) { - MULADD(l, r, out); + for (int n = 0; n < cols / kNumOperandsR; ++n) { + MulAdd(l, &r, &out); } + const float sl = Eigen::internal::pfirst(l); for (int k = 0; k < cols_mod; ++k) { - SCALAR_MULADD(sl, r, out); + ScalarMulAdd(sl, &r, &out); } + j++; } k_offset += left.block_size; begin = end; @@ -430,21 +730,101 @@ inline void GEPP(const std::vector& left_slices, } } -#undef SCALAR_MULADD -#undef SCALAR_MULADD3WAY #undef LOAD +#undef EXPAND_BFLOAT_L +#undef EXPAND_BFLOAT_U #undef STORE -#undef LOAD_SCALAR #undef FMA -#undef MULADD -#undef MULADD3WAY -#undef MULADD3WAY_16 -#undef MULADD3WAY_32 -#undef MULADD3WAY_128 } // namespace +template +class SparseMatMul { + typedef Eigen::Tensor MatrixL; + typedef Eigen::Tensor MatrixR; + typedef Eigen::TensorMap, + Eigen::Aligned> + ConstMatrixMapL; + typedef Eigen::TensorMap, + Eigen::Aligned> + ConstMatrixMapR; + typedef Eigen::TensorMap, + Eigen::Aligned> + MatrixMapR; + // Perform matrix multiplication of "left" and "right", and store the result + // in *"output". + public: + static inline void Compute(const ConstMatrixMapL& left, + const ConstMatrixMapR& right, bool transpose_left, + const DeviceBase::CpuWorkerThreads* thread_pool, + bool transpose_output, MatrixMap* output); + + private: + // Computes multiplication of left and num_cols columns of right, and stores + // the output block in *"output" at offsets "output_row_offset" and + // "output_col_offset". If assign is true, assigns the value to that block, + // else adds the values to the existing values. + static inline void ComputeOutputBlock( + const std::vector*>& left, const ConstMatrixMapR& right, + int num_cols, int output_row_offset, int output_col_offset, bool assign, + bool transpose_output, MatrixMap* output); + + // Encodes "mat" using a sparse representation and stores that in + // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and + // "slice_num_cols", each grid element is converted into a SparseSlice and + // stored in mat_slices. "slice_block_size" is used to perform further column + // blocking of each slice. + static inline BlockingCounter* CreateSparseSlices( + const ConstMatrixMapL& mat, bool transpose, int slice_num_rows, + int slice_block_size, int slice_num_cols, + std::vector*>>* mat_slices, + const DeviceBase::CpuWorkerThreads* thread_pool); + + // This function chops "mat" along column dimension into pieces with at most N + // columns, and concatenates the pieces one after the other in "buffer". It + // returns the list of the pieces in "slices". It returns a BlockingCounter + // which should be used to wait for the shuffle operations to complete. + static inline BlockingCounter* CreateDenseSlices( + const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start, + int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, + MatrixR* buffer, std::vector* slices); + + // Helper function for CreateDenseSlices to move the data around. It returns a + // BlockingCounter which should be used to wait for the shuffle operations to + // complete. + static inline BlockingCounter* ShuffleMatrix( + const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows, + int slice_col_start, int slice_num_cols, const int N, + const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer); + + // Helper function for CreateDenseSlices to create slices. + static inline void SliceMatrix(const MatrixR& mat, const int num_rows, + const int num_slices, + std::vector* slices); + + // Heuristics to compute various block sizes. + // KR, NR: block sizes for "right". We run blocking iterations that operate on + // matrices with at most this size. + // KL: grid size along the column dimension used while encoding left. + // IB, JB: number of left and right slices to multiply together. This is used + // for ordering different ComputeBlockOutput operations inside each blocking + // iteration so as to potentially reduce the working set size. + static inline void ComputeBlockSizes(const ConstMatrixMapL& left, + const ConstMatrixMapR& right, + bool transpose_left, int num_threads, + int* KR, int* NR, int* KL, int* JB, + int* IB); + + TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul); +}; + +template class SparseMatMulOp : public OpKernel { + typedef Eigen::Tensor MatrixR; + typedef Eigen::TensorMap, + Eigen::Aligned> + ConstMatrixMapR; + public: explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); @@ -461,12 +841,10 @@ class SparseMatMulOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), errors::InvalidArgument("b is not a matrix")); - auto left = a.matrix(); - auto right = b.matrix(); - const int m = transpose_a_ ? left.dimension(1) : left.dimension(0); - const int k = transpose_a_ ? left.dimension(0) : left.dimension(1); - const int n = transpose_b_ ? right.dimension(0) : right.dimension(1); - const int k2 = transpose_b_ ? right.dimension(1) : right.dimension(0); + const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0); + const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1); + const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1); + const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0); OP_REQUIRES(ctx, k == k2, errors::InvalidArgument("Matrix size incompatible: a: ", @@ -476,127 +854,94 @@ class SparseMatMulOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); auto out = output->matrix(); + std::unique_ptr a_float; + std::unique_ptr b_float; if (!a_is_sparse_ && !b_is_sparse_) { - // Fallback to Eigen contract. - // Note that we currently don't optimize the case where only right is - // sparse. That can generally be handled by transposing the order of the - // matmul. + auto left = &a; + auto right = &b; + // TODO(agarwal): multi-thread the conversions from bfloat16 to float. + if (std::is_same::value) { + a_float.reset(new Tensor(DT_FLOAT, a.shape())); + BFloat16ToFloat(a.flat().data(), + a_float->flat().data(), a.NumElements()); + left = a_float.get(); + } + if (std::is_same::value) { + b_float.reset(new Tensor(DT_FLOAT, b.shape())); + BFloat16ToFloat(b.flat().data(), + b_float->flat().data(), b.NumElements()); + right = b_float.get(); + } Eigen::array, 1> dim_pair; dim_pair[0].first = transpose_a_ ? 0 : 1; dim_pair[0].second = transpose_b_ ? 1 : 0; + out.device(ctx->template eigen_device()) = - left.contract(right, dim_pair); + left->matrix().contract(right->matrix(), dim_pair); return; } - auto left_mat = &left; - auto right_mat = &right; + + auto left = &a; + auto right = &b; bool transpose_output = false; bool transpose_a = transpose_a_; bool transpose_b = transpose_b_; if (!a_is_sparse_) { // Swap the order of multiplications using the identity: // A * B = (B' * A')'. - std::swap(left_mat, right_mat); + std::swap(left, right); std::swap(transpose_a, transpose_b); transpose_a = !transpose_a; transpose_b = !transpose_b; transpose_output = !transpose_output; } - std::unique_ptr right_tr_mat; - std::unique_ptr::ConstMatrix> right_tr_map; + + std::unique_ptr right_tr; if (transpose_b) { // TODO(agarwal): avoid transposing the matrix here and directly handle // transpose in CreateDenseSlices. - right_tr_mat.reset( - new Matrix(right_mat->dimension(1), right_mat->dimension(0))); + right_tr.reset( + new Tensor(right->dtype(), + TensorShape({right->dim_size(1), right->dim_size(0)}))); + Eigen::array perm({1, 0}); - right_tr_mat->device(ctx->template eigen_device()) = - right_mat->shuffle(perm); - right_tr_map.reset(new TTypes::ConstMatrix( - right_tr_mat->data(), right_tr_mat->dimensions())); - right_mat = right_tr_map.get(); + if (transpose_output) { + right_tr->matrix().device(ctx->template eigen_device()) = + right->matrix().shuffle(perm); + } else { + right_tr->matrix().device(ctx->template eigen_device()) = + right->matrix().shuffle(perm); + } + right = right_tr.get(); } - SparseMatMul(*left_mat, *right_mat, transpose_a, - ctx->device()->tensorflow_cpu_worker_threads(), - transpose_output, &out); + if (transpose_output) { + SparseMatMul::Compute( + left->matrix(), right->matrix(), transpose_a, + ctx->device()->tensorflow_cpu_worker_threads(), transpose_output, + &out); + } else { + SparseMatMul::Compute( + left->matrix(), right->matrix(), transpose_a, + ctx->device()->tensorflow_cpu_worker_threads(), transpose_output, + &out); + } } private: - // Perform matrix multiplication of "left" and "right", and store the result - // in *"output". - static inline void SparseMatMul( - const ConstMatrixMap& left, const ConstMatrixMap& right, - bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, - bool transpose_output, MatrixMap* output); - - // Computes multiplication of left and num_cols columns of right, and stores - // the output block in *"output" at offsets "output_row_offset" and - // "output_col_offset". If assign is true, assigns the value to that block, - // else adds the values to the existing values. - static inline void ComputeOutputBlock(const std::vector& left, - const ConstMatrixMap& right, - int num_cols, int output_row_offset, - int output_col_offset, bool assign, - bool transpose_output, - MatrixMap* output); - - // Encodes "mat" using a sparse representation and stores that in - // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and - // "slice_num_cols", each grid element is converted into a SparseSlice and - // stored in mat_slices. "slice_block_size" is used to perform further column - // blocking of each slice. - static inline BlockingCounter* CreateSparseSlices( - const ConstMatrixMap& mat, bool transpose, int slice_num_rows, - int slice_block_size, int slice_num_cols, - std::vector>* mat_slices, - const DeviceBase::CpuWorkerThreads* thread_pool); - - // This function chops "mat" along column dimension into pieces with at most N - // columns, and concatenates the pieces one after the other in "buffer". It - // returns the list of the pieces in "slices". It returns a BlockingCounter - // which should be used to wait for the shuffle operations to complete. - static inline BlockingCounter* CreateDenseSlices( - const ConstMatrixMap& mat, int row_start, int num_rows, int col_start, - int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, - Matrix* buffer, std::vector* slices); - - // Helper function for CreateDenseSlices to move the data around. It returns a - // BlockingCounter which should be used to wait for the shuffle operations to - // complete. - static inline BlockingCounter* ShuffleMatrix( - const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows, - int slice_col_start, int slice_num_cols, const int N, - const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer); - - // Helper function for CreateDenseSlices to create slices. - static inline void SliceMatrix(const Matrix& mat, const int num_rows, - const int num_slices, - std::vector* slices); - - // Heuristics to compute various block sizes. - // KR, NR: block sizes for "right". We run blocking iterations that operate on - // matrices with at most this size. - // KL: grid size along the column dimension used while encoding left. - // IB, JB: number of left and right slices to multiply together. This is used - // for ordering different ComputeBlockOutput operations inside each blocking - // iteration so as to potentially reduce the working set size. - static inline void ComputeBlockSizes(const ConstMatrixMap& left, - const ConstMatrixMap& right, - bool transpose_left, int num_threads, - int* KR, int* NR, int* KL, int* JB, - int* IB); - bool transpose_a_; bool transpose_b_; bool a_is_sparse_; bool b_is_sparse_; + TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); }; -inline void SparseMatMulOp::ComputeOutputBlock( - const std::vector& left, const ConstMatrixMap& right, - int num_cols, int output_row_offset, int output_col_offset, bool assign, +template +inline void SparseMatMul::ComputeOutputBlock( + const std::vector*>& left, + const SparseMatMul::ConstMatrixMapR& right, int num_cols, + int output_row_offset, int output_col_offset, bool assign, bool transpose_output, MatrixMap* output) { static const Eigen::array perm({1, 0}); int num_rows = left[0]->num_rows; @@ -605,9 +950,9 @@ inline void SparseMatMulOp::ComputeOutputBlock( Matrix out(num_rows, rhs_num_cols); out.setZero(); if (num_cols == N) { - GEPP(left, right, num_cols, &out); + GEPP(left, right, num_cols, &out); } else { - GEPP<-1>(left, right, num_cols, &out); + GEPP(left, right, num_cols, &out); } if (!assign) { const Eigen::array begin = {output_row_offset, output_col_offset}; @@ -643,10 +988,11 @@ inline void SparseMatMulOp::ComputeOutputBlock( } } -inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( - const ConstMatrixMap& mat, bool transpose, int slice_num_rows, - int slice_block_size, int slice_num_cols, - std::vector>* mat_slices, +template +inline BlockingCounter* SparseMatMul::CreateSparseSlices( + const SparseMatMul::ConstMatrixMapL& mat, bool transpose, + int slice_num_rows, int slice_block_size, int slice_num_cols, + std::vector*>>* mat_slices, const DeviceBase::CpuWorkerThreads* thread_pool) { const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0); const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1); @@ -657,12 +1003,13 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( mat_slices->resize(num_slices_dim0); BlockingCounter* counter = new BlockingCounter(num_slices_dim0 * num_slices_dim1); - auto work = [counter, transpose](SparseSlice* sparse_slice, - ConstMatrixMap* slice, int col_offset) { + auto work = [counter, transpose](SparseSlice* sparse_slice, + SparseMatMul::ConstMatrixMapL* slice, + int col_offset) { if (transpose) { - sparse_slice->Initialize(*slice, col_offset); + sparse_slice->template Initialize(*slice, col_offset); } else { - sparse_slice->Initialize(*slice, col_offset); + sparse_slice->template Initialize(*slice, col_offset); } delete slice; counter->DecrementCount(); @@ -674,16 +1021,17 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( for (int j = 0; j < num_slices_dim1; ++j) { int num_cols = std::min(slice_num_cols, mat_num_cols - j * slice_num_cols); - ConstMatrixMap* slice = nullptr; + SparseMatMul::ConstMatrixMapL* slice = nullptr; if (transpose) { - slice = - new ConstMatrixMap(&mat(0, i * slice_num_rows), mat.dimensions()); + slice = new SparseMatMul::ConstMatrixMapL( + &mat(0, i * slice_num_rows), mat.dimensions()); } else { DSizes d(num_rows, mat_num_cols); - slice = new ConstMatrixMap(&mat(i * slice_num_rows, 0), d); + slice = new SparseMatMul::ConstMatrixMapL( + &mat(i * slice_num_rows, 0), d); } - SparseSlice* sparse_slice = - new SparseSlice(num_rows, num_cols, slice_block_size); + auto* sparse_slice = + new SparseSlice(num_rows, num_cols, slice_block_size); (*mat_slices)[i][j] = sparse_slice; thread_pool->workers->Schedule( std::bind(work, sparse_slice, slice, slice_num_cols * j)); @@ -691,11 +1039,58 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices( } return counter; } +#define LOAD(x) Eigen::internal::ploadu((x)); +#define INTERLEAVE(x) Eigen::internal::pinterleave4x64(x); +#define STORE(x, y) Eigen::internal::pstoreu(x, y); + +template +ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc, + int num_elements) { + DCHECK_GE(kNumOperands, 8); + static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16); + const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM; + DCHECK_EQ(num, num_elements); + const float* src = reinterpret_cast(bsrc); + float* dst = reinterpret_cast(bdst); + for (int index = 0; index + kStep <= num; index += kStep) { + auto in = LOAD(src); + auto tmp = INTERLEAVE(in); + STORE(dst, tmp); + src += kNumOperands; + dst += kNumOperands; + } + if (num % kStep != 0) { + memcpy(dst, src, (num % kStep) * sizeof(bfloat16)); + } +} -inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( - const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows, - int slice_col_start, int slice_num_cols, const int N, - const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer) { +template +ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src, + int num_elements) { + if (std::is_same::value || kNumOperands < 8) { + memcpy(dst, src, num_elements * sizeof(T)); + } else if (std::is_same::value) { + if (num_elements == N) { + CopyAndMayBeInterleaveBfloat16(dst, src, num_elements); + } else { + CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements); + } + } else { + LOG(FATAL) << "Unsupported type"; + } +} + +#undef LOAD +#undef Interleave +#undef Store + +template +inline BlockingCounter* SparseMatMul::ShuffleMatrix( + const SparseMatMul::ConstMatrixMapR& mat, int slice_row_start, + int slice_num_rows, int slice_col_start, int slice_num_cols, const int N, + const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) { + DCHECK_EQ(N % 2, 0); + DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N); int num_threads = std::min(thread_pool->num_threads, 16); BlockingCounter* counter = new BlockingCounter(num_threads); DCHECK_EQ(N, buffer->dimension(1)); @@ -703,17 +1098,17 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( slice_num_cols, N, buffer, counter](int s, int e) { const int row_start = s % slice_num_rows + slice_row_start; const int col_start = s / slice_num_rows * N + slice_col_start; - float* out_start = &(*buffer)(s, 0); - const float* input_start = &mat(row_start, col_start); - const float* input_end = &mat(slice_row_start + slice_num_rows - 1, - slice_col_start + slice_num_cols - 1); + auto* out_start = &(*buffer)(s, 0); + const auto* input_start = &mat(row_start, col_start); + const auto* input_end = &mat(slice_row_start + slice_num_rows - 1, + slice_col_start + slice_num_cols - 1); const int mat_num_cols = mat.dimension(1); const int row_slice_size = slice_num_rows * mat_num_cols; const int aligned_end = slice_num_cols / N * slice_num_rows; const int e1 = std::min(e, aligned_end); while (s < e1) { - memcpy(out_start, input_start, N * sizeof(float)); + CopyAndMayBeInterleave(out_start, input_start, N); out_start += N; input_start += mat_num_cols; if (input_start > input_end) { @@ -724,7 +1119,7 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( int s1 = std::max(s, aligned_end); const int copy_num_cols = slice_num_cols % N; while (s1 < e) { - memcpy(out_start, input_start, copy_num_cols * sizeof(float)); + CopyAndMayBeInterleave(out_start, input_start, copy_num_cols); out_start += N; input_start += mat_num_cols; ++s1; @@ -745,21 +1140,24 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix( return counter; } -inline void SparseMatMulOp::SliceMatrix(const Matrix& mat, const int num_rows, - const int num_slices, - std::vector* slices) { +template +inline void SparseMatMul::SliceMatrix( + const MatrixR& mat, const int num_rows, const int num_slices, + std::vector::ConstMatrixMapR*>* slices) { slices->resize(num_slices); DSizes d(num_rows, mat.dimension(1)); DCHECK_LE(num_rows * num_slices, mat.dimension(0)); for (int i = 0; i < num_slices; ++i) { - (*slices)[i] = new ConstMatrixMap(&mat(i * num_rows, 0), d); + (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d); } } -inline BlockingCounter* SparseMatMulOp::CreateDenseSlices( - const ConstMatrixMap& mat, int row_start, int num_rows, int col_start, - int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool, - Matrix* buffer, std::vector* slices) { +template +inline BlockingCounter* SparseMatMul::CreateDenseSlices( + const SparseMatMul::ConstMatrixMapR& mat, int row_start, + int num_rows, int col_start, int num_cols, + const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer, + std::vector::ConstMatrixMapR*>* slices) { BlockingCounter* shuffle_counter = ShuffleMatrix( mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer); const int num_slices = (num_cols + N - 1) / N; @@ -767,11 +1165,11 @@ inline BlockingCounter* SparseMatMulOp::CreateDenseSlices( return shuffle_counter; } -inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left, - const ConstMatrixMap& right, - bool transpose_left, - int num_threads, int* KR, int* NR, - int* KL, int* JB, int* IB) { +template +inline void SparseMatMul::ComputeBlockSizes( + const SparseMatMul::ConstMatrixMapL& left, + const SparseMatMul::ConstMatrixMapR& right, bool transpose_left, + int num_threads, int* KR, int* NR, int* KL, int* JB, int* IB) { // Heuristics for calculating block sizes // Assume two hyperthreads per core. const int est_num_cores = std::max(1, (num_threads + 1) / 2); @@ -838,20 +1236,21 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left, // and update the output block o_ij. These calls are further blocked to // reduce the working set size. In each iteration we take IB elements from // {l_i} and JB elements from {r_j} and compute the IB * JB inner products. -inline void SparseMatMulOp::SparseMatMul( - const ConstMatrixMap& left, const ConstMatrixMap& right, - bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, - bool transpose_output, MatrixMap* output) { +template +inline void SparseMatMul::Compute( + const SparseMatMul::ConstMatrixMapL& left, + const SparseMatMul::ConstMatrixMapR& right, bool transpose_left, + const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output, + MatrixMap* output) { const int num_threads = thread_pool->num_threads; int KR, NR, KL, JB, IB; ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, &JB, &IB); - // Slice the left matrix - std::vector> left_slices; + std::vector*>> left_slices; std::unique_ptr sparse_slice_counter; sparse_slice_counter.reset( - CreateSparseSlices(ConstMatrixMap(left.data(), left.dimensions()), + CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()), transpose_left, M, K, KL, &left_slices, thread_pool)); const int num_left_slices = left_slices.size(); @@ -862,10 +1261,10 @@ inline void SparseMatMulOp::SparseMatMul( // is the block size per iteration. const int buffer_num_rows = std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N; - Matrix buffer(buffer_num_rows, N); - std::vector right_slices; + MatrixR buffer(buffer_num_rows, N); + std::vector right_slices; - std::vector block_left_slices; + std::vector*> block_left_slices; std::vector> tasks; // Number of blocks based on block sizes of KR * NR. const int num_k_blocks = (right_dim0 + KR - 1) / KR; @@ -890,7 +1289,6 @@ inline void SparseMatMulOp::SparseMatMul( const int num_cols = std::min(N, right_num_cols - N * j_inner); for (int i_inner = i_outer; i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) { - // Figure out which left slices to use. block_left_slices.clear(); int begin = kb * KR / KL; int end = std::min((kb + 1) * KR / KL, @@ -900,7 +1298,7 @@ inline void SparseMatMulOp::SparseMatMul( left_slices[i_inner].begin() + begin, left_slices[i_inner].begin() + end); tasks.push_back(std::bind( - &SparseMatMulOp::ComputeOutputBlock, block_left_slices, + &ComputeOutputBlock, block_left_slices, std::ref(*right_slices[j_inner]), num_cols, M * i_inner, N * j_inner + nb * NR, kb == 0, transpose_output, output)); } @@ -933,7 +1331,18 @@ inline void SparseMatMulOp::SparseMatMul( } } -REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU), - SparseMatMulOp); +#define REGISTER_SPARSE_MATMUL(TA, TB) \ + REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Ta") \ + .TypeConstraint("Tb"), \ + SparseMatMulOp); + +REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); +REGISTER_SPARSE_MATMUL(float, bfloat16); +REGISTER_SPARSE_MATMUL(bfloat16, float); +REGISTER_SPARSE_MATMUL(float, float); + +#undef REGISTER_SPARSE_MATMUL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc index 2f28b8ff9e..676ea6e28c 100644 --- a/tensorflow/core/kernels/sparse_matmul_op_test.cc +++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_matmul_op.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/bfloat16.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/node_builder.h" @@ -26,11 +27,13 @@ limitations under the License. namespace tensorflow { random::PhiloxRandom philox(1, 1); random::SimplePhilox rnd(&philox); +using Eigen::operator==; +template void Sparsify(Tensor* t, float sparsity) { const int64 N = t->NumElements(); CHECK_LE(sparsity, 1); - auto flat = t->flat(); + auto flat = t->flat(); if (sparsity == 1) { flat.setZero(); return; @@ -38,9 +41,9 @@ void Sparsify(Tensor* t, float sparsity) { static const uint32 K = 10000; for (int64 i = 0; i < N; ++i) { if (rnd.Uniform(K) < sparsity * K) { - flat(i) = 0; - } else if (flat(i) == 0) { - flat(i) = 0.1; + flat(i) = T(0); + } else if (flat(i) == T(0)) { + flat(i) = T(1); } } } @@ -59,6 +62,7 @@ Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a, return ret; } +template static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, float sparsity_a, float sparsity_b, bool transpose_a, bool transpose_b) { @@ -66,14 +70,14 @@ static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, bool b_sparse = (sparsity_b > 0); auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d}); - Tensor left(DataTypeToEnum::value, left_shape); - left.flat().setRandom(); - Sparsify(&left, sparsity_a); + Tensor left(DataTypeToEnum::value, left_shape); + left.flat().setRandom(); + Sparsify(&left, sparsity_a); auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n}); - Tensor right(DataTypeToEnum::value, right_shape); - right.flat().setRandom(); - Sparsify(&right, sparsity_b); + Tensor right(DataTypeToEnum::value, right_shape); + right.flat().setRandom(); + Sparsify(&right, sparsity_b); SparseMatMulNode(g, test::graph::Constant(g, left), test::graph::Constant(g, right), transpose_a, transpose_b, @@ -81,59 +85,82 @@ static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, return g; } +template static Graph* SparseMatMul(int m, int n, int d, float sparsity_a, float sparsity_b, bool transpose_a, bool transpose_b) { Graph* g = new Graph(OpRegistry::Global()); - return SparseMatMulHelper(g, m, n, d, sparsity_a, sparsity_b, transpose_a, - transpose_b); + return SparseMatMulHelper(g, m, n, d, sparsity_a, sparsity_b, + transpose_a, transpose_b); } -#define BM_SPARSE(M, K, N, S1, S2, TA, TB) \ - static void BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB( \ - int iters) { \ - testing::StopTiming(); \ - testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ - std::string label = \ - strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", TA, TB, \ - S1 / 100.0, S2 / 100.0); \ - testing::SetLabel(label); \ - testing::UseRealTime(); \ - auto g = SparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, TA, TB); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ - } \ - BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB); - -BM_SPARSE(2048, 2048, 2048, 0, 0, false, false); -BM_SPARSE(2048, 2048, 2048, 1, 0, false, false); -BM_SPARSE(2048, 2048, 2048, 50, 0, false, false); -BM_SPARSE(2048, 2048, 2048, 85, 0, false, false); -BM_SPARSE(2048, 2048, 2048, 99, 0, false, false); - -BM_SPARSE(2048, 2048, 2048, 0, 50, false, false); -BM_SPARSE(2048, 2048, 2048, 0, 85, false, false); - -BM_SPARSE(2048, 2048, 2048, 85, 0, true, false); -BM_SPARSE(2048, 2048, 2048, 85, 0, false, true); -BM_SPARSE(2048, 2048, 2048, 85, 0, true, true); - -BM_SPARSE(2048, 2048, 2048, 0, 85, true, false); -BM_SPARSE(2048, 2048, 2048, 0, 85, false, true); -BM_SPARSE(2048, 2048, 2048, 0, 85, true, true); - -BM_SPARSE(1024, 1024, 1024, 0, 0, false, false); -BM_SPARSE(1024, 1024, 1024, 1, 0, false, false); -BM_SPARSE(1024, 1024, 1024, 85, 0, false, false); - -BM_SPARSE(256, 256, 256, 1, 0, false, false); -BM_SPARSE(512, 512, 512, 1, 0, false, false); +#define BM_SPARSE(M, K, N, S1, S2, TRA, TRB, TA, TB) \ + static void \ + BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB( \ + int iters) { \ + testing::StopTiming(); \ + testing::ItemsProcessed(static_cast(iters) * M * K * N * 2); \ + std::string label = \ + strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", TRA, TRB, \ + S1 / 100.0, S2 / 100.0); \ + testing::SetLabel(label); \ + testing::UseRealTime(); \ + auto g = SparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, TRA, TRB); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ + BENCHMARK( \ + BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB); + +#define BM_SPARSE_FLOAT(M, K, N, S1, S2, TRA, TRB) \ + BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, float) +#define BM_SPARSE_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \ + BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, bfloat16) +#define BM_SPARSE_FLOAT_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \ + BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, bfloat16) +#define BM_SPARSE_BFLOAT16_FLOAT(M, K, N, S1, S2, TRA, TRB) \ + BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, float) + +// Test sparse b +BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 0, false, false); +BM_SPARSE_FLOAT(2048, 2048, 2048, 1, 0, false, false); +BM_SPARSE_FLOAT(2048, 2048, 2048, 50, 0, false, false); +BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, false); +BM_SPARSE_FLOAT(2048, 2048, 2048, 99, 0, false, false); +// Test sparse a +BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 50, false, false); +BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, false); +// Test transposing +BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, false); +BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, true); +BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, true); +BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, false); +BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, true); +BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, true); + +// Test smaller sizes +BM_SPARSE_FLOAT(1024, 1024, 1024, 0, 0, false, false); +BM_SPARSE_FLOAT(1024, 1024, 1024, 1, 0, false, false); +BM_SPARSE_FLOAT(1024, 1024, 1024, 85, 0, false, false); +BM_SPARSE_FLOAT(256, 256, 256, 1, 0, false, false); +BM_SPARSE_FLOAT(512, 512, 512, 1, 0, false, false); + +// Test bfloat16 +BM_SPARSE_BFLOAT16(2048, 2048, 2048, 0, 0, false, false); +BM_SPARSE_BFLOAT16(2048, 2048, 2048, 1, 0, false, false); +BM_SPARSE_BFLOAT16(2048, 2048, 2048, 85, 0, false, false); +BM_SPARSE_BFLOAT16(2048, 2048, 2048, 99, 0, false, false); +BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 85, 0, false, false); +BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 99, 0, false, false); +BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 85, 0, false, false); +BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 99, 0, false, false); static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_1, float sparsity_2) { Graph* g = new Graph(OpRegistry::Global()); - SparseMatMulHelper(g, d, n, m, sparsity_1, sparsity_2, true, false); - SparseMatMulHelper(g, m, d, n, sparsity_2, 0, false, true); + SparseMatMulHelper(g, d, n, m, sparsity_1, sparsity_2, true, + false); + SparseMatMulHelper(g, m, d, n, sparsity_2, 0, false, true); return g; } diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index fb35d40734..3af6b96c84 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -167,6 +167,7 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, .HostMemory("perm"), \ TransposeCpuOp); TF_CALL_ALL_TYPES(REGISTER) +REGISTER(bfloat16); #undef REGISTER #if GOOGLE_CUDA diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index f3b0d919a3..ba0b5e4bbb 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -594,13 +594,15 @@ transpose_b: If true, "b" is transposed before multiplication. )doc"); REGISTER_OP("SparseMatMul") - .Input("a: float") - .Input("b: float") + .Input("a: Ta") + .Input("b: Tb") .Output("product: float") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") .Attr("a_is_sparse: bool = false") .Attr("b_is_sparse: bool = false") + .Attr("Ta: {float, bfloat16} = DT_FLOAT") + .Attr("Tb: {float, bfloat16} = DT_FLOAT") .Doc(R"doc( Multiply matrix "a" by matrix "b". diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 1c7c4ea1cf..90431d631d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1107,6 +1107,7 @@ medium_kernel_test_list = glob([ "kernel_tests/seq2seq_test.py", "kernel_tests/slice_op_test.py", "kernel_tests/sparse_ops_test.py", + "kernel_tests/sparse_matmul_op_test.py", "kernel_tests/sparse_tensor_dense_matmul_op_test.py", ]) @@ -1154,7 +1155,6 @@ cpu_only_kernel_test_list = glob([ "kernel_tests/self_adjoint_eig_op_test.py", "kernel_tests/sparse_add_op_test.py", "kernel_tests/sparse_concat_op_test.py", - "kernel_tests/sparse_matmul_op_test.py", "kernel_tests/sparse_split_op_test.py", "kernel_tests/sparse_reorder_op_test.py", "kernel_tests/sparse_to_dense_op_test.py", diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py index d1f3aef880..6142bdab9c 100644 --- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py @@ -22,90 +22,115 @@ import numpy as np import tensorflow as tf -def RandMatrix(rows, cols, tr): +def RandMatrix(rows, cols, tr, round_bfloat=False): if tr: rows, cols = cols, rows - return (np.clip(np.random.uniform(low=-100.0, high=100.0, size=rows * cols), - 0, 100) / 100).reshape([rows, cols]).astype(np.float32) + rand_func = np.random.randint if round_bfloat else np.random.uniform + return (np.clip(rand_func(low=-256.0, high=256.0, size=rows * cols), + -64, 64) / 128.0).reshape([rows, cols]).astype(np.float32) class SparseMatMulTest(tf.test.TestCase): - def _testCpuMatmul(self, x, y, tr_a=False, tr_b=False, - sp_a=True, sp_b=False): - x_mat = np.matrix(x) - if tr_a: - x_mat = np.transpose(x_mat) - y_mat = np.matrix(y) - if tr_b: - y_mat = np.transpose(y_mat) - np_ans = x_mat * y_mat + def _testCpuMatmul(self, x, y, + tr_a=False, tr_b=False, + sp_a=True, sp_b=False, + x_dtype=tf.float32, + y_dtype=tf.float32): with self.test_session(use_gpu=False): - tf_ans = tf.matmul(x, y, + tf_x = tf.cast(x, x_dtype) + tf_y = tf.cast(y, y_dtype) + tf_ans = tf.matmul(tf_x, tf_y, transpose_a=tr_a, transpose_b=tr_b, a_is_sparse=sp_a, b_is_sparse=sp_b) out = tf_ans.eval() - self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) + np_x = tf.cast(tf_x, tf.float32).eval() + np_y = tf.cast(tf_y, tf.float32).eval() + + if tr_a: + np_x = np.transpose(np_x) + if tr_b: + np_y = np.transpose(np_y) + + np_ans = np.matrix(np_x) * np.matrix(np_y) self.assertShapeEqual(np_ans, tf_ans) + self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) - def testFloatBasic(self): + def testBasic(self): x = np.arange(0., 4.).reshape([4, 1]).astype(np.float32) y = np.arange(-1., 1.).reshape([1, 2]).astype(np.float32) - self._testCpuMatmul(x, y) + for x_dtype in (tf.float32, tf.bfloat16): + for y_dtype in (tf.float32, tf.bfloat16): + self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype) # Tests setting one dimension to be a high value. - def testFloatLarge(self): + def testLarge(self): r1 = np.random.randint(6000, 20000) r2 = np.random.randint(1, 10) r3 = np.random.randint(1, 10) for m, k, n in [(r1, r2, r3), (r2, r1, r3), (r2, r3, r1)]: - x = RandMatrix(m, k, False) - y = RandMatrix(k, n, False) - self._testCpuMatmul(x, y) - self._testCpuMatmul(x, y, sp_a=False, sp_b=True) + for x_dtype in (tf.float32, tf.bfloat16): + for y_dtype in (tf.float32, tf.bfloat16): + x = RandMatrix(m, k, False) + y = RandMatrix(k, n, False) + self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype) # Tests random sized matrices. - def testFloatRandom(self): + def testRandom(self): for _ in range(10): for tr_a in [True, False]: for tr_b in [True, False]: for sp_a in [True, False]: for sp_b in [True, False]: - n, k, m = np.random.randint(1, 100, size=3) - x = RandMatrix(n, k, tr_a) - y = RandMatrix(k, m, tr_b) - self._testCpuMatmul(x, y, tr_a, tr_b, sp_a, sp_b) + for x_dtype in (tf.float32, tf.bfloat16): + for y_dtype in (tf.float32, tf.bfloat16): + n, k, m = np.random.randint(1, 100, size=3) + x = RandMatrix(n, k, tr_a) + y = RandMatrix(k, m, tr_b) + self._testCpuMatmul(x, y, tr_a, tr_b, sp_a, sp_b, + x_dtype=x_dtype, y_dtype=y_dtype) class MatMulGradientTest(tf.test.TestCase): - def _testGradients(self, tr_a, tr_b, sp_a, sp_b, name): + def _testGradients(self, tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, name): with self.test_session(): - a = tf.constant(RandMatrix(3, 2, tr_a), dtype=tf.float32) - b = tf.constant(RandMatrix(2, 4, tr_b), dtype=tf.float32) - m = tf.matmul(a, b, + a = tf.constant(RandMatrix(3, 2, tr_a, round_bfloat=True), + dtype=tf.float32) + b = tf.constant(RandMatrix(2, 4, tr_b, round_bfloat=True), + dtype=tf.float32) + tf_a = tf.cast(a, a_dtype) if a_dtype != tf.float32 else a + tf_b = tf.cast(b, b_dtype) if b_dtype != tf.float32 else b + + m = tf.matmul(tf_a, tf_b, name=name, transpose_a=tr_a, transpose_b=tr_b, a_is_sparse=sp_a, b_is_sparse=sp_b) err = (tf.test.compute_gradient_error(a, [2, 3] - if tr_a else [3, 2], m, [3, 4]) + + if tr_a else [3, 2], m, [3, 4], + x_init_value=a.eval(), + delta=1/64.) + tf.test.compute_gradient_error(b, [4, 2] - if tr_b else [2, 4], m, [3, 4])) - print("sparse_matmul gradient err = ", err) - self.assertLess(err, 1e-3) + if tr_b else [2, 4], m, [3, 4], + x_init_value=b.eval(), + delta=1/64.)) + self.assertLess(err, 1/128.) def testGradientInput(self): for tr_a in [True, False]: for tr_b in [True, False]: for sp_a in [True, False]: for sp_b in [True, False]: - name = "sparse_matmul_%s_%s_%s_%s" % (tr_a, tr_b, sp_a, sp_b) - self._testGradients(tr_a, tr_b, sp_a, sp_b, name) + for a_dtype in (tf.float32, tf.bfloat16): + for b_dtype in (tf.float32, tf.bfloat16): + name = "sparse_matmul_%s_%s_%s_%s" % (tr_a, tr_b, sp_a, sp_b) + self._testGradients(tr_a, tr_b, sp_a, sp_b, + a_dtype, b_dtype, name) if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index d5c6e9fc91..ab71e28cab 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -526,7 +526,8 @@ def _SparseMatMulGrad(op, grad): # Use heuristic to figure out if grad might be sparse grad: (grad.op.type == "ReluGrad") } - def _SparseMatMul(t1, t2, transpose_a=False, transpose_b=False): + def _SparseMatMul(t1, t2, out_dtype, + transpose_a=False, transpose_b=False): """Helper function to create SparseMatMul op.""" assert t1 in is_sparse and t2 in is_sparse @@ -535,25 +536,30 @@ def _SparseMatMulGrad(op, grad): if transpose_b: t2 = array_ops.transpose(t2) transpose_b = False - return math_ops.matmul(t1, t2, + prod = math_ops.matmul(t1, t2, transpose_a=transpose_a, transpose_b=transpose_b, a_is_sparse=t1_sparse, b_is_sparse=t2_sparse) + if prod.dtype != out_dtype: + prod = math_ops.cast(prod, out_dtype) + return prod + dtype_a = op.inputs[0].dtype + dtype_b = op.inputs[1].dtype if not t_a and not t_b: - return (_SparseMatMul(grad, op.inputs[1], transpose_b=True), - _SparseMatMul(op.inputs[0], grad, transpose_a=True)) + return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True), + _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True)) elif not t_a and t_b: - return (_SparseMatMul(grad, op.inputs[1]), - _SparseMatMul(grad, op.inputs[0], transpose_a=True)) + return (_SparseMatMul(grad, op.inputs[1], dtype_a), + _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True)) elif t_a and not t_b: - return (_SparseMatMul(op.inputs[1], grad, transpose_b=True), - _SparseMatMul(op.inputs[0], grad)) + return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True), + _SparseMatMul(op.inputs[0], grad, dtype_b)) elif t_a and t_b: - return (_SparseMatMul(op.inputs[1], grad, + return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True), - _SparseMatMul(grad, op.inputs[0], + _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True, transpose_b=True)) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 9ebf251574..f0f438a33d 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1087,7 +1087,14 @@ def matmul(a, b, with ops.op_scope([a, b], name, "MatMul") as name: a = ops.convert_to_tensor(a, name="a") b = ops.convert_to_tensor(b, name="b") - if a.dtype == dtypes.float32 and (a_is_sparse or b_is_sparse): + sparse_matmul_types = [dtypes.bfloat16, dtypes.float32] + use_sparse_matmul = (a.dtype in sparse_matmul_types and + b.dtype in sparse_matmul_types and + (a_is_sparse or b_is_sparse)) + if dtypes.bfloat16 in (a.dtype, b.dtype): + # matmul currently doesn't handle bfloat16 inputs. + use_sparse_matmul = True + if use_sparse_matmul: return sparse_matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b, -- cgit v1.2.3