aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-10 12:30:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-10 13:32:14 -0700
commit6b373bb149396beec11347881b2d6dedfbcc83c4 (patch)
tree9eddf82919c16becd71b55243f48a0967e9b5db4
parent49d25eae890216f15833adfdd1e668479470745d (diff)
Allow bfloat16 inputs to SparseMatMul.
Change: 121980920
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.cc1019
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op_test.cc133
-rw-r--r--tensorflow/core/kernels/transpose_op.cc1
-rw-r--r--tensorflow/core/ops/math_ops.cc6
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py97
-rw-r--r--tensorflow/python/ops/math_grad.py26
-rw-r--r--tensorflow/python/ops/math_ops.py9
8 files changed, 885 insertions, 408 deletions
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index 067aa2624d..6cfaa3b958 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -18,9 +18,11 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/sparse_matmul_op.h"
+
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -35,12 +37,16 @@ namespace tensorflow {
namespace {
+using Eigen::operator==;
typedef Eigen::Tensor<float, 2, Eigen::RowMajor> Matrix;
typedef Eigen::DSizes<Eigen::DenseIndex, 2> DSizes;
-typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>,
- Eigen::Aligned> MatrixMap;
typedef Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor>,
- Eigen::Aligned> ConstMatrixMap;
+ Eigen::Aligned>
+ ConstMatrixMap;
+typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ MatrixMap;
+
typedef Eigen::ThreadPoolDevice CPUDevice;
// Blocksizes
@@ -52,8 +58,7 @@ static const int N = 128;
// This stores a sparse representation of a slice of a matrix with size
// (num_rows, num_cols). The slice is represented as a series of blocks of size
// (num_rows, b), where b = block_size for all but the last block, which may
-// have
-// fewer columns.
+// have fewer columns.
//
// num_rows and block_size are assumed to be <= 256. This allows storing
// different indices as uint8.
@@ -71,7 +76,12 @@ static const int N = 128;
// are the values in the following range:
// [index3[index3_offset[i-1]], index3[index3_offset[i]]). Similarly for
// index_offset.
+template <typename T>
struct SparseSlice {
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMap;
+
public:
// Indices of three elements on the same row.
struct Index3 {
@@ -105,13 +115,13 @@ struct SparseSlice {
// See comments above.
std::vector<int> index3_offset;
std::vector<Index3> index3;
- std::vector<float> data3;
+ std::vector<T> data3;
// See comments above. Similar to "index3" except that each element in "index"
// corresponds to one element in data.
std::vector<int> index_offset;
std::vector<Index> index;
- std::vector<float> data;
+ std::vector<T> data;
// Number of rows and columns for the slice.
const int num_rows;
@@ -121,8 +131,10 @@ struct SparseSlice {
const int block_size;
};
+template <typename T>
template <bool Transpose>
-void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
+void SparseSlice<T>::Initialize(const SparseSlice<T>::ConstMatrixMap& mat,
+ int col_offset) {
const int mat_rows = Transpose ? mat.dimension(1) : mat.dimension(0);
const int mat_cols = Transpose ? mat.dimension(0) : mat.dimension(1);
DCHECK_LE(num_rows, mat_rows);
@@ -142,6 +154,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
Index3 idx3;
Index idx;
int data3_size = 0;
+ static const T zero(0);
for (int i = 0; i < num_blocks; ++i) {
int num_block_cols = std::min(block_size, num_cols - block_size * i);
for (int row = 0; row < num_rows; ++row) {
@@ -150,17 +163,17 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
// *curr is nonzero and then reads it again on use. However, the result
// of the race is only that some of the "nonzeros" in the resulting sparse
// representation may actually be zero, which is harmless.
- const float* start =
+ const auto* start =
Transpose ? &mat(col_offset, row) : &mat(row, col_offset);
- const float* curr = start;
+ const auto* curr = start;
const int stride = Transpose ? mat.dimension(1) : 1;
- const float* end = start + stride * num_block_cols;
+ const auto* end = start + stride * num_block_cols;
uint8 k = 0;
#define NEXT_ELEM \
curr += stride; \
++k;
while (true) {
- while (curr < end && (*curr == 0)) {
+ while (curr < end && (*curr == zero)) {
NEXT_ELEM;
}
if (curr >= end) break;
@@ -168,7 +181,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
data3.push_back(*curr);
NEXT_ELEM;
- while (curr < end && (*curr == 0)) {
+ while (curr < end && (*curr == zero)) {
NEXT_ELEM;
}
if (curr >= end) break;
@@ -176,7 +189,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
data3.push_back(*curr);
NEXT_ELEM;
- while (curr < end && (*curr == 0)) {
+ while (curr < end && (*curr == zero)) {
NEXT_ELEM;
}
if (curr >= end) break;
@@ -213,7 +226,8 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
DCHECK_EQ(index.size(), data.size());
}
-void SparseSlice::Clear() {
+template <typename T>
+void SparseSlice<T>::Clear() {
index3_offset.clear();
index3.clear();
data3.clear();
@@ -222,107 +236,356 @@ void SparseSlice::Clear() {
data.clear();
}
-#define SCALAR_MULADD(a, inp, out) *out++ += *a * *inp++;
-
-#define SCALAR_MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \
- *out++ += *a1 * *inp1++ + *a2 * *inp2++ + *a3 * *inp3++;
-
typedef Eigen::internal::packet_traits<float>::type Packet;
static const int kNumOperands = (sizeof(Packet) / sizeof(float));
#define LOAD(x) Eigen::internal::pload<Packet>(x);
+#define EXPAND_BFLOAT_L(x, y) \
+ const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
+#define EXPAND_BFLOAT_U(x, y) \
+ const auto y = Eigen::internal::pexpand_bf16_u<Packet>(x);
#define STORE(x, y) Eigen::internal::pstore<float>(x, y);
-#define LOAD_SCALAR(x, y) const auto y = Eigen::internal::pload1<Packet>(x);
#define FMA(a, b, c, d) d = Eigen::internal::pmadd<Packet>(a, b, c);
-// Vectorized version of SCALAR_MULADD.
-#define MULADD(a, inp, out) \
- do { \
- const auto b = LOAD(inp); \
- inp += kNumOperands; \
- auto c = LOAD(out); \
- FMA(a, b, c, c); \
- STORE(out, c); \
- out += kNumOperands; \
- } while (false)
-
-// Vectorized version of SCALAR_MULADD3WAY.
-#define MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out) \
- do { \
- auto c = LOAD(out); \
- const auto b1 = LOAD(inp1); \
- inp1 += kNumOperands; \
- const auto b2 = LOAD(inp2); \
- inp2 += kNumOperands; \
- const auto b3 = LOAD(inp3); \
- inp3 += kNumOperands; \
- FMA(a1, b1, c, c); \
- FMA(a2, b2, c, c); \
- FMA(a3, b3, c, c); \
- STORE(out, c); \
- out += kNumOperands; \
- } while (false)
-
-#ifdef EIGEN_VECTORIZE_AVX2
-// Unroll MULADD3WAY for two iterations
-#define MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out) \
- do { \
- auto c1 = LOAD(out); \
- const auto b1 = LOAD(inp1); \
- const auto b2 = LOAD(inp2); \
- const auto b3 = LOAD(inp3); \
- \
- auto c2 = LOAD(out + kNumOperands); \
- const auto b4 = LOAD(inp1 + kNumOperands); \
- const auto b5 = LOAD(inp2 + kNumOperands); \
- const auto b6 = LOAD(inp3 + kNumOperands); \
- \
- FMA(a1, b1, c1, c1); \
- FMA(a1, b4, c2, c2); \
- FMA(a2, b2, c1, c1); \
- FMA(a2, b5, c2, c2); \
- FMA(a3, b3, c1, c1); \
- FMA(a3, b6, c2, c2); \
- STORE(out, c1); \
- STORE(out + kNumOperands, c2); \
- out += 2 * kNumOperands; \
- inp1 += 2 * kNumOperands; \
- inp2 += 2 * kNumOperands; \
- inp3 += 2 * kNumOperands; \
- } while (false)
-// Further unroll MULADD3WAY.
-#define MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out) \
- MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_16(a1, a2, a3, inp1, inp2, inp3, out);
-#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY_32(a1, a2, a3, inp1, inp2, inp3, out);
-#else
-#define MULADD3WAY_128(a1, a2, a3, inp1, inp2, inp3, out) \
- for (int __i = 0; __i < 128 / (4 * kNumOperands); ++__i) { \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
- MULADD3WAY(a1, a2, a3, inp1, inp2, inp3, out); \
+#define ALWAYS_INLINE EIGEN_ALWAYS_INLINE
+
+ALWAYS_INLINE float ConvertBfloat16ToFloat(const bfloat16* src) {
+ float out = 0;
+ auto tmp = reinterpret_cast<bfloat16*>(&out);
+ tmp[1] = *src;
+ return out;
+}
+
+ALWAYS_INLINE Packet ConvertFourBfloat16ToFloat(const bfloat16* src) {
+ return Eigen::internal::pload4bf16<Packet>(
+ reinterpret_cast<const float*>(src));
+}
+
+ALWAYS_INLINE Packet ConvertTwoBfloat16ToFloat(const bfloat16* src) {
+ return Eigen::internal::pload2bf16<Packet>(
+ reinterpret_cast<const float*>(src));
+}
+
+ALWAYS_INLINE void ScalarMulAdd(const float a, const float** inp, float** out) {
+ **out += a * **inp;
+ ++*inp;
+ ++*out;
+}
+
+ALWAYS_INLINE void ScalarMulAdd(const float a, const bfloat16** inp,
+ float** out) {
+ float inp_f = ConvertBfloat16ToFloat(*inp);
+ **out += a * inp_f;
+ ++*inp;
+ ++*out;
+}
+ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
+ const float a3, const bfloat16** inp1,
+ const bfloat16** inp2,
+ const bfloat16** inp3, float** out) {
+ float inp1_f = ConvertBfloat16ToFloat(*inp1);
+ float inp2_f = ConvertBfloat16ToFloat(*inp2);
+ float inp3_f = ConvertBfloat16ToFloat(*inp3);
+ **out += a1 * inp1_f + a2 * inp2_f + a3 * inp3_f;
+ ++*out;
+ ++*inp1;
+ ++*inp2;
+ ++*inp3;
+}
+
+ALWAYS_INLINE void ScalarMulAdd3Way(const float a1, const float a2,
+ const float a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ **out += a1 * **inp1 + a2 * **inp2 + a3 * **inp3;
+ ++*out;
+ ++*inp1;
+ ++*inp2;
+ ++*inp3;
+}
+
+ALWAYS_INLINE void LoadSingleScalar(const bfloat16** data, Packet* l) {
+ auto tmp = ConvertBfloat16ToFloat(*data);
+ *l = Eigen::internal::pset1<Packet>(tmp);
+ ++*data;
+}
+
+ALWAYS_INLINE void LoadTwoScalars(const bfloat16** data, Packet* l1,
+ Packet* l2) {
+ if (kNumOperands >= 2) {
+ auto tmp = ConvertTwoBfloat16ToFloat(*data);
+ *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
+ *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
+ *data += 2;
+ } else {
+ LoadSingleScalar(data, l1);
+ LoadSingleScalar(data, l2);
+ }
+}
+
+ALWAYS_INLINE void LoadFourScalars(const bfloat16** data, Packet* l1,
+ Packet* l2, Packet* l3, Packet* l4) {
+ if (kNumOperands >= 4) {
+ auto tmp = ConvertFourBfloat16ToFloat(*data);
+ *l1 = Eigen::internal::pbroadcast_first<Packet>(tmp);
+ *l2 = Eigen::internal::pbroadcast_second<Packet>(tmp);
+ *l3 = Eigen::internal::pbroadcast_third<Packet>(tmp);
+ *l4 = Eigen::internal::pbroadcast_fourth<Packet>(tmp);
+ *data += 4;
+ } else {
+ LoadTwoScalars(data, l1, l2);
+ LoadTwoScalars(data, l3, l4);
+ }
+}
+
+ALWAYS_INLINE void LoadSingleScalar(const float** data, Packet* l) {
+ *l = Eigen::internal::pload1<Packet>(*data);
+ ++(*data);
+}
+
+ALWAYS_INLINE void LoadTwoScalars(const float** data, Packet* l1, Packet* l2) {
+ LoadSingleScalar(data, l1);
+ LoadSingleScalar(data, l2);
+}
+
+ALWAYS_INLINE void LoadFourScalars(const float** data, Packet* l1, Packet* l2,
+ Packet* l3, Packet* l4) {
+ LoadTwoScalars(data, l1, l2);
+ LoadTwoScalars(data, l3, l4);
+}
+
+template <typename T>
+ALWAYS_INLINE void LoadThreeScalars(const T** data, Packet* l1, Packet* l2,
+ Packet* l3) {
+ LoadTwoScalars(data, l1, l2);
+ LoadSingleScalar(data, l3);
+}
+
+template <typename T>
+ALWAYS_INLINE void LoadSixScalars(const T** data, Packet* l1, Packet* l2,
+ Packet* l3, Packet* l4, Packet* l5,
+ Packet* l6) {
+ LoadFourScalars(data, l1, l2, l3, l4);
+ LoadTwoScalars(data, l5, l6);
+}
+
+// Vectorized version of ScalarMulAdd.
+ALWAYS_INLINE void MulAdd(const Packet a, const bfloat16** binp, float** out) {
+ auto inp = reinterpret_cast<const float*>(*binp);
+ const auto b = LOAD(inp);
+ EXPAND_BFLOAT_L(b, b_0);
+ EXPAND_BFLOAT_U(b, b_1);
+ *binp += 2 * kNumOperands;
+ auto c1 = LOAD(*out);
+ auto c2 = LOAD(*out + kNumOperands);
+ FMA(a, b_0, c1, c1);
+ FMA(a, b_1, c2, c2);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ *out += 2 * kNumOperands;
+}
+
+// Vectorized version of ScalarMulAdd3Way.
+ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
+ const bfloat16** binp1, const bfloat16** binp2,
+ const bfloat16** binp3, float** out) {
+ auto inp1 = reinterpret_cast<const float*>(*binp1);
+ auto inp2 = reinterpret_cast<const float*>(*binp2);
+ auto inp3 = reinterpret_cast<const float*>(*binp3);
+ auto c1 = LOAD(*out);
+ auto c2 = LOAD(*out + kNumOperands);
+ const auto b1 = LOAD(inp1);
+ EXPAND_BFLOAT_L(b1, b1_0);
+ EXPAND_BFLOAT_U(b1, b1_1);
+ *binp1 += 2 * kNumOperands;
+ const auto b2 = LOAD(inp2);
+ EXPAND_BFLOAT_L(b2, b2_0);
+ EXPAND_BFLOAT_U(b2, b2_1);
+ *binp2 += 2 * kNumOperands;
+ const auto b3 = LOAD(inp3);
+ EXPAND_BFLOAT_L(b3, b3_0);
+ EXPAND_BFLOAT_U(b3, b3_1);
+ *binp3 += 2 * kNumOperands;
+ FMA(a1, b1_0, c1, c1);
+ FMA(a1, b1_1, c2, c2);
+ FMA(a2, b2_0, c1, c1);
+ FMA(a2, b2_1, c2, c2);
+ FMA(a3, b3_0, c1, c1);
+ FMA(a3, b3_1, c2, c2);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ *out += 2 * kNumOperands;
+}
+
+// Unroll MulAdd3Way for two iterations
+ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
+ const Packet a3, const bfloat16** binp1,
+ const bfloat16** binp2, const bfloat16** binp3,
+ float** out) {
+ auto inp1 = reinterpret_cast<const float*>(*binp1);
+ auto inp2 = reinterpret_cast<const float*>(*binp2);
+ auto inp3 = reinterpret_cast<const float*>(*binp3);
+ auto c1 = LOAD(*out);
+ auto c2 = LOAD(*out + kNumOperands);
+ const auto b1 = LOAD(inp1);
+ const auto b2 = LOAD(inp2);
+ const auto b3 = LOAD(inp3);
+
+ EXPAND_BFLOAT_L(b1, b1_0);
+ EXPAND_BFLOAT_U(b1, b1_1);
+ EXPAND_BFLOAT_L(b2, b2_0);
+ EXPAND_BFLOAT_U(b2, b2_1);
+ EXPAND_BFLOAT_L(b3, b3_0);
+ EXPAND_BFLOAT_U(b3, b3_1);
+ auto c3 = LOAD(*out + 2 * kNumOperands);
+ auto c4 = LOAD(*out + 3 * kNumOperands);
+ const auto b4 = LOAD(inp1 + kNumOperands);
+ const auto b5 = LOAD(inp2 + kNumOperands);
+ const auto b6 = LOAD(inp3 + kNumOperands);
+
+ EXPAND_BFLOAT_L(b4, b4_0);
+ EXPAND_BFLOAT_U(b4, b4_1);
+ EXPAND_BFLOAT_L(b5, b5_0);
+ EXPAND_BFLOAT_U(b5, b5_1);
+ EXPAND_BFLOAT_L(b6, b6_0);
+ EXPAND_BFLOAT_U(b6, b6_1);
+
+ FMA(a1, b1_0, c1, c1);
+ FMA(a1, b1_1, c2, c2);
+ FMA(a1, b4_0, c3, c3);
+ FMA(a1, b4_1, c4, c4);
+ FMA(a2, b2_0, c1, c1);
+ FMA(a2, b2_1, c2, c2);
+ FMA(a2, b5_0, c3, c3);
+ FMA(a2, b5_1, c4, c4);
+ FMA(a3, b3_0, c1, c1);
+ FMA(a3, b3_1, c2, c2);
+ FMA(a3, b6_0, c3, c3);
+ FMA(a3, b6_1, c4, c4);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ STORE(*out + 2 * kNumOperands, c3);
+ STORE(*out + 3 * kNumOperands, c4);
+ *out += 4 * kNumOperands;
+ *binp1 += 4 * kNumOperands;
+ *binp2 += 4 * kNumOperands;
+ *binp3 += 4 * kNumOperands;
+}
+
+// Apply MulAdd3Way on 128 operands.
+ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
+ const Packet a3, const bfloat16** inp1,
+ const bfloat16** inp2, const bfloat16** inp3,
+ float** out) {
+ for (int k = 0; k < 128 / (8 * kNumOperands); ++k) {
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
}
-#endif
+}
+// Vectorized version of ScalarMulAdd
+ALWAYS_INLINE void MulAdd(const Packet a, const float** inp, float** out) {
+ const auto b = LOAD(*inp);
+ *inp += kNumOperands;
+ auto c = LOAD(*out);
+ FMA(a, b, c, c);
+ STORE(*out, c);
+ *out += kNumOperands;
+}
+
+// Vectorized version of ScalarMulAdd3Way
+ALWAYS_INLINE void MulAdd3Way(const Packet a1, const Packet a2, const Packet a3,
+ const float** inp1, const float** inp2,
+ const float** inp3, float** out) {
+ auto c = LOAD(*out);
+ const auto b1 = LOAD(*inp1);
+ *inp1 += kNumOperands;
+ const auto b2 = LOAD(*inp2);
+ *inp2 += kNumOperands;
+ const auto b3 = LOAD(*inp3);
+ *inp3 += kNumOperands;
+ FMA(a1, b1, c, c);
+ FMA(a2, b2, c, c);
+ FMA(a3, b3, c, c);
+ STORE(*out, c);
+ *out += kNumOperands;
+}
+
+// Unroll MulAdd3Way for two iterations
+ALWAYS_INLINE void TwoMulAdd3Way(const Packet a1, const Packet a2,
+ const Packet a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ auto c1 = LOAD(*out);
+ const auto b1 = LOAD(*inp1);
+ const auto b2 = LOAD(*inp2);
+ const auto b3 = LOAD(*inp3);
+
+ auto c2 = LOAD(*out + kNumOperands);
+ const auto b4 = LOAD(*inp1 + kNumOperands);
+ const auto b5 = LOAD(*inp2 + kNumOperands);
+ const auto b6 = LOAD(*inp3 + kNumOperands);
+
+ FMA(a1, b1, c1, c1);
+ FMA(a1, b4, c2, c2);
+ FMA(a2, b2, c1, c1);
+ FMA(a2, b5, c2, c2);
+ FMA(a3, b3, c1, c1);
+ FMA(a3, b6, c2, c2);
+ STORE(*out, c1);
+ STORE(*out + kNumOperands, c2);
+ *out += 2 * kNumOperands;
+ *inp1 += 2 * kNumOperands;
+ *inp2 += 2 * kNumOperands;
+ *inp3 += 2 * kNumOperands;
+}
+
+// Unroll MulAdd3Way for four iterations
+ALWAYS_INLINE void FourMulAdd3Way(const Packet a1, const Packet a2,
+ const Packet a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ TwoMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+}
+
+// Apply MulAdd3Way on 128 operands.
+ALWAYS_INLINE void MulAdd3Way128(const Packet a1, const Packet a2,
+ const Packet a3, const float** inp1,
+ const float** inp2, const float** inp3,
+ float** out) {
+ if (kNumOperands == 8) {
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ FourMulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ } else {
+ DCHECK_LE(4 * kNumOperands, 128);
+ for (int i = 0; i < 128 / (4 * kNumOperands); ++i) {
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ MulAdd3Way(a1, a2, a3, inp1, inp2, inp3, out);
+ }
+ }
+}
// Computes product of "left_slices" with "num_cols" columns of "right", and
// stores the output in *"output".
// Note that left_slices is a list of SparseSlices, which are conceptually
// assumed to be concatenated along the column dimension. Also each SparseSlice
// is encoded as a list of blocks with upto N columns. See SparseSlice for more
// details.
-template <int Cols>
-inline void GEPP(const std::vector<SparseSlice*>& left_slices,
- const ConstMatrixMap& right, const int num_cols,
- Matrix* output) {
+template <typename TL, typename TR, int Cols>
+inline void GEPP(
+ const std::vector<SparseSlice<TL>*>& left_slices,
+ const Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>& right,
+ const int num_cols, Matrix* output) {
const int cols = (Cols == -1) ? num_cols : Cols;
DCHECK_EQ(num_cols, cols);
const int right_num_cols = right.dimension(1);
const int output_num_cols = output->dimension(1);
- const int cols_mod = cols % kNumOperands;
+ static const int kNumOperandsR = kNumOperands * sizeof(float) / sizeof(TR);
+ const int cols_mod = cols % kNumOperandsR;
int k_offset = 0;
// Pre-compute pointers for output matrix.
float* out_ptrs[M];
@@ -332,15 +595,15 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices,
}
for (const auto* left_slice : left_slices) {
const auto& left = *left_slice;
- const float* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr;
- const float* data = (left.data.size() > 0) ? &left.data[0] : nullptr;
+ const auto* data3 = (left.data3.size() > 0) ? &left.data3[0] : nullptr;
+ const auto* data = (left.data.size() > 0) ? &left.data[0] : nullptr;
const int num_blocks = left.index3_offset.size();
int begin3 = 0;
int begin = 0;
for (int i = 0; i < num_blocks; ++i) {
// Pre-compute pointers for right matrix
- const float* right_ptrs[K];
- const float* const right_start = &right(k_offset, 0);
+ const TR* right_ptrs[K];
+ const auto* const right_start = &right(k_offset, 0);
DCHECK_LT(k_offset, right.dimension(0));
for (int j = 0; j < K; ++j) {
right_ptrs[j] = right_start + right_num_cols * j;
@@ -350,79 +613,116 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices,
int j = begin3;
// Loop unrolled for 2 iterations.
for (; j + 1 < end3; j += 2) {
- const float* sl1 = data3++;
- LOAD_SCALAR(sl1, l1);
- const float* sl2 = data3++;
- LOAD_SCALAR(sl2, l2);
- const float* sl3 = data3++;
- LOAD_SCALAR(sl3, l3);
- const float* nsl1 = data3++;
- LOAD_SCALAR(nsl1, nl1);
- const float* nsl2 = data3++;
- LOAD_SCALAR(nsl2, nl2);
- const float* nsl3 = data3++;
- LOAD_SCALAR(nsl3, nl3);
- const SparseSlice::Index3& index = left.index3[j];
- const SparseSlice::Index3& nindex = left.index3[j + 1];
+ Packet l1, l2, l3, nl1, nl2, nl3;
+ LoadSixScalars(&data3, &l1, &l2, &l3, &nl1, &nl2, &nl3);
+ const auto& index = left.index3[j];
+ const auto& nindex = left.index3[j + 1];
float* out = out_ptrs[index.m];
float* nout = out_ptrs[nindex.m];
- const float* r1 = right_ptrs[index.k1];
- const float* r2 = right_ptrs[index.k2];
- const float* r3 = right_ptrs[index.k3];
- const float* nr1 = right_ptrs[nindex.k1];
- const float* nr2 = right_ptrs[nindex.k2];
- const float* nr3 = right_ptrs[nindex.k3];
+ const auto* r1 = right_ptrs[index.k1];
+ const auto* r2 = right_ptrs[index.k2];
+ const auto* r3 = right_ptrs[index.k3];
+
+ const auto* nr1 = right_ptrs[nindex.k1];
+ const auto* nr2 = right_ptrs[nindex.k2];
+ const auto* nr3 = right_ptrs[nindex.k3];
if (cols == 128) {
- MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out);
- MULADD3WAY_128(nl1, nl2, nl3, nr1, nr2, nr3, nout);
+ MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
+ MulAdd3Way128(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
} else {
- for (int n = 0; n < cols / kNumOperands; ++n) {
- MULADD3WAY(l1, l2, l3, r1, r2, r3, out);
- MULADD3WAY(nl1, nl2, nl3, nr1, nr2, nr3, nout);
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
+ MulAdd3Way(nl1, nl2, nl3, &nr1, &nr2, &nr3, &nout);
}
+
+ const float sl1 = Eigen::internal::pfirst<Packet>(l1);
+ const float sl2 = Eigen::internal::pfirst<Packet>(l2);
+ const float sl3 = Eigen::internal::pfirst<Packet>(l3);
+ const float nsl1 = Eigen::internal::pfirst<Packet>(nl1);
+ const float nsl2 = Eigen::internal::pfirst<Packet>(nl2);
+ const float nsl3 = Eigen::internal::pfirst<Packet>(nl3);
for (int k = 0; k < cols_mod; ++k) {
- SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out);
- SCALAR_MULADD3WAY(nsl1, nsl2, nsl3, nr1, nr2, nr3, nout);
+ ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
+ ScalarMulAdd3Way(nsl1, nsl2, nsl3, &nr1, &nr2, &nr3, &nout);
}
}
}
if (j < end3) {
- const float* sl1 = data3++;
- LOAD_SCALAR(sl1, l1);
- const float* sl2 = data3++;
- LOAD_SCALAR(sl2, l2);
- const float* sl3 = data3++;
- LOAD_SCALAR(sl3, l3);
- const SparseSlice::Index3& index = left.index3[j];
+ Packet l1, l2, l3;
+ LoadThreeScalars(&data3, &l1, &l2, &l3);
+
+ const auto& index = left.index3[j];
float* out = out_ptrs[index.m];
- const float* r1 = right_ptrs[index.k1];
- const float* r2 = right_ptrs[index.k2];
- const float* r3 = right_ptrs[index.k3];
+ const auto* r1 = right_ptrs[index.k1];
+ const auto* r2 = right_ptrs[index.k2];
+ const auto* r3 = right_ptrs[index.k3];
if (cols == 128) {
- MULADD3WAY_128(l1, l2, l3, r1, r2, r3, out);
+ MulAdd3Way128(l1, l2, l3, &r1, &r2, &r3, &out);
} else {
- for (int n = 0; n < cols / kNumOperands; ++n) {
- MULADD3WAY(l1, l2, l3, r1, r2, r3, out);
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd3Way(l1, l2, l3, &r1, &r2, &r3, &out);
}
+ const float sl1 = Eigen::internal::pfirst<Packet>(l1);
+ const float sl2 = Eigen::internal::pfirst<Packet>(l2);
+ const float sl3 = Eigen::internal::pfirst<Packet>(l3);
for (int k = 0; k < cols_mod; ++k) {
- SCALAR_MULADD3WAY(sl1, sl2, sl3, r1, r2, r3, out);
+ ScalarMulAdd3Way(sl1, sl2, sl3, &r1, &r2, &r3, &out);
}
}
}
begin3 = end3;
int end = left.index_offset[i];
- for (int j = begin; j < end; ++j) {
- const float* sl = data++;
- LOAD_SCALAR(sl, l);
- const SparseSlice::Index& index = left.index[j];
- const float* r = right_ptrs[index.k];
+ // Loop unrolled for 4 iterations.
+ j = begin;
+ for (; j + 3 < end; j += 4) {
+ Packet l, nl, n2l, n3l;
+ LoadFourScalars(&data, &l, &nl, &n2l, &n3l);
+
+ const auto& index = left.index[j];
+ const auto& nindex = left.index[j + 1];
+ const auto& n2index = left.index[j + 2];
+ const auto& n3index = left.index[j + 3];
+ const auto* r = right_ptrs[index.k];
+ const auto* nr = right_ptrs[nindex.k];
+ const auto* n2r = right_ptrs[n2index.k];
+ const auto* n3r = right_ptrs[n3index.k];
+ float* out = out_ptrs[index.m];
+ float* nout = out_ptrs[nindex.m];
+ float* n2out = out_ptrs[n2index.m];
+ float* n3out = out_ptrs[n3index.m];
+
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd(l, &r, &out);
+ MulAdd(nl, &nr, &nout);
+ MulAdd(n2l, &n2r, &n2out);
+ MulAdd(n3l, &n3r, &n3out);
+ }
+
+ const float sl1 = Eigen::internal::pfirst<Packet>(l);
+ const float sl2 = Eigen::internal::pfirst<Packet>(nl);
+ const float sl3 = Eigen::internal::pfirst<Packet>(n2l);
+ const float sl4 = Eigen::internal::pfirst<Packet>(n3l);
+ for (int k = 0; k < cols_mod; ++k) {
+ ScalarMulAdd(sl1, &r, &out);
+ ScalarMulAdd(sl2, &nr, &nout);
+ ScalarMulAdd(sl3, &n2r, &n2out);
+ ScalarMulAdd(sl4, &n3r, &n3out);
+ }
+ }
+ while (j < end) {
+ Packet l;
+ LoadSingleScalar(&data, &l);
+ const auto& index = left.index[j];
+ const auto* r = right_ptrs[index.k];
float* out = out_ptrs[index.m];
- for (int n = 0; n < cols / kNumOperands; ++n) {
- MULADD(l, r, out);
+ for (int n = 0; n < cols / kNumOperandsR; ++n) {
+ MulAdd(l, &r, &out);
}
+ const float sl = Eigen::internal::pfirst<Packet>(l);
for (int k = 0; k < cols_mod; ++k) {
- SCALAR_MULADD(sl, r, out);
+ ScalarMulAdd(sl, &r, &out);
}
+ j++;
}
k_offset += left.block_size;
begin = end;
@@ -430,21 +730,101 @@ inline void GEPP(const std::vector<SparseSlice*>& left_slices,
}
}
-#undef SCALAR_MULADD
-#undef SCALAR_MULADD3WAY
#undef LOAD
+#undef EXPAND_BFLOAT_L
+#undef EXPAND_BFLOAT_U
#undef STORE
-#undef LOAD_SCALAR
#undef FMA
-#undef MULADD
-#undef MULADD3WAY
-#undef MULADD3WAY_16
-#undef MULADD3WAY_32
-#undef MULADD3WAY_128
} // namespace
+template <typename TL, typename TR>
+class SparseMatMul {
+ typedef Eigen::Tensor<TL, 2, Eigen::RowMajor> MatrixL;
+ typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR;
+ typedef Eigen::TensorMap<Eigen::Tensor<const TL, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMapL;
+ typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMapR;
+ typedef Eigen::TensorMap<Eigen::Tensor<TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ MatrixMapR;
+ // Perform matrix multiplication of "left" and "right", and store the result
+ // in *"output".
+ public:
+ static inline void Compute(const ConstMatrixMapL& left,
+ const ConstMatrixMapR& right, bool transpose_left,
+ const DeviceBase::CpuWorkerThreads* thread_pool,
+ bool transpose_output, MatrixMap* output);
+
+ private:
+ // Computes multiplication of left and num_cols columns of right, and stores
+ // the output block in *"output" at offsets "output_row_offset" and
+ // "output_col_offset". If assign is true, assigns the value to that block,
+ // else adds the values to the existing values.
+ static inline void ComputeOutputBlock(
+ const std::vector<SparseSlice<TL>*>& left, const ConstMatrixMapR& right,
+ int num_cols, int output_row_offset, int output_col_offset, bool assign,
+ bool transpose_output, MatrixMap* output);
+
+ // Encodes "mat" using a sparse representation and stores that in
+ // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
+ // "slice_num_cols", each grid element is converted into a SparseSlice and
+ // stored in mat_slices. "slice_block_size" is used to perform further column
+ // blocking of each slice.
+ static inline BlockingCounter* CreateSparseSlices(
+ const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
+ int slice_block_size, int slice_num_cols,
+ std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
+ const DeviceBase::CpuWorkerThreads* thread_pool);
+
+ // This function chops "mat" along column dimension into pieces with at most N
+ // columns, and concatenates the pieces one after the other in "buffer". It
+ // returns the list of the pieces in "slices". It returns a BlockingCounter
+ // which should be used to wait for the shuffle operations to complete.
+ static inline BlockingCounter* CreateDenseSlices(
+ const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
+ int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
+ MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
+
+ // Helper function for CreateDenseSlices to move the data around. It returns a
+ // BlockingCounter which should be used to wait for the shuffle operations to
+ // complete.
+ static inline BlockingCounter* ShuffleMatrix(
+ const ConstMatrixMapR& mat, int slice_row_start, int slice_num_rows,
+ int slice_col_start, int slice_num_cols, const int N,
+ const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer);
+
+ // Helper function for CreateDenseSlices to create slices.
+ static inline void SliceMatrix(const MatrixR& mat, const int num_rows,
+ const int num_slices,
+ std::vector<ConstMatrixMapR*>* slices);
+
+ // Heuristics to compute various block sizes.
+ // KR, NR: block sizes for "right". We run blocking iterations that operate on
+ // matrices with at most this size.
+ // KL: grid size along the column dimension used while encoding left.
+ // IB, JB: number of left and right slices to multiply together. This is used
+ // for ordering different ComputeBlockOutput operations inside each blocking
+ // iteration so as to potentially reduce the working set size.
+ static inline void ComputeBlockSizes(const ConstMatrixMapL& left,
+ const ConstMatrixMapR& right,
+ bool transpose_left, int num_threads,
+ int* KR, int* NR, int* KL, int* JB,
+ int* IB);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul);
+};
+
+template <typename TL, typename TR>
class SparseMatMulOp : public OpKernel {
+ typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR;
+ typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
+ Eigen::Aligned>
+ ConstMatrixMapR;
+
public:
explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
@@ -461,12 +841,10 @@ class SparseMatMulOp : public OpKernel {
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
errors::InvalidArgument("b is not a matrix"));
- auto left = a.matrix<float>();
- auto right = b.matrix<float>();
- const int m = transpose_a_ ? left.dimension(1) : left.dimension(0);
- const int k = transpose_a_ ? left.dimension(0) : left.dimension(1);
- const int n = transpose_b_ ? right.dimension(0) : right.dimension(1);
- const int k2 = transpose_b_ ? right.dimension(1) : right.dimension(0);
+ const int m = transpose_a_ ? a.dim_size(1) : a.dim_size(0);
+ const int k = transpose_a_ ? a.dim_size(0) : a.dim_size(1);
+ const int n = transpose_b_ ? b.dim_size(0) : b.dim_size(1);
+ const int k2 = transpose_b_ ? b.dim_size(1) : b.dim_size(0);
OP_REQUIRES(ctx, k == k2,
errors::InvalidArgument("Matrix size incompatible: a: ",
@@ -476,127 +854,94 @@ class SparseMatMulOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
auto out = output->matrix<float>();
+ std::unique_ptr<Tensor> a_float;
+ std::unique_ptr<Tensor> b_float;
if (!a_is_sparse_ && !b_is_sparse_) {
- // Fallback to Eigen contract.
- // Note that we currently don't optimize the case where only right is
- // sparse. That can generally be handled by transposing the order of the
- // matmul.
+ auto left = &a;
+ auto right = &b;
+ // TODO(agarwal): multi-thread the conversions from bfloat16 to float.
+ if (std::is_same<TL, bfloat16>::value) {
+ a_float.reset(new Tensor(DT_FLOAT, a.shape()));
+ BFloat16ToFloat(a.flat<bfloat16>().data(),
+ a_float->flat<float>().data(), a.NumElements());
+ left = a_float.get();
+ }
+ if (std::is_same<TR, bfloat16>::value) {
+ b_float.reset(new Tensor(DT_FLOAT, b.shape()));
+ BFloat16ToFloat(b.flat<bfloat16>().data(),
+ b_float->flat<float>().data(), b.NumElements());
+ right = b_float.get();
+ }
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0].first = transpose_a_ ? 0 : 1;
dim_pair[0].second = transpose_b_ ? 1 : 0;
+
out.device(ctx->template eigen_device<CPUDevice>()) =
- left.contract(right, dim_pair);
+ left->matrix<float>().contract(right->matrix<float>(), dim_pair);
return;
}
- auto left_mat = &left;
- auto right_mat = &right;
+
+ auto left = &a;
+ auto right = &b;
bool transpose_output = false;
bool transpose_a = transpose_a_;
bool transpose_b = transpose_b_;
if (!a_is_sparse_) {
// Swap the order of multiplications using the identity:
// A * B = (B' * A')'.
- std::swap(left_mat, right_mat);
+ std::swap(left, right);
std::swap(transpose_a, transpose_b);
transpose_a = !transpose_a;
transpose_b = !transpose_b;
transpose_output = !transpose_output;
}
- std::unique_ptr<Matrix> right_tr_mat;
- std::unique_ptr<TTypes<float>::ConstMatrix> right_tr_map;
+
+ std::unique_ptr<Tensor> right_tr;
if (transpose_b) {
// TODO(agarwal): avoid transposing the matrix here and directly handle
// transpose in CreateDenseSlices.
- right_tr_mat.reset(
- new Matrix(right_mat->dimension(1), right_mat->dimension(0)));
+ right_tr.reset(
+ new Tensor(right->dtype(),
+ TensorShape({right->dim_size(1), right->dim_size(0)})));
+
Eigen::array<int, 2> perm({1, 0});
- right_tr_mat->device(ctx->template eigen_device<CPUDevice>()) =
- right_mat->shuffle(perm);
- right_tr_map.reset(new TTypes<float>::ConstMatrix(
- right_tr_mat->data(), right_tr_mat->dimensions()));
- right_mat = right_tr_map.get();
+ if (transpose_output) {
+ right_tr->matrix<TL>().device(ctx->template eigen_device<CPUDevice>()) =
+ right->matrix<TL>().shuffle(perm);
+ } else {
+ right_tr->matrix<TR>().device(ctx->template eigen_device<CPUDevice>()) =
+ right->matrix<TR>().shuffle(perm);
+ }
+ right = right_tr.get();
}
- SparseMatMul(*left_mat, *right_mat, transpose_a,
- ctx->device()->tensorflow_cpu_worker_threads(),
- transpose_output, &out);
+ if (transpose_output) {
+ SparseMatMul<TR, TL>::Compute(
+ left->matrix<TR>(), right->matrix<TL>(), transpose_a,
+ ctx->device()->tensorflow_cpu_worker_threads(), transpose_output,
+ &out);
+ } else {
+ SparseMatMul<TL, TR>::Compute(
+ left->matrix<TL>(), right->matrix<TR>(), transpose_a,
+ ctx->device()->tensorflow_cpu_worker_threads(), transpose_output,
+ &out);
+ }
}
private:
- // Perform matrix multiplication of "left" and "right", and store the result
- // in *"output".
- static inline void SparseMatMul(
- const ConstMatrixMap& left, const ConstMatrixMap& right,
- bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
- bool transpose_output, MatrixMap* output);
-
- // Computes multiplication of left and num_cols columns of right, and stores
- // the output block in *"output" at offsets "output_row_offset" and
- // "output_col_offset". If assign is true, assigns the value to that block,
- // else adds the values to the existing values.
- static inline void ComputeOutputBlock(const std::vector<SparseSlice*>& left,
- const ConstMatrixMap& right,
- int num_cols, int output_row_offset,
- int output_col_offset, bool assign,
- bool transpose_output,
- MatrixMap* output);
-
- // Encodes "mat" using a sparse representation and stores that in
- // "mat_slices". "mat" is broken into a grid with sizes "slice_num_rows" and
- // "slice_num_cols", each grid element is converted into a SparseSlice and
- // stored in mat_slices. "slice_block_size" is used to perform further column
- // blocking of each slice.
- static inline BlockingCounter* CreateSparseSlices(
- const ConstMatrixMap& mat, bool transpose, int slice_num_rows,
- int slice_block_size, int slice_num_cols,
- std::vector<std::vector<SparseSlice*>>* mat_slices,
- const DeviceBase::CpuWorkerThreads* thread_pool);
-
- // This function chops "mat" along column dimension into pieces with at most N
- // columns, and concatenates the pieces one after the other in "buffer". It
- // returns the list of the pieces in "slices". It returns a BlockingCounter
- // which should be used to wait for the shuffle operations to complete.
- static inline BlockingCounter* CreateDenseSlices(
- const ConstMatrixMap& mat, int row_start, int num_rows, int col_start,
- int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
- Matrix* buffer, std::vector<ConstMatrixMap*>* slices);
-
- // Helper function for CreateDenseSlices to move the data around. It returns a
- // BlockingCounter which should be used to wait for the shuffle operations to
- // complete.
- static inline BlockingCounter* ShuffleMatrix(
- const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows,
- int slice_col_start, int slice_num_cols, const int N,
- const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer);
-
- // Helper function for CreateDenseSlices to create slices.
- static inline void SliceMatrix(const Matrix& mat, const int num_rows,
- const int num_slices,
- std::vector<ConstMatrixMap*>* slices);
-
- // Heuristics to compute various block sizes.
- // KR, NR: block sizes for "right". We run blocking iterations that operate on
- // matrices with at most this size.
- // KL: grid size along the column dimension used while encoding left.
- // IB, JB: number of left and right slices to multiply together. This is used
- // for ordering different ComputeBlockOutput operations inside each blocking
- // iteration so as to potentially reduce the working set size.
- static inline void ComputeBlockSizes(const ConstMatrixMap& left,
- const ConstMatrixMap& right,
- bool transpose_left, int num_threads,
- int* KR, int* NR, int* KL, int* JB,
- int* IB);
-
bool transpose_a_;
bool transpose_b_;
bool a_is_sparse_;
bool b_is_sparse_;
+
TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
};
-inline void SparseMatMulOp::ComputeOutputBlock(
- const std::vector<SparseSlice*>& left, const ConstMatrixMap& right,
- int num_cols, int output_row_offset, int output_col_offset, bool assign,
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
+ const std::vector<SparseSlice<TL>*>& left,
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& right, int num_cols,
+ int output_row_offset, int output_col_offset, bool assign,
bool transpose_output, MatrixMap* output) {
static const Eigen::array<int, 2> perm({1, 0});
int num_rows = left[0]->num_rows;
@@ -605,9 +950,9 @@ inline void SparseMatMulOp::ComputeOutputBlock(
Matrix out(num_rows, rhs_num_cols);
out.setZero();
if (num_cols == N) {
- GEPP<N>(left, right, num_cols, &out);
+ GEPP<TL, TR, N>(left, right, num_cols, &out);
} else {
- GEPP<-1>(left, right, num_cols, &out);
+ GEPP<TL, TR, -1>(left, right, num_cols, &out);
}
if (!assign) {
const Eigen::array<int, 2> begin = {output_row_offset, output_col_offset};
@@ -643,10 +988,11 @@ inline void SparseMatMulOp::ComputeOutputBlock(
}
}
-inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
- const ConstMatrixMap& mat, bool transpose, int slice_num_rows,
- int slice_block_size, int slice_num_cols,
- std::vector<std::vector<SparseSlice*>>* mat_slices,
+template <typename TL, typename TR>
+inline BlockingCounter* SparseMatMul<TL, TR>::CreateSparseSlices(
+ const SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
+ int slice_num_rows, int slice_block_size, int slice_num_cols,
+ std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
const DeviceBase::CpuWorkerThreads* thread_pool) {
const int mat_num_rows = transpose ? mat.dimension(1) : mat.dimension(0);
const int mat_num_cols = transpose ? mat.dimension(0) : mat.dimension(1);
@@ -657,12 +1003,13 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
mat_slices->resize(num_slices_dim0);
BlockingCounter* counter =
new BlockingCounter(num_slices_dim0 * num_slices_dim1);
- auto work = [counter, transpose](SparseSlice* sparse_slice,
- ConstMatrixMap* slice, int col_offset) {
+ auto work = [counter, transpose](SparseSlice<TL>* sparse_slice,
+ SparseMatMul<TL, TR>::ConstMatrixMapL* slice,
+ int col_offset) {
if (transpose) {
- sparse_slice->Initialize<true>(*slice, col_offset);
+ sparse_slice->template Initialize<true>(*slice, col_offset);
} else {
- sparse_slice->Initialize<false>(*slice, col_offset);
+ sparse_slice->template Initialize<false>(*slice, col_offset);
}
delete slice;
counter->DecrementCount();
@@ -674,16 +1021,17 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
for (int j = 0; j < num_slices_dim1; ++j) {
int num_cols =
std::min<int>(slice_num_cols, mat_num_cols - j * slice_num_cols);
- ConstMatrixMap* slice = nullptr;
+ SparseMatMul<TL, TR>::ConstMatrixMapL* slice = nullptr;
if (transpose) {
- slice =
- new ConstMatrixMap(&mat(0, i * slice_num_rows), mat.dimensions());
+ slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
+ &mat(0, i * slice_num_rows), mat.dimensions());
} else {
DSizes d(num_rows, mat_num_cols);
- slice = new ConstMatrixMap(&mat(i * slice_num_rows, 0), d);
+ slice = new SparseMatMul<TL, TR>::ConstMatrixMapL(
+ &mat(i * slice_num_rows, 0), d);
}
- SparseSlice* sparse_slice =
- new SparseSlice(num_rows, num_cols, slice_block_size);
+ auto* sparse_slice =
+ new SparseSlice<TL>(num_rows, num_cols, slice_block_size);
(*mat_slices)[i][j] = sparse_slice;
thread_pool->workers->Schedule(
std::bind(work, sparse_slice, slice, slice_num_cols * j));
@@ -691,11 +1039,58 @@ inline BlockingCounter* SparseMatMulOp::CreateSparseSlices(
}
return counter;
}
+#define LOAD(x) Eigen::internal::ploadu<Packet>((x));
+#define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
+#define STORE(x, y) Eigen::internal::pstoreu<float>(x, y);
+
+template <int NUM_ELEM = -1>
+ALWAYS_INLINE void CopyAndMayBeInterleaveBfloat16(void* bdst, const void* bsrc,
+ int num_elements) {
+ DCHECK_GE(kNumOperands, 8);
+ static const int kStep = kNumOperands * sizeof(float) / sizeof(bfloat16);
+ const int num = (NUM_ELEM == -1) ? num_elements : NUM_ELEM;
+ DCHECK_EQ(num, num_elements);
+ const float* src = reinterpret_cast<const float*>(bsrc);
+ float* dst = reinterpret_cast<float*>(bdst);
+ for (int index = 0; index + kStep <= num; index += kStep) {
+ auto in = LOAD(src);
+ auto tmp = INTERLEAVE(in);
+ STORE(dst, tmp);
+ src += kNumOperands;
+ dst += kNumOperands;
+ }
+ if (num % kStep != 0) {
+ memcpy(dst, src, (num % kStep) * sizeof(bfloat16));
+ }
+}
-inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
- const ConstMatrixMap& mat, int slice_row_start, int slice_num_rows,
- int slice_col_start, int slice_num_cols, const int N,
- const DeviceBase::CpuWorkerThreads* thread_pool, Matrix* buffer) {
+template <typename T>
+ALWAYS_INLINE void CopyAndMayBeInterleave(void* dst, const void* src,
+ int num_elements) {
+ if (std::is_same<T, float>::value || kNumOperands < 8) {
+ memcpy(dst, src, num_elements * sizeof(T));
+ } else if (std::is_same<T, bfloat16>::value) {
+ if (num_elements == N) {
+ CopyAndMayBeInterleaveBfloat16<N>(dst, src, num_elements);
+ } else {
+ CopyAndMayBeInterleaveBfloat16<-1>(dst, src, num_elements);
+ }
+ } else {
+ LOG(FATAL) << "Unsupported type";
+ }
+}
+
+#undef LOAD
+#undef Interleave
+#undef Store
+
+template <typename TL, typename TR>
+inline BlockingCounter* SparseMatMul<TL, TR>::ShuffleMatrix(
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int slice_row_start,
+ int slice_num_rows, int slice_col_start, int slice_num_cols, const int N,
+ const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer) {
+ DCHECK_EQ(N % 2, 0);
+ DCHECK_LE(kNumOperands * sizeof(float) / sizeof(TR), N);
int num_threads = std::min(thread_pool->num_threads, 16);
BlockingCounter* counter = new BlockingCounter(num_threads);
DCHECK_EQ(N, buffer->dimension(1));
@@ -703,17 +1098,17 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
slice_num_cols, N, buffer, counter](int s, int e) {
const int row_start = s % slice_num_rows + slice_row_start;
const int col_start = s / slice_num_rows * N + slice_col_start;
- float* out_start = &(*buffer)(s, 0);
- const float* input_start = &mat(row_start, col_start);
- const float* input_end = &mat(slice_row_start + slice_num_rows - 1,
- slice_col_start + slice_num_cols - 1);
+ auto* out_start = &(*buffer)(s, 0);
+ const auto* input_start = &mat(row_start, col_start);
+ const auto* input_end = &mat(slice_row_start + slice_num_rows - 1,
+ slice_col_start + slice_num_cols - 1);
const int mat_num_cols = mat.dimension(1);
const int row_slice_size = slice_num_rows * mat_num_cols;
const int aligned_end = slice_num_cols / N * slice_num_rows;
const int e1 = std::min(e, aligned_end);
while (s < e1) {
- memcpy(out_start, input_start, N * sizeof(float));
+ CopyAndMayBeInterleave<TR>(out_start, input_start, N);
out_start += N;
input_start += mat_num_cols;
if (input_start > input_end) {
@@ -724,7 +1119,7 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
int s1 = std::max(s, aligned_end);
const int copy_num_cols = slice_num_cols % N;
while (s1 < e) {
- memcpy(out_start, input_start, copy_num_cols * sizeof(float));
+ CopyAndMayBeInterleave<TR>(out_start, input_start, copy_num_cols);
out_start += N;
input_start += mat_num_cols;
++s1;
@@ -745,21 +1140,24 @@ inline BlockingCounter* SparseMatMulOp::ShuffleMatrix(
return counter;
}
-inline void SparseMatMulOp::SliceMatrix(const Matrix& mat, const int num_rows,
- const int num_slices,
- std::vector<ConstMatrixMap*>* slices) {
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::SliceMatrix(
+ const MatrixR& mat, const int num_rows, const int num_slices,
+ std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
slices->resize(num_slices);
DSizes d(num_rows, mat.dimension(1));
DCHECK_LE(num_rows * num_slices, mat.dimension(0));
for (int i = 0; i < num_slices; ++i) {
- (*slices)[i] = new ConstMatrixMap(&mat(i * num_rows, 0), d);
+ (*slices)[i] = new ConstMatrixMapR(&mat(i * num_rows, 0), d);
}
}
-inline BlockingCounter* SparseMatMulOp::CreateDenseSlices(
- const ConstMatrixMap& mat, int row_start, int num_rows, int col_start,
- int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
- Matrix* buffer, std::vector<ConstMatrixMap*>* slices) {
+template <typename TL, typename TR>
+inline BlockingCounter* SparseMatMul<TL, TR>::CreateDenseSlices(
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
+ int num_rows, int col_start, int num_cols,
+ const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
+ std::vector<SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
BlockingCounter* shuffle_counter = ShuffleMatrix(
mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer);
const int num_slices = (num_cols + N - 1) / N;
@@ -767,11 +1165,11 @@ inline BlockingCounter* SparseMatMulOp::CreateDenseSlices(
return shuffle_counter;
}
-inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left,
- const ConstMatrixMap& right,
- bool transpose_left,
- int num_threads, int* KR, int* NR,
- int* KL, int* JB, int* IB) {
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::ComputeBlockSizes(
+ const SparseMatMul<TL, TR>::ConstMatrixMapL& left,
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
+ int num_threads, int* KR, int* NR, int* KL, int* JB, int* IB) {
// Heuristics for calculating block sizes
// Assume two hyperthreads per core.
const int est_num_cores = std::max(1, (num_threads + 1) / 2);
@@ -838,20 +1236,21 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left,
// and update the output block o_ij. These calls are further blocked to
// reduce the working set size. In each iteration we take IB elements from
// {l_i} and JB elements from {r_j} and compute the IB * JB inner products.
-inline void SparseMatMulOp::SparseMatMul(
- const ConstMatrixMap& left, const ConstMatrixMap& right,
- bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
- bool transpose_output, MatrixMap* output) {
+template <typename TL, typename TR>
+inline void SparseMatMul<TL, TR>::Compute(
+ const SparseMatMul<TL, TR>::ConstMatrixMapL& left,
+ const SparseMatMul<TL, TR>::ConstMatrixMapR& right, bool transpose_left,
+ const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output,
+ MatrixMap* output) {
const int num_threads = thread_pool->num_threads;
int KR, NR, KL, JB, IB;
ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
&JB, &IB);
-
// Slice the left matrix
- std::vector<std::vector<SparseSlice*>> left_slices;
+ std::vector<std::vector<SparseSlice<TL>*>> left_slices;
std::unique_ptr<BlockingCounter> sparse_slice_counter;
sparse_slice_counter.reset(
- CreateSparseSlices(ConstMatrixMap(left.data(), left.dimensions()),
+ CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
transpose_left, M, K, KL, &left_slices, thread_pool));
const int num_left_slices = left_slices.size();
@@ -862,10 +1261,10 @@ inline void SparseMatMulOp::SparseMatMul(
// is the block size per iteration.
const int buffer_num_rows =
std::min(KR, right_dim0) * (std::min(NR, right_dim1) + N - 1) / N;
- Matrix buffer(buffer_num_rows, N);
- std::vector<ConstMatrixMap*> right_slices;
+ MatrixR buffer(buffer_num_rows, N);
+ std::vector<ConstMatrixMapR*> right_slices;
- std::vector<SparseSlice*> block_left_slices;
+ std::vector<SparseSlice<TL>*> block_left_slices;
std::vector<std::function<void(void)>> tasks;
// Number of blocks based on block sizes of KR * NR.
const int num_k_blocks = (right_dim0 + KR - 1) / KR;
@@ -890,7 +1289,6 @@ inline void SparseMatMulOp::SparseMatMul(
const int num_cols = std::min(N, right_num_cols - N * j_inner);
for (int i_inner = i_outer;
i_inner < std::min(num_left_slices, i_outer + IB); ++i_inner) {
- // Figure out which left slices to use.
block_left_slices.clear();
int begin = kb * KR / KL;
int end = std::min<int>((kb + 1) * KR / KL,
@@ -900,7 +1298,7 @@ inline void SparseMatMulOp::SparseMatMul(
left_slices[i_inner].begin() + begin,
left_slices[i_inner].begin() + end);
tasks.push_back(std::bind(
- &SparseMatMulOp::ComputeOutputBlock, block_left_slices,
+ &ComputeOutputBlock, block_left_slices,
std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
N * j_inner + nb * NR, kb == 0, transpose_output, output));
}
@@ -933,7 +1331,18 @@ inline void SparseMatMulOp::SparseMatMul(
}
}
-REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU),
- SparseMatMulOp);
+#define REGISTER_SPARSE_MATMUL(TA, TB) \
+ REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TA>("Ta") \
+ .TypeConstraint<TB>("Tb"), \
+ SparseMatMulOp<TA, TB>);
+
+REGISTER_SPARSE_MATMUL(bfloat16, bfloat16);
+REGISTER_SPARSE_MATMUL(float, bfloat16);
+REGISTER_SPARSE_MATMUL(bfloat16, float);
+REGISTER_SPARSE_MATMUL(float, float);
+
+#undef REGISTER_SPARSE_MATMUL
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc
index 2f28b8ff9e..676ea6e28c 100644
--- a/tensorflow/core/kernels/sparse_matmul_op_test.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse_matmul_op.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -26,11 +27,13 @@ limitations under the License.
namespace tensorflow {
random::PhiloxRandom philox(1, 1);
random::SimplePhilox rnd(&philox);
+using Eigen::operator==;
+template <typename T>
void Sparsify(Tensor* t, float sparsity) {
const int64 N = t->NumElements();
CHECK_LE(sparsity, 1);
- auto flat = t->flat<float>();
+ auto flat = t->flat<T>();
if (sparsity == 1) {
flat.setZero();
return;
@@ -38,9 +41,9 @@ void Sparsify(Tensor* t, float sparsity) {
static const uint32 K = 10000;
for (int64 i = 0; i < N; ++i) {
if (rnd.Uniform(K) < sparsity * K) {
- flat(i) = 0;
- } else if (flat(i) == 0) {
- flat(i) = 0.1;
+ flat(i) = T(0);
+ } else if (flat(i) == T(0)) {
+ flat(i) = T(1);
}
}
}
@@ -59,6 +62,7 @@ Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a,
return ret;
}
+template <typename TA, typename TB>
static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d,
float sparsity_a, float sparsity_b,
bool transpose_a, bool transpose_b) {
@@ -66,14 +70,14 @@ static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d,
bool b_sparse = (sparsity_b > 0);
auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d});
- Tensor left(DataTypeToEnum<float>::value, left_shape);
- left.flat<float>().setRandom();
- Sparsify(&left, sparsity_a);
+ Tensor left(DataTypeToEnum<TA>::value, left_shape);
+ left.flat<TA>().setRandom();
+ Sparsify<TA>(&left, sparsity_a);
auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n});
- Tensor right(DataTypeToEnum<float>::value, right_shape);
- right.flat<float>().setRandom();
- Sparsify(&right, sparsity_b);
+ Tensor right(DataTypeToEnum<TB>::value, right_shape);
+ right.flat<TB>().setRandom();
+ Sparsify<TB>(&right, sparsity_b);
SparseMatMulNode(g, test::graph::Constant(g, left),
test::graph::Constant(g, right), transpose_a, transpose_b,
@@ -81,59 +85,82 @@ static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d,
return g;
}
+template <typename TA, typename TB>
static Graph* SparseMatMul(int m, int n, int d, float sparsity_a,
float sparsity_b, bool transpose_a,
bool transpose_b) {
Graph* g = new Graph(OpRegistry::Global());
- return SparseMatMulHelper(g, m, n, d, sparsity_a, sparsity_b, transpose_a,
- transpose_b);
+ return SparseMatMulHelper<TA, TB>(g, m, n, d, sparsity_a, sparsity_b,
+ transpose_a, transpose_b);
}
-#define BM_SPARSE(M, K, N, S1, S2, TA, TB) \
- static void BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB( \
- int iters) { \
- testing::StopTiming(); \
- testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
- std::string label = \
- strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", TA, TB, \
- S1 / 100.0, S2 / 100.0); \
- testing::SetLabel(label); \
- testing::UseRealTime(); \
- auto g = SparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, TA, TB); \
- testing::StartTiming(); \
- test::Benchmark("cpu", g).Run(iters); \
- } \
- BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB);
-
-BM_SPARSE(2048, 2048, 2048, 0, 0, false, false);
-BM_SPARSE(2048, 2048, 2048, 1, 0, false, false);
-BM_SPARSE(2048, 2048, 2048, 50, 0, false, false);
-BM_SPARSE(2048, 2048, 2048, 85, 0, false, false);
-BM_SPARSE(2048, 2048, 2048, 99, 0, false, false);
-
-BM_SPARSE(2048, 2048, 2048, 0, 50, false, false);
-BM_SPARSE(2048, 2048, 2048, 0, 85, false, false);
-
-BM_SPARSE(2048, 2048, 2048, 85, 0, true, false);
-BM_SPARSE(2048, 2048, 2048, 85, 0, false, true);
-BM_SPARSE(2048, 2048, 2048, 85, 0, true, true);
-
-BM_SPARSE(2048, 2048, 2048, 0, 85, true, false);
-BM_SPARSE(2048, 2048, 2048, 0, 85, false, true);
-BM_SPARSE(2048, 2048, 2048, 0, 85, true, true);
-
-BM_SPARSE(1024, 1024, 1024, 0, 0, false, false);
-BM_SPARSE(1024, 1024, 1024, 1, 0, false, false);
-BM_SPARSE(1024, 1024, 1024, 85, 0, false, false);
-
-BM_SPARSE(256, 256, 256, 1, 0, false, false);
-BM_SPARSE(512, 512, 512, 1, 0, false, false);
+#define BM_SPARSE(M, K, N, S1, S2, TRA, TRB, TA, TB) \
+ static void \
+ BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB( \
+ int iters) { \
+ testing::StopTiming(); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
+ std::string label = \
+ strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", TRA, TRB, \
+ S1 / 100.0, S2 / 100.0); \
+ testing::SetLabel(label); \
+ testing::UseRealTime(); \
+ auto g = SparseMatMul<TA, TB>(M, N, K, S1 / 100.0, S2 / 100.0, TRA, TRB); \
+ testing::StartTiming(); \
+ test::Benchmark("cpu", g).Run(iters); \
+ } \
+ BENCHMARK( \
+ BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TRA##_##TRB##_##TA##_##TB);
+
+#define BM_SPARSE_FLOAT(M, K, N, S1, S2, TRA, TRB) \
+ BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, float)
+#define BM_SPARSE_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \
+ BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, bfloat16)
+#define BM_SPARSE_FLOAT_BFLOAT16(M, K, N, S1, S2, TRA, TRB) \
+ BM_SPARSE(M, K, N, S1, S2, TRA, TRB, float, bfloat16)
+#define BM_SPARSE_BFLOAT16_FLOAT(M, K, N, S1, S2, TRA, TRB) \
+ BM_SPARSE(M, K, N, S1, S2, TRA, TRB, bfloat16, float)
+
+// Test sparse b
+BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 0, false, false);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 1, 0, false, false);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 50, 0, false, false);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, false);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 99, 0, false, false);
+// Test sparse a
+BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 50, false, false);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, false);
+// Test transposing
+BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, false);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, false, true);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 85, 0, true, true);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, false);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, false, true);
+BM_SPARSE_FLOAT(2048, 2048, 2048, 0, 85, true, true);
+
+// Test smaller sizes
+BM_SPARSE_FLOAT(1024, 1024, 1024, 0, 0, false, false);
+BM_SPARSE_FLOAT(1024, 1024, 1024, 1, 0, false, false);
+BM_SPARSE_FLOAT(1024, 1024, 1024, 85, 0, false, false);
+BM_SPARSE_FLOAT(256, 256, 256, 1, 0, false, false);
+BM_SPARSE_FLOAT(512, 512, 512, 1, 0, false, false);
+
+// Test bfloat16
+BM_SPARSE_BFLOAT16(2048, 2048, 2048, 0, 0, false, false);
+BM_SPARSE_BFLOAT16(2048, 2048, 2048, 1, 0, false, false);
+BM_SPARSE_BFLOAT16(2048, 2048, 2048, 85, 0, false, false);
+BM_SPARSE_BFLOAT16(2048, 2048, 2048, 99, 0, false, false);
+BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 85, 0, false, false);
+BM_SPARSE_BFLOAT16_FLOAT(2048, 2048, 2048, 99, 0, false, false);
+BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 85, 0, false, false);
+BM_SPARSE_FLOAT_BFLOAT16(2048, 2048, 2048, 99, 0, false, false);
static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_1,
float sparsity_2) {
Graph* g = new Graph(OpRegistry::Global());
- SparseMatMulHelper(g, d, n, m, sparsity_1, sparsity_2, true, false);
- SparseMatMulHelper(g, m, d, n, sparsity_2, 0, false, true);
+ SparseMatMulHelper<float, float>(g, d, n, m, sparsity_1, sparsity_2, true,
+ false);
+ SparseMatMulHelper<float, float>(g, m, d, n, sparsity_2, 0, false, true);
return g;
}
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index fb35d40734..3af6b96c84 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -167,6 +167,7 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
.HostMemory("perm"), \
TransposeCpuOp);
TF_CALL_ALL_TYPES(REGISTER)
+REGISTER(bfloat16);
#undef REGISTER
#if GOOGLE_CUDA
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index f3b0d919a3..ba0b5e4bbb 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -594,13 +594,15 @@ transpose_b: If true, "b" is transposed before multiplication.
)doc");
REGISTER_OP("SparseMatMul")
- .Input("a: float")
- .Input("b: float")
+ .Input("a: Ta")
+ .Input("b: Tb")
.Output("product: float")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
.Attr("a_is_sparse: bool = false")
.Attr("b_is_sparse: bool = false")
+ .Attr("Ta: {float, bfloat16} = DT_FLOAT")
+ .Attr("Tb: {float, bfloat16} = DT_FLOAT")
.Doc(R"doc(
Multiply matrix "a" by matrix "b".
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 1c7c4ea1cf..90431d631d 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1107,6 +1107,7 @@ medium_kernel_test_list = glob([
"kernel_tests/seq2seq_test.py",
"kernel_tests/slice_op_test.py",
"kernel_tests/sparse_ops_test.py",
+ "kernel_tests/sparse_matmul_op_test.py",
"kernel_tests/sparse_tensor_dense_matmul_op_test.py",
])
@@ -1154,7 +1155,6 @@ cpu_only_kernel_test_list = glob([
"kernel_tests/self_adjoint_eig_op_test.py",
"kernel_tests/sparse_add_op_test.py",
"kernel_tests/sparse_concat_op_test.py",
- "kernel_tests/sparse_matmul_op_test.py",
"kernel_tests/sparse_split_op_test.py",
"kernel_tests/sparse_reorder_op_test.py",
"kernel_tests/sparse_to_dense_op_test.py",
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index d1f3aef880..6142bdab9c 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -22,90 +22,115 @@ import numpy as np
import tensorflow as tf
-def RandMatrix(rows, cols, tr):
+def RandMatrix(rows, cols, tr, round_bfloat=False):
if tr:
rows, cols = cols, rows
- return (np.clip(np.random.uniform(low=-100.0, high=100.0, size=rows * cols),
- 0, 100) / 100).reshape([rows, cols]).astype(np.float32)
+ rand_func = np.random.randint if round_bfloat else np.random.uniform
+ return (np.clip(rand_func(low=-256.0, high=256.0, size=rows * cols),
+ -64, 64) / 128.0).reshape([rows, cols]).astype(np.float32)
class SparseMatMulTest(tf.test.TestCase):
- def _testCpuMatmul(self, x, y, tr_a=False, tr_b=False,
- sp_a=True, sp_b=False):
- x_mat = np.matrix(x)
- if tr_a:
- x_mat = np.transpose(x_mat)
- y_mat = np.matrix(y)
- if tr_b:
- y_mat = np.transpose(y_mat)
- np_ans = x_mat * y_mat
+ def _testCpuMatmul(self, x, y,
+ tr_a=False, tr_b=False,
+ sp_a=True, sp_b=False,
+ x_dtype=tf.float32,
+ y_dtype=tf.float32):
with self.test_session(use_gpu=False):
- tf_ans = tf.matmul(x, y,
+ tf_x = tf.cast(x, x_dtype)
+ tf_y = tf.cast(y, y_dtype)
+ tf_ans = tf.matmul(tf_x, tf_y,
transpose_a=tr_a, transpose_b=tr_b,
a_is_sparse=sp_a,
b_is_sparse=sp_b)
out = tf_ans.eval()
- self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
+ np_x = tf.cast(tf_x, tf.float32).eval()
+ np_y = tf.cast(tf_y, tf.float32).eval()
+
+ if tr_a:
+ np_x = np.transpose(np_x)
+ if tr_b:
+ np_y = np.transpose(np_y)
+
+ np_ans = np.matrix(np_x) * np.matrix(np_y)
self.assertShapeEqual(np_ans, tf_ans)
+ self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
- def testFloatBasic(self):
+ def testBasic(self):
x = np.arange(0., 4.).reshape([4, 1]).astype(np.float32)
y = np.arange(-1., 1.).reshape([1, 2]).astype(np.float32)
- self._testCpuMatmul(x, y)
+ for x_dtype in (tf.float32, tf.bfloat16):
+ for y_dtype in (tf.float32, tf.bfloat16):
+ self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype)
# Tests setting one dimension to be a high value.
- def testFloatLarge(self):
+ def testLarge(self):
r1 = np.random.randint(6000, 20000)
r2 = np.random.randint(1, 10)
r3 = np.random.randint(1, 10)
for m, k, n in [(r1, r2, r3),
(r2, r1, r3),
(r2, r3, r1)]:
- x = RandMatrix(m, k, False)
- y = RandMatrix(k, n, False)
- self._testCpuMatmul(x, y)
- self._testCpuMatmul(x, y, sp_a=False, sp_b=True)
+ for x_dtype in (tf.float32, tf.bfloat16):
+ for y_dtype in (tf.float32, tf.bfloat16):
+ x = RandMatrix(m, k, False)
+ y = RandMatrix(k, n, False)
+ self._testCpuMatmul(x, y, x_dtype=x_dtype, y_dtype=y_dtype)
# Tests random sized matrices.
- def testFloatRandom(self):
+ def testRandom(self):
for _ in range(10):
for tr_a in [True, False]:
for tr_b in [True, False]:
for sp_a in [True, False]:
for sp_b in [True, False]:
- n, k, m = np.random.randint(1, 100, size=3)
- x = RandMatrix(n, k, tr_a)
- y = RandMatrix(k, m, tr_b)
- self._testCpuMatmul(x, y, tr_a, tr_b, sp_a, sp_b)
+ for x_dtype in (tf.float32, tf.bfloat16):
+ for y_dtype in (tf.float32, tf.bfloat16):
+ n, k, m = np.random.randint(1, 100, size=3)
+ x = RandMatrix(n, k, tr_a)
+ y = RandMatrix(k, m, tr_b)
+ self._testCpuMatmul(x, y, tr_a, tr_b, sp_a, sp_b,
+ x_dtype=x_dtype, y_dtype=y_dtype)
class MatMulGradientTest(tf.test.TestCase):
- def _testGradients(self, tr_a, tr_b, sp_a, sp_b, name):
+ def _testGradients(self, tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, name):
with self.test_session():
- a = tf.constant(RandMatrix(3, 2, tr_a), dtype=tf.float32)
- b = tf.constant(RandMatrix(2, 4, tr_b), dtype=tf.float32)
- m = tf.matmul(a, b,
+ a = tf.constant(RandMatrix(3, 2, tr_a, round_bfloat=True),
+ dtype=tf.float32)
+ b = tf.constant(RandMatrix(2, 4, tr_b, round_bfloat=True),
+ dtype=tf.float32)
+ tf_a = tf.cast(a, a_dtype) if a_dtype != tf.float32 else a
+ tf_b = tf.cast(b, b_dtype) if b_dtype != tf.float32 else b
+
+ m = tf.matmul(tf_a, tf_b,
name=name,
transpose_a=tr_a,
transpose_b=tr_b,
a_is_sparse=sp_a,
b_is_sparse=sp_b)
err = (tf.test.compute_gradient_error(a, [2, 3]
- if tr_a else [3, 2], m, [3, 4]) +
+ if tr_a else [3, 2], m, [3, 4],
+ x_init_value=a.eval(),
+ delta=1/64.) +
tf.test.compute_gradient_error(b, [4, 2]
- if tr_b else [2, 4], m, [3, 4]))
- print("sparse_matmul gradient err = ", err)
- self.assertLess(err, 1e-3)
+ if tr_b else [2, 4], m, [3, 4],
+ x_init_value=b.eval(),
+ delta=1/64.))
+ self.assertLess(err, 1/128.)
def testGradientInput(self):
for tr_a in [True, False]:
for tr_b in [True, False]:
for sp_a in [True, False]:
for sp_b in [True, False]:
- name = "sparse_matmul_%s_%s_%s_%s" % (tr_a, tr_b, sp_a, sp_b)
- self._testGradients(tr_a, tr_b, sp_a, sp_b, name)
+ for a_dtype in (tf.float32, tf.bfloat16):
+ for b_dtype in (tf.float32, tf.bfloat16):
+ name = "sparse_matmul_%s_%s_%s_%s" % (tr_a, tr_b, sp_a, sp_b)
+ self._testGradients(tr_a, tr_b, sp_a, sp_b,
+ a_dtype, b_dtype, name)
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index d5c6e9fc91..ab71e28cab 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -526,7 +526,8 @@ def _SparseMatMulGrad(op, grad):
# Use heuristic to figure out if grad might be sparse
grad: (grad.op.type == "ReluGrad")
}
- def _SparseMatMul(t1, t2, transpose_a=False, transpose_b=False):
+ def _SparseMatMul(t1, t2, out_dtype,
+ transpose_a=False, transpose_b=False):
"""Helper function to create SparseMatMul op."""
assert t1 in is_sparse and t2 in is_sparse
@@ -535,25 +536,30 @@ def _SparseMatMulGrad(op, grad):
if transpose_b:
t2 = array_ops.transpose(t2)
transpose_b = False
- return math_ops.matmul(t1, t2,
+ prod = math_ops.matmul(t1, t2,
transpose_a=transpose_a,
transpose_b=transpose_b,
a_is_sparse=t1_sparse,
b_is_sparse=t2_sparse)
+ if prod.dtype != out_dtype:
+ prod = math_ops.cast(prod, out_dtype)
+ return prod
+ dtype_a = op.inputs[0].dtype
+ dtype_b = op.inputs[1].dtype
if not t_a and not t_b:
- return (_SparseMatMul(grad, op.inputs[1], transpose_b=True),
- _SparseMatMul(op.inputs[0], grad, transpose_a=True))
+ return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
elif not t_a and t_b:
- return (_SparseMatMul(grad, op.inputs[1]),
- _SparseMatMul(grad, op.inputs[0], transpose_a=True))
+ return (_SparseMatMul(grad, op.inputs[1], dtype_a),
+ _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
elif t_a and not t_b:
- return (_SparseMatMul(op.inputs[1], grad, transpose_b=True),
- _SparseMatMul(op.inputs[0], grad))
+ return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad, dtype_b))
elif t_a and t_b:
- return (_SparseMatMul(op.inputs[1], grad,
+ return (_SparseMatMul(op.inputs[1], grad, dtype_a,
transpose_a=True, transpose_b=True),
- _SparseMatMul(grad, op.inputs[0],
+ _SparseMatMul(grad, op.inputs[0], dtype_b,
transpose_a=True, transpose_b=True))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 9ebf251574..f0f438a33d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1087,7 +1087,14 @@ def matmul(a, b,
with ops.op_scope([a, b], name, "MatMul") as name:
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
- if a.dtype == dtypes.float32 and (a_is_sparse or b_is_sparse):
+ sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
+ use_sparse_matmul = (a.dtype in sparse_matmul_types and
+ b.dtype in sparse_matmul_types and
+ (a_is_sparse or b_is_sparse))
+ if dtypes.bfloat16 in (a.dtype, b.dtype):
+ # matmul currently doesn't handle bfloat16 inputs.
+ use_sparse_matmul = True
+ if use_sparse_matmul:
return sparse_matmul(a, b,
transpose_a=transpose_a,
transpose_b=transpose_b,