aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-06 15:15:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 15:21:23 -0700
commita2ee8bca3f3fa08bf022f5c8a68c3e1cf829f2cd (patch)
tree35d554fcf7dc24246d33be4a405ef7a8d6cf8abc /tensorflow/stream_executor/stream.cc
parent755fa7b501b5a1dadf2b8a1814d74d4451a05975 (diff)
Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.
PiperOrigin-RevId: 161137741
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 97193af777..9b4a4c4fb1 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -3484,6 +3484,27 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
+ const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
+ PARAM(algorithm));
+
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
+ const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
+ DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
+ m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ algorithm, output_profile_result);
+}
+
+Stream &Stream::ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
int ldc, blas::ComputationType computation_type,