diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-06 15:15:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-06 15:21:23 -0700 |
commit | a2ee8bca3f3fa08bf022f5c8a68c3e1cf829f2cd (patch) | |
tree | 35d554fcf7dc24246d33be4a405ef7a8d6cf8abc /tensorflow/stream_executor/stream.cc | |
parent | 755fa7b501b5a1dadf2b8a1814d74d4451a05975 (diff) |
Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.
PiperOrigin-RevId: 161137741
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 97193af777..9b4a4c4fb1 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -3484,6 +3484,27 @@ Stream &Stream::ThenBlasGemmWithAlgorithm( Stream &Stream::ThenBlasGemmWithAlgorithm( blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, int alpha, const DeviceMemory<int8> &a, int lda, + const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c, + int ldc, blas::ComputationType computation_type, + blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), + PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), + PARAM(algorithm)); + + ThenBlasWithProfileImpl< + blas::Transpose, blas::Transpose, uint64, uint64, uint64, int, + const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int, + DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, + m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, + algorithm, output_profile_result); +} + +Stream &Stream::ThenBlasGemmWithAlgorithm( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, int ldc, blas::ComputationType computation_type, |