From 9a1e2d5d3d2c6420c410378c385b0c4665cedb9b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 21 Dec 2016 11:04:40 -0800 Subject: Added experimental support for libxsmm sparse matrix-dense matrix multiplication. Needs a new enough version of libxsmm with sparse support, and needs some patches to work reliably at all sizes. Change: 142680668 --- libxsmm.BUILD | 21 +- tensorflow/core/kernels/BUILD | 11 +- tensorflow/core/kernels/sparse_matmul_op.cc | 335 ++++++++++++++++++++++++++-- tensorflow/workspace.bzl | 8 +- 4 files changed, 344 insertions(+), 31 deletions(-) diff --git a/libxsmm.BUILD b/libxsmm.BUILD index c618e822e7..a0aab0f5b7 100644 --- a/libxsmm.BUILD +++ b/libxsmm.BUILD @@ -8,7 +8,7 @@ exports_files(["LICENSE"]) # Arguments to ./scripts/libxsmm_interface.py, see that file for detailed description. # precision: SP & DP # prefetch: 1 (auto) -libxsmm_interface_arguments = "0 0 1" +libxsmm_interface_arguments = "0 1" # Arguments to ./scripts/libxsmm_config.py, see that file for detailed description. # ilp64: no @@ -60,6 +60,8 @@ cc_library( "src/libxsmm_dump.c", "src/libxsmm_malloc.c", "src/libxsmm_gemm.c", + "src/libxsmm_gemm_diff.c", + "src/libxsmm_hash.c", "src/libxsmm_timer.c", "src/libxsmm_trace.c", "src/libxsmm_trans.c", @@ -87,17 +89,11 @@ cc_library( "include/libxsmm_sync.h", "include/libxsmm_timer.h", "include/libxsmm_typedefs.h", - "src/libxsmm_gemm_diff.c", - "src/libxsmm_cpuid_x86.c", - "src/libxsmm_hash.c", # Generated: "include/libxsmm.h", "include/libxsmm_config.h", "include/libxsmm_dispatch.h", - ] + glob([ - "src/*.h", - "src/template/*.c", - ]), + ], copts = [ "-mavx", # JIT does not work without avx anyway, and this silences some CRC32 warnings. "-Wno-vla", # Libxsmm convolutions heavily use VLA. @@ -107,12 +103,13 @@ cc_library( "LIBXSMM_CPUID_X86_NOINLINE", "__BLAS=0", ], - includes = ["include"], + includes = [ + "include", + "src", + "src/template", + ], linkopts = ["-ldl"], visibility = ["//visibility:public"], - deps = [ - ":libxsmm_headers", - ], ) py_library( diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 4ef10f4b18..ae5fcf0186 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1972,8 +1972,17 @@ cc_library( tf_kernel_library( name = "sparse_matmul_op", + defines = select({ + ":xsmm": ["TENSORFLOW_USE_LIBXSMM"], + "//conditions:default": [], + }), prefix = "sparse_matmul_op", - deps = MATH_DEPS, + deps = MATH_DEPS + select({ + ":xsmm": [ + "@libxsmm_archive//:xsmm_avx", + ], + "//conditions:default": [], + }), ) cc_library( diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index c5460c8db1..9545839184 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -32,8 +32,13 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" - +#ifdef TENSORFLOW_USE_LIBXSMM +#include "third_party/libxsmm/include/libxsmm_intrinsics_x86.h" +#include "third_party/libxsmm/include/libxsmm_spmdm.h" +#endif namespace tensorflow { @@ -753,10 +758,16 @@ class SparseMatMul { typedef Eigen::TensorMap, Eigen::Aligned> MatrixMapR; + + public: + // Not used; added to match interface of LibxsmmSparseMatMul + struct TensorInfoCache {}; + // Perform matrix multiplication of "left" and "right", and store the result // in *"output". public: - static inline void Compute(const ConstMatrixMapL& left, + static inline void Compute(TensorInfoCache* cache, + const ConstMatrixMapL& left, const ConstMatrixMapR& right, bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output, MatrixMap* output); @@ -820,7 +831,106 @@ class SparseMatMul { TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMul); }; +#ifdef TENSORFLOW_USE_LIBXSMM template +class LibxsmmSparseMatMul { + typedef Eigen::Tensor MatrixL; + typedef Eigen::Tensor MatrixR; + typedef Eigen::TensorMap, + Eigen::Aligned> + ConstMatrixMapL; + typedef Eigen::TensorMap, + Eigen::Aligned> + ConstMatrixMapR; + typedef Eigen::TensorMap, + Eigen::Aligned> + MatrixMapR; + + public: + // This structure contains a set of libxsmm kernels for sizes that have been + // encountered previously by this operator so that libxsmm does not need to + // reallocate its scratchpad memory each time (which hurts performance + // substantially). + struct TensorInfoCache { + struct TensorInfoCacheEntry { + // Parameters for kernel + int M; + int K; + int N; + int max_threads; + // libxsmm handle and matrix data + libxsmm_spmdm_handle handle; + libxsmm_CSR_sparseslice* output_csr; + // Chain to non-libxsmm implementation's cache in case that ever becomes + // useful (it is an empty struct right now) + typename SparseMatMul::TensorInfoCache + non_libxsmm_cache; // Currently not used + }; + // protects entries; invariant: entries is a valid std::multimap + tensorflow::mutex lock; + // Because there could be multiple matrix multiplies with the same sizes + // going on at the same time, we need to allow multiple cache entries for a + // given set of parameters. Taking and returning entries is used to make + // sure the same cache entry is not used from two threads at a time. + std::multimap, + std::unique_ptr> + entries GUARDED_BY(lock); + + TensorInfoCache() : lock(), entries() {} + // Look up and remove first entry with these parameters, creating one if + // there isn't one + std::unique_ptr take_cache_entry(int M, int K, int N, + int max_threads) + LOCKS_EXCLUDED(lock) { + tensorflow::mutex_lock ml(lock); + auto key = std::make_tuple(M, K, N, max_threads); + auto it = entries.find(key); + if (it != entries.end()) { + auto val = std::move(it->second); + entries.erase(it); + return val; + } else { + std::unique_ptr e{ + new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; + libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); + return e; + } + } + // Add a cache entry with certain parameters + void return_cache_entry(std::unique_ptr e) + LOCKS_EXCLUDED(lock) { + tensorflow::mutex_lock ml(lock); + auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads); + entries.insert(std::make_pair(key, std::move(e))); + } + ~TensorInfoCache() { + tensorflow::mutex_lock ml(lock); + for (auto& p : entries) { + libxsmm_spmdm_destroy(&p.second->handle); + } + entries.clear(); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); + }; + + // Perform matrix multiplication of "left" and "right", and store the result + // in *"output". + public: + static inline void Compute(TensorInfoCache* cache, + const ConstMatrixMapL& left, + const ConstMatrixMapR& right, bool transpose_left, + const DeviceBase::CpuWorkerThreads* thread_pool, + bool transpose_output, MatrixMap* output); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(LibxsmmSparseMatMul); +}; +#endif + +template class DoMatMul> class SparseMatMulOp : public OpKernel { typedef Eigen::Tensor MatrixR; typedef Eigen::TensorMap, @@ -927,15 +1037,15 @@ class SparseMatMulOp : public OpKernel { } if (transpose_output) { - SparseMatMul::Compute( - left->matrix(), right->matrix(), transpose_a, - ctx->device()->tensorflow_cpu_worker_threads(), transpose_output, - &out); + DoMatMul::Compute(&this->cache_tr_, left->matrix(), + right->matrix(), transpose_a, + ctx->device()->tensorflow_cpu_worker_threads(), + transpose_output, &out); } else { - SparseMatMul::Compute( - left->matrix(), right->matrix(), transpose_a, - ctx->device()->tensorflow_cpu_worker_threads(), transpose_output, - &out); + DoMatMul::Compute(&this->cache_nt_, left->matrix(), + right->matrix(), transpose_a, + ctx->device()->tensorflow_cpu_worker_threads(), + transpose_output, &out); } } @@ -945,6 +1055,11 @@ class SparseMatMulOp : public OpKernel { bool a_is_sparse_; bool b_is_sparse_; + // Cache for non-transposed-output multiply + typename DoMatMul::TensorInfoCache cache_nt_; + // Cache for transposed-output multiply + typename DoMatMul::TensorInfoCache cache_tr_; + TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp); }; @@ -1219,6 +1334,182 @@ inline void SparseMatMul::ComputeBlockSizes( DCHECK_EQ(N * sizeof(float) % 64, size_t{0}); } +#ifdef TENSORFLOW_USE_LIBXSMM + +template +void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool, + const F& f) { + int num_threads = thread_pool->num_threads; + if (num_threads == 0) { + LOG(FATAL) << "Have 0 threads in thread pool"; + } else if (num_threads == 1) { + f(0); + } else { + BlockingCounter counter(num_threads - 1); + for (int i = 1; i < num_threads; ++i) { + thread_pool->workers->Schedule([&, i]() { + f(i); + counter.DecrementCount(); + }); + } + f(0); + counter.Wait(); + } +} + +template +struct empty_type_wrapper {}; + +// Copies of interface to libxsmm_spmdm_createSparseSlice_*_notrans_thread to +// allow overloading +void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( + empty_type_wrapper, const libxsmm_spmdm_handle* handle, char transA, + const float* A, libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, + int tid, int nthreads) { + return libxsmm_spmdm_createSparseSlice_fp32_thread( + handle, transA, A, libxsmm_output_csr_a, block_id, tid, nthreads); +} +void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( + empty_type_wrapper, const libxsmm_spmdm_handle* handle, + char transA, const bfloat16* A, + libxsmm_CSR_sparseslice* libxsmm_output_csr_a, int block_id, int tid, + int nthreads) { + return libxsmm_spmdm_createSparseSlice_bfloat16_thread( + handle, transA, reinterpret_cast(A), libxsmm_output_csr_a, + block_id, tid, nthreads); +} + +void wrapper_libxsmm_spmdm_compute_generic_thread( + empty_type_wrapper, const libxsmm_spmdm_handle* handle, + char transA, char transB, const bfloat16* alpha, + libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, const bfloat16* beta, + float* C, int block_id, int tid, int nthreads) { + return libxsmm_spmdm_compute_bfloat16_thread( + handle, transA, transB, reinterpret_cast(alpha), A_sparse, + reinterpret_cast(B), reinterpret_cast(beta), + C, block_id, tid, nthreads); +} +void wrapper_libxsmm_spmdm_compute_generic_thread( + empty_type_wrapper, const libxsmm_spmdm_handle* handle, char transA, + char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse, + const float* B, const float* beta, float* C, int block_id, int tid, + int nthreads) { + return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha, + A_sparse, B, beta, C, block_id, tid, + nthreads); +} + +class PinnedToCurrentCPU { + bool valid; + cpu_set_t old_cpu_set; + + public: + PinnedToCurrentCPU() : valid(false) { + int ret = 0; + ret = sched_getaffinity(0, sizeof(cpu_set_t), &old_cpu_set); + if (ret != 0) { + PLOG(WARNING) << "sched_getaffinity"; + return; + } + valid = true; + cpu_set_t new_cpu_set; + CPU_ZERO(&new_cpu_set); + CPU_SET(sched_getcpu(), &new_cpu_set); + ret = sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set); + if (ret != 0) { + PLOG(WARNING) << "sched_setaffinity"; + } + } + ~PinnedToCurrentCPU() { + if (!valid) return; + // No reason to trap errors here + sched_setaffinity(0, sizeof(cpu_set_t), &old_cpu_set); + } +}; + +template +inline void LibxsmmSparseMatMul::Compute( + typename LibxsmmSparseMatMul::TensorInfoCache* cache, + const typename LibxsmmSparseMatMul::ConstMatrixMapL& left, + const typename LibxsmmSparseMatMul::ConstMatrixMapR& right, + bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, + bool transpose_output, MatrixMap* output) { + if (transpose_output || transpose_left) { + // Not handled by libxsmm currently + SparseMatMul::Compute( + nullptr /* Assumes no cached data for fallback */, left, right, + transpose_left, thread_pool, transpose_output, output); + return; + } + const int num_threads = thread_pool->num_threads; + const int left_dim0 = transpose_left ? left.dimension(1) : left.dimension(0); + const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); + const int right_dim0 = right.dimension(0); + const int right_dim1 = right.dimension(1); + CHECK_EQ(left_dim1, right_dim0); + CHECK_EQ(left_dim0, + (transpose_output ? output->dimension(1) : output->dimension(0))); + CHECK_EQ(right_dim1, + (transpose_output ? output->dimension(0) : output->dimension(1))); + CHECK(!transpose_output); + if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) { + // Causes problems in libxsmm + SparseMatMul::Compute( + nullptr /* Assumes no cached data for fallback */, left, right, + transpose_left, thread_pool, transpose_output, output); + return; + } + auto left_data = left.data(); + auto right_data = right.data(); + auto output_data = output->data(); + // Initialize libxsmm for this matrix; make sure another thread doesn't use + // this handle + auto entry = + cache->take_cache_entry(left_dim0, right_dim0, right_dim1, num_threads); + // Convert the left matrix to compressed sparse row (CSR) format + ptrdiff_t total_num_creation_blocks = + libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle); + std::atomic cur_create_block_number; + cur_create_block_number.store(0); + do_on_all_threads(thread_pool, [&](int i) { + PinnedToCurrentCPU pin; + while (true) { + int work_item = cur_create_block_number.fetch_add(1); + if (work_item >= total_num_creation_blocks) break; + wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( + empty_type_wrapper{}, &entry->handle, + (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item, + i, num_threads); + } + }); + // Do matrix-matrix multiplication + // TODO(jewillco): libxsmm doesn't support beta != 1 yet -- remove when + // release + // includes beta handling + memset(output_data, 0, left_dim0 * right_dim1 * sizeof(TR)); + ptrdiff_t total_num_mult_blocks = + libxsmm_spmdm_get_num_compute_blocks(&entry->handle); + std::atomic cur_mult_block_number; + cur_mult_block_number.store(0); + do_on_all_threads(thread_pool, [&](int i) { + PinnedToCurrentCPU pin; + while (true) { + int work_item = cur_mult_block_number.fetch_add(1); + if (work_item >= total_num_mult_blocks) break; + const TL alpha(1.0); // Stored in a variable so we can get a pointer + const TL beta(0.0); // Stored in a variable so we can get a pointer + wrapper_libxsmm_spmdm_compute_generic_thread( + empty_type_wrapper{}, &entry->handle, + (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr, + right_data, &beta, output_data, work_item, i, num_threads); + } + }); + // Put handle + CSR storage back into cache + cache->return_cache_entry(std::move(entry)); +} + +#endif // TENSORFLOW_USE_LIBXSMM + // Here is a an overview of the SparseMatMul code. Note that we assume that the // left matrix is sparse. // @@ -1249,10 +1540,11 @@ inline void SparseMatMul::ComputeBlockSizes( // {l_i} and JB elements from {r_j} and compute the IB * JB inner products. template inline void SparseMatMul::Compute( + typename SparseMatMul::TensorInfoCache* /*cache*/, const typename SparseMatMul::ConstMatrixMapL& left, - const typename SparseMatMul::ConstMatrixMapR& right, bool transpose_left, - const DeviceBase::CpuWorkerThreads* thread_pool, bool transpose_output, - MatrixMap* output) { + const typename SparseMatMul::ConstMatrixMapR& right, + bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, + bool transpose_output, MatrixMap* output) { const int num_threads = thread_pool->num_threads; int KR, NR, KL, JB, IB; ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, @@ -1347,12 +1639,27 @@ inline void SparseMatMul::Compute( .Device(DEVICE_CPU) \ .TypeConstraint("Ta") \ .TypeConstraint("Tb"), \ - SparseMatMulOp); + SparseMatMulOp); +#ifdef TENSORFLOW_USE_LIBXSMM +#define REGISTER_SPARSE_MATMUL_LIBXSMM(TA, TB) \ + REGISTER_KERNEL_BUILDER(Name("SparseMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Ta") \ + .TypeConstraint("Tb"), \ + SparseMatMulOp); +#endif REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); + REGISTER_SPARSE_MATMUL(float, bfloat16); + REGISTER_SPARSE_MATMUL(bfloat16, float); + +#ifdef TENSORFLOW_USE_LIBXSMM +REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); +#else REGISTER_SPARSE_MATMUL(float, float); +#endif #undef REGISTER_SPARSE_MATMUL diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 0cb4054e22..c1ac7e1ac3 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -28,11 +28,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.new_http_archive( name = "libxsmm_archive", urls = [ - "http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.5.tar.gz", - "https://github.com/hfp/libxsmm/archive/1.5.tar.gz", + "http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.6.1.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.6.1.tar.gz", ], - sha256 = "c52568c5e0e8dc9d8fcf869a716d73598e52f71c3d83af5a4c0b3be81403b423", - strip_prefix = "libxsmm-1.5", + sha256 = "1dd81077b186300122dc8a8f1872c21fd2bd9b88286ab9f068cc7b62fa7593a7", + strip_prefix = "libxsmm-1.6.1", build_file = str(Label("//:libxsmm.BUILD")), ) -- cgit v1.2.3