diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-06 15:15:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-06 15:21:23 -0700 |
commit | a2ee8bca3f3fa08bf022f5c8a68c3e1cf829f2cd (patch) | |
tree | 35d554fcf7dc24246d33be4a405ef7a8d6cf8abc /tensorflow/stream_executor/blas.h | |
parent | 755fa7b501b5a1dadf2b8a1814d74d4451a05975 (diff) |
Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.
PiperOrigin-RevId: 161137741
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r-- | tensorflow/stream_executor/blas.h | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 07a0f7ccd6..cfff3649c8 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -97,8 +97,9 @@ enum class ComputationType { kF16, // 16-bit floating-point kF32, // 32-bit floating-point kF64, // 64-bit floating-point + kI32, // 32-bit integer kComplexF32, // Complex number comprised of two f32s. - kComplexF64 // Complex number comprised of two f64s. + kComplexF64, // Complex number comprised of two f64s. }; // Converts a ComputationType to a string. @@ -108,6 +109,15 @@ string ComputationTypeString(ComputationType ty); // as a hint to the blas library. typedef int64 AlgorithmType; +// blas uses -1 to represent the default algorithm. This happens to match up +// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast +// to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert +// to ensure that this assumption does not break. +// If another blas implementation uses a different value for the default +// algorithm, then it needs to convert kDefaultGemmAlgo to that value +// (e.g. via a function called ToWhateverGemmAlgo). +constexpr AlgorithmType kDefaultGemmAlgo = -1; + // Describes the result of a performance experiment, usually timing the speed of // a particular AlgorithmType. // @@ -946,6 +956,12 @@ class BlasSupport { // creating a new Stream for each attempt. virtual bool DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda, + const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int32> *c, + int ldc, ComputationType computation_type, AlgorithmType algorithm, + ProfileResult *output_profile_result) = 0; + virtual bool DoBlasGemmWithAlgorithm( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, int lda, const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, @@ -1739,6 +1755,13 @@ class BlasSupport { override; \ bool DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, \ + int lda, const DeviceMemory<int8> &b, int ldb, int beta, \ + DeviceMemory<int> *c, int ldc, blas::ComputationType computation_type, \ + blas::AlgorithmType algorithm, \ + blas::ProfileResult *output_profile_result) override; \ + bool DoBlasGemmWithAlgorithm( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \ const DeviceMemory<Eigen::half> &a, int lda, \ const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, \ |