diff options
Diffstat (limited to 'tensorflow/core/kernels/sparse_matmul_op.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_matmul_op.cc | 279 |
1 files changed, 65 insertions, 214 deletions
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 2ed0522ce4..46e743b4cf 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -837,15 +837,6 @@ class SparseMatMul { }; #ifdef TENSORFLOW_USE_LIBXSMM -#ifdef EXTRA_CACHE_LOGGING -static tensorflow::mutex global_cache_stats_lock; -static int total_num_entries_outstanding GUARDED_BY(global_cache_stats_lock) = - 0; -static int total_num_entries_in_cache GUARDED_BY(global_cache_stats_lock) = 0; -#endif // EXTRA_CACHE_LOGGING - -static const int max_entries_per_graph_node = 40; - template <typename TL, typename TR> class LibxsmmSparseMatMul { typedef Eigen::Tensor<TL, 2, Eigen::RowMajor> MatrixL; @@ -861,7 +852,6 @@ class LibxsmmSparseMatMul { MatrixMapR; public: -#if 1 // This structure contains a set of libxsmm kernels for sizes that have been // encountered previously by this operator so that libxsmm does not need to // reallocate its scratchpad memory each time (which hurts performance @@ -880,181 +870,57 @@ class LibxsmmSparseMatMul { // useful (it is an empty struct right now) typename SparseMatMul<TL, TR>::TensorInfoCache non_libxsmm_cache; // Currently not used - TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCacheEntry); - ~TensorInfoCacheEntry() { -#ifdef EXTRA_CACHE_LOGGING - LOG(INFO) << "Deleting tensor cache entry at " << (void*)this; -#endif // EXTRA_CACHE_LOGGING - libxsmm_spmdm_destroy(&handle); - } }; - // protects entries; invariant: entries is a valid std::list. + // protects entries; invariant: entries is a valid std::multimap tensorflow::mutex lock; // Because there could be multiple matrix multiplies with the same sizes // going on at the same time, we need to allow multiple cache entries for a // given set of parameters. Taking and returning entries is used to make // sure the same cache entry is not used from two threads at a time. - using entries_map_type = std::list<std::pair< - std::tuple<int, int, int, int>, - std::unique_ptr<TensorInfoCacheEntry>>>; // multimap in LRU order - entries_map_type entries GUARDED_BY( - lock); // MRU element at end so reverse search will find it first - int num_entries_outstanding GUARDED_BY(lock); - - TensorInfoCache() : lock(), entries(), num_entries_outstanding(0) {} + std::multimap<std::tuple<int, int, int, int>, + std::unique_ptr<TensorInfoCacheEntry>> + entries GUARDED_BY(lock); + + TensorInfoCache() : lock(), entries() {} // Look up and remove first entry with these parameters, creating one if // there isn't one std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N, int max_threads) -#ifdef EXTRA_CACHE_LOGGING - LOCKS_EXCLUDED(lock, global_cache_stats_lock) -#else - LOCKS_EXCLUDED(lock) -#endif - { + LOCKS_EXCLUDED(lock) { tensorflow::mutex_lock ml(lock); -#ifdef EXTRA_CACHE_LOGGING - tensorflow::mutex_lock ml2(global_cache_stats_lock); -#endif auto key = std::make_tuple(M, K, N, max_threads); - auto it_rev = - std::find_if(entries.rbegin(), entries.rend(), - [&](const typename entries_map_type::value_type& e) { - return e.first == key; - }); - auto it = - (it_rev == entries.rend() ? entries.end() : std::next(it_rev).base()); + auto it = entries.find(key); if (it != entries.end()) { auto val = std::move(it->second); entries.erase(it); - ++num_entries_outstanding; -#ifdef EXTRA_CACHE_LOGGING - ++total_num_entries_outstanding; - --total_num_entries_in_cache; - LOG(INFO) << "Used existing cache entry at " << (void*)val.get() - << " for " << M << "x" << K << "x" << N << " max_threads " - << max_threads - << ", num_entries_outstanding = " << num_entries_outstanding - << ", new cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif return val; } else { - while (!entries.empty() && - entries.size() + num_entries_outstanding + 1 > - max_entries_per_graph_node) { -#ifdef EXTRA_CACHE_LOGGING - LOG(INFO) << "Removing old cache entry at " - << (void*)entries.front().second.get(); -#endif - entries.pop_front(); - } std::unique_ptr<TensorInfoCacheEntry> e{ new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; // setup scoped allocator, which uses cpu_allocator() for this scope const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator; libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); - ++num_entries_outstanding; -#ifdef EXTRA_CACHE_LOGGING - ++total_num_entries_outstanding; - LOG(INFO) << "Created cache entry at " << (void*)e.get() << " for " << M - << "x" << K << "x" << N << " max_threads " << max_threads - << ", num_entries_outstanding = " << num_entries_outstanding - << ", new cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif return e; } } // Add a cache entry with certain parameters void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e) -#ifdef EXTRA_CACHE_LOGGING - LOCKS_EXCLUDED(lock, global_cache_stats_lock) -#else - LOCKS_EXCLUDED(lock) -#endif - { + LOCKS_EXCLUDED(lock) { tensorflow::mutex_lock ml(lock); -#ifdef EXTRA_CACHE_LOGGING - tensorflow::mutex_lock ml2(global_cache_stats_lock); -#endif auto key = std::make_tuple(e->M, e->K, e->N, e->max_threads); - --num_entries_outstanding; -#ifdef EXTRA_CACHE_LOGGING - --total_num_entries_outstanding; - LOG(INFO) << "Returned cache entry at " << (void*)e.get() << " for " - << e->M << "x" << e->K << "x" << e->N << " max_threads " - << e->max_threads - << ", num_entries_outstanding = " << num_entries_outstanding - << ", prev cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif - entries.push_back(std::make_pair(key, std::move(e))); -#ifdef EXTRA_CACHE_LOGGING - ++total_num_entries_in_cache; -#endif + entries.insert(std::make_pair(key, std::move(e))); } ~TensorInfoCache() { tensorflow::mutex_lock ml(lock); -#ifdef EXTRA_CACHE_LOGGING - tensorflow::mutex_lock ml2(global_cache_stats_lock); - LOG(INFO) << "Deleting TensorInfoCache, cache size = " << entries.size() - << ", total num_entries_outstanding = " - << total_num_entries_outstanding - << ", total cache size = " << total_num_entries_in_cache; -#endif - CHECK_EQ(num_entries_outstanding, 0); + for (auto& p : entries) { + libxsmm_spmdm_destroy(&p.second->handle); + } entries.clear(); } private: TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); }; -#else - // This structure contains a set of libxsmm kernels for sizes that have been - // encountered previously by this operator so that libxsmm does not need to - // reallocate its scratchpad memory each time (which hurts performance - // substantially). - struct TensorInfoCache { - struct TensorInfoCacheEntry { - // Parameters for kernel - int M; - int K; - int N; - int max_threads; - // libxsmm handle and matrix data - libxsmm_spmdm_handle handle; - libxsmm_CSR_sparseslice* output_csr; - // Chain to non-libxsmm implementation's cache in case that ever becomes - // useful (it is an empty struct right now) - typename SparseMatMul<TL, TR>::TensorInfoCache - non_libxsmm_cache; // Currently not used - }; - TensorInfoCache() {} - // Look up and remove first entry with these parameters, creating one if - // there isn't one - std::unique_ptr<TensorInfoCacheEntry> take_cache_entry(int M, int K, int N, - int max_threads) { - std::unique_ptr<TensorInfoCacheEntry> e{ - new TensorInfoCacheEntry{M, K, N, max_threads, {}, nullptr}}; - libxsmm_spmdm_init(M, N, K, max_threads, &e->handle, &e->output_csr); - return e; - } - // Add a cache entry with certain parameters - void return_cache_entry(std::unique_ptr<TensorInfoCacheEntry> e) { - libxsmm_spmdm_destroy(&e->handle); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(TensorInfoCache); - }; -#endif // Perform matrix multiplication of "left" and "right", and store the result // in *"output". @@ -1479,21 +1345,21 @@ inline void SparseMatMul<TL, TR>::ComputeBlockSizes( template <typename F> void do_on_all_threads(const DeviceBase::CpuWorkerThreads* thread_pool, - ptrdiff_t max_thread_count, const F& f) { + const F& f) { int num_threads = thread_pool->num_threads; if (num_threads == 0) { LOG(FATAL) << "Have 0 threads in thread pool"; } else if (num_threads == 1) { - f(0, 1); + f(0); } else { BlockingCounter counter(num_threads - 1); for (int i = 1; i < num_threads; ++i) { thread_pool->workers->Schedule([&, i]() { - f(i, num_threads); + f(i); counter.DecrementCount(); }); } - f(0, num_threads); + f(0); counter.Wait(); } } @@ -1522,24 +1388,21 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( void wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle, - char transA, char transB, libxsmm_CSR_sparseslice* A_sparse, - const bfloat16* B, char transC, float* C, int block_id, int tid, - int nthreads) { - const uint16 alpha = 1; - const uint16 beta = 0; + char transA, char transB, const bfloat16* alpha, + libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC, + const bfloat16* beta, float* C, int block_id, int tid, int nthreads) { return libxsmm_spmdm_compute_bfloat16_thread( - handle, transA, transB, &alpha, A_sparse, - reinterpret_cast<const uint16*>(B), transC, &beta, C, block_id, tid, - nthreads); + handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse, + reinterpret_cast<const uint16*>(B), transC, + reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads); } void wrapper_libxsmm_spmdm_compute_generic_thread( empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA, - char transB, libxsmm_CSR_sparseslice* A_sparse, const float* B, char transC, - float* C, int block_id, int tid, int nthreads) { - const float alpha = 1.f; - const float beta = 0.f; - return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, &alpha, - A_sparse, B, transC, &beta, C, + char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse, + const float* B, char transC, const float* beta, float* C, int block_id, + int tid, int nthreads) { + return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha, + A_sparse, B, transC, beta, C, block_id, tid, nthreads); } @@ -1590,13 +1453,11 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute( const int left_dim1 = transpose_left ? left.dimension(0) : left.dimension(1); const int right_dim0 = right.dimension(0); const int right_dim1 = right.dimension(1); - const int output_dim0 = - transpose_output ? output->dimension(1) : output->dimension(0); - const int output_dim1 = - transpose_output ? output->dimension(0) : output->dimension(1); CHECK_EQ(left_dim1, right_dim0); - CHECK_EQ(left_dim0, output_dim0); - CHECK_EQ(right_dim1, output_dim1); + CHECK_EQ(left_dim0, + (transpose_output ? output->dimension(1) : output->dimension(0))); + CHECK_EQ(right_dim1, + (transpose_output ? output->dimension(0) : output->dimension(1))); if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) { // Causes problems in libxsmm SparseMatMul<TL, TR>::Compute( @@ -1614,50 +1475,42 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute( // Convert the left matrix to compressed sparse row (CSR) format ptrdiff_t total_num_creation_blocks = libxsmm_spmdm_get_num_createSparseSlice_blocks(&entry->handle); - ptrdiff_t total_num_mult_blocks = - libxsmm_spmdm_get_num_compute_blocks(&entry->handle); - bool use_libxsmm = - !(total_num_creation_blocks + total_num_mult_blocks < num_threads && - !transpose_left && !transpose_output); - if (!use_libxsmm) { - // Avoid some performance issues in libxsmm (FIXME) - cache->return_cache_entry(std::move(entry)); - SparseMatMul<TL, TR>::Compute( - nullptr /* Assumes no cached data for fallback */, left, right, - transpose_left, thread_pool, transpose_output, output); - return; - } std::atomic<int> cur_create_block_number; cur_create_block_number.store(0); - do_on_all_threads(thread_pool, total_num_creation_blocks, - [&](int i, int actual_num_threads) { - PinnedToCurrentCPU pin; - while (true) { - int work_item = cur_create_block_number.fetch_add(1); - if (work_item >= total_num_creation_blocks) break; - wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( - empty_type_wrapper<TL>{}, &entry->handle, - (transpose_left ? 'T' : 'N'), left_data, - entry->output_csr, work_item, i, - actual_num_threads); - } - }); + do_on_all_threads(thread_pool, [&](int i) { + PinnedToCurrentCPU pin; + while (true) { + int work_item = cur_create_block_number.fetch_add(1); + if (work_item >= total_num_creation_blocks) break; + wrapper_libxsmm_spmdm_createSparseSlice_generic_thread( + empty_type_wrapper<TL>{}, &entry->handle, + (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item, + i, num_threads); + } + }); // Do matrix-matrix multiplication + // TODO(jewillco): libxsmm doesn't support beta != 1 yet -- remove when + // release + // includes beta handling + memset(output_data, 0, left_dim0 * right_dim1 * sizeof(TR)); + ptrdiff_t total_num_mult_blocks = + libxsmm_spmdm_get_num_compute_blocks(&entry->handle); std::atomic<int> cur_mult_block_number; cur_mult_block_number.store(0); - do_on_all_threads( - thread_pool, total_num_mult_blocks, [&](int i, int actual_num_threads) { - PinnedToCurrentCPU pin; - while (true) { - int work_item = cur_mult_block_number.fetch_add(1); - if (work_item >= total_num_mult_blocks) break; - wrapper_libxsmm_spmdm_compute_generic_thread( - empty_type_wrapper<TL>{}, &entry->handle, - (transpose_left ? 'T' : 'N'), 'N', entry->output_csr, right_data, - (transpose_output ? 'T' : 'N'), output_data, work_item, i, - actual_num_threads); - } - }); + do_on_all_threads(thread_pool, [&](int i) { + PinnedToCurrentCPU pin; + while (true) { + int work_item = cur_mult_block_number.fetch_add(1); + if (work_item >= total_num_mult_blocks) break; + const TL alpha(1.0); // Stored in a variable so we can get a pointer + const TL beta(0.0); // Stored in a variable so we can get a pointer + wrapper_libxsmm_spmdm_compute_generic_thread( + empty_type_wrapper<TL>{}, &entry->handle, + (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr, + right_data, (transpose_output ? 'T' : 'N'), &beta, output_data, + work_item, i, num_threads); + } + }); // Put handle + CSR storage back into cache cache->return_cache_entry(std::move(entry)); } @@ -1803,17 +1656,15 @@ inline void SparseMatMul<TL, TR>::Compute( SparseMatMulOp<TA, TB, LibxsmmSparseMatMul>); #endif +REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); + REGISTER_SPARSE_MATMUL(float, bfloat16); REGISTER_SPARSE_MATMUL(bfloat16, float); #ifdef TENSORFLOW_USE_LIBXSMM -REGISTER_SPARSE_MATMUL_LIBXSMM(bfloat16, bfloat16); - REGISTER_SPARSE_MATMUL_LIBXSMM(float, float); #else -REGISTER_SPARSE_MATMUL(bfloat16, bfloat16); - REGISTER_SPARSE_MATMUL(float, float); #endif |