aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_matmul_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-18 15:50:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-18 16:06:08 -0800
commit9ce3d17f65017c683caf842264d7dd1582b6139f (patch)
tree6bf8e67ebb9806604cff28fe28cf0f8c2a6782bc /tensorflow/core/kernels/sparse_matmul_op.cc
parent88c0db60efb7b1a908d0d8fec76f7c4072b51b4a (diff)
Update libxsmm version to 1.6.5.
Change: 144893095
Diffstat (limited to 'tensorflow/core/kernels/sparse_matmul_op.cc')
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.cc26
1 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index 31cca59f50..f8ed7b5bfe 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -1386,21 +1386,21 @@ void wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
void wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<bfloat16>, const libxsmm_spmdm_handle* handle,
char transA, char transB, const bfloat16* alpha,
- libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, const bfloat16* beta,
- float* C, int block_id, int tid, int nthreads) {
+ libxsmm_CSR_sparseslice* A_sparse, const bfloat16* B, char transC,
+ const bfloat16* beta, float* C, int block_id, int tid, int nthreads) {
return libxsmm_spmdm_compute_bfloat16_thread(
handle, transA, transB, reinterpret_cast<const uint16*>(alpha), A_sparse,
- reinterpret_cast<const uint16*>(B), reinterpret_cast<const uint16*>(beta),
- C, block_id, tid, nthreads);
+ reinterpret_cast<const uint16*>(B), transC,
+ reinterpret_cast<const uint16*>(beta), C, block_id, tid, nthreads);
}
void wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<float>, const libxsmm_spmdm_handle* handle, char transA,
char transB, const float* alpha, libxsmm_CSR_sparseslice* A_sparse,
- const float* B, const float* beta, float* C, int block_id, int tid,
- int nthreads) {
+ const float* B, char transC, const float* beta, float* C, int block_id,
+ int tid, int nthreads) {
return libxsmm_spmdm_compute_fp32_thread(handle, transA, transB, alpha,
- A_sparse, B, beta, C, block_id, tid,
- nthreads);
+ A_sparse, B, transC, beta, C,
+ block_id, tid, nthreads);
}
class PinnedToCurrentCPU {
@@ -1438,7 +1438,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
const typename LibxsmmSparseMatMul<TL, TR>::ConstMatrixMapR& right,
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
bool transpose_output, MatrixMap* output) {
- if (transpose_output || transpose_left) {
+ if (false) {
// Not handled by libxsmm currently
SparseMatMul<TL, TR>::Compute(
nullptr /* Assumes no cached data for fallback */, left, right,
@@ -1455,7 +1455,6 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
(transpose_output ? output->dimension(1) : output->dimension(0)));
CHECK_EQ(right_dim1,
(transpose_output ? output->dimension(0) : output->dimension(1)));
- CHECK(!transpose_output);
if (left_dim0 < 32 || left_dim1 < 32 || right_dim1 < 32) {
// Causes problems in libxsmm
SparseMatMul<TL, TR>::Compute(
@@ -1482,7 +1481,7 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
if (work_item >= total_num_creation_blocks) break;
wrapper_libxsmm_spmdm_createSparseSlice_generic_thread(
empty_type_wrapper<TL>{}, &entry->handle,
- (transpose_left ? 'T' : 'N'), left_data, entry->output_csr, work_item,
+ (transpose_left ? 'Y' : 'N'), left_data, entry->output_csr, work_item,
i, num_threads);
}
});
@@ -1504,8 +1503,9 @@ inline void LibxsmmSparseMatMul<TL, TR>::Compute(
const TL beta(0.0); // Stored in a variable so we can get a pointer
wrapper_libxsmm_spmdm_compute_generic_thread(
empty_type_wrapper<TL>{}, &entry->handle,
- (transpose_left ? 'T' : 'N'), 'N', &alpha, entry->output_csr,
- right_data, &beta, output_data, work_item, i, num_threads);
+ (transpose_left ? 'Y' : 'N'), 'N', &alpha, entry->output_csr,
+ right_data, (transpose_output ? 'Y' : 'N'), &beta, output_data,
+ work_item, i, num_threads);
}
});
// Put handle + CSR storage back into cache