aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/blas.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-06 15:15:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-06 15:21:23 -0700
commita2ee8bca3f3fa08bf022f5c8a68c3e1cf829f2cd (patch)
tree35d554fcf7dc24246d33be4a405ef7a8d6cf8abc /tensorflow/stream_executor/blas.h
parent755fa7b501b5a1dadf2b8a1814d74d4451a05975 (diff)
Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.
PiperOrigin-RevId: 161137741
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r--tensorflow/stream_executor/blas.h25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 07a0f7ccd6..cfff3649c8 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -97,8 +97,9 @@ enum class ComputationType {
kF16, // 16-bit floating-point
kF32, // 32-bit floating-point
kF64, // 64-bit floating-point
+ kI32, // 32-bit integer
kComplexF32, // Complex number comprised of two f32s.
- kComplexF64 // Complex number comprised of two f64s.
+ kComplexF64, // Complex number comprised of two f64s.
};
// Converts a ComputationType to a string.
@@ -108,6 +109,15 @@ string ComputationTypeString(ComputationType ty);
// as a hint to the blas library.
typedef int64 AlgorithmType;
+// blas uses -1 to represent the default algorithm. This happens to match up
+// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
+// to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
+// to ensure that this assumption does not break.
+// If another blas implementation uses a different value for the default
+// algorithm, then it needs to convert kDefaultGemmAlgo to that value
+// (e.g. via a function called ToWhateverGemmAlgo).
+constexpr AlgorithmType kDefaultGemmAlgo = -1;
+
// Describes the result of a performance experiment, usually timing the speed of
// a particular AlgorithmType.
//
@@ -946,6 +956,12 @@ class BlasSupport {
// creating a new Stream for each attempt.
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
+ const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int32> *c,
+ int ldc, ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
+ virtual bool DoBlasGemmWithAlgorithm(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const Eigen::half &alpha,
const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
@@ -1739,6 +1755,13 @@ class BlasSupport {
override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, \
+ int lda, const DeviceMemory<int8> &b, int ldb, int beta, \
+ DeviceMemory<int> *c, int ldc, blas::ComputationType computation_type, \
+ blas::AlgorithmType algorithm, \
+ blas::ProfileResult *output_profile_result) override; \
+ bool DoBlasGemmWithAlgorithm( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \
const DeviceMemory<Eigen::half> &a, int lda, \
const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, \