diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-05-11 09:46:11 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-11 10:51:58 -0700 |
commit | 523055469c8a61425e3b8f104be67787c2933ccb (patch) | |
tree | 2bab6823c11e909543614358364766b4e3de669c /tensorflow/stream_executor/blas.h | |
parent | 939ede027be73ecafcc422371afe27dceccc720d (diff) |
Add fp16 matrix multiplication (GEMM) support to StreamExecutor, gated on
compilation with CUDA 7.5; fp16 convolutions via cuDNN will come soon.
This does not update any TensorFlow ops, but it is a dependency of doing
that.
Note: fp16 axpy and dot do not exist in CUDA 7.5 and have thus not been added.
CUDA 8.0 supports both (through the axpyEx and dotEx interfaces).
Change: 122069402
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r-- | tensorflow/stream_executor/blas.h | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 1f1d427c45..ab4f125861 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/array_slice.h" #include "tensorflow/stream_executor/platform/port.h" +#include "third_party/eigen3/Eigen/Core" namespace perftools { namespace gputools { @@ -846,6 +847,17 @@ class BlasSupport { // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix; // op(b) is a k-by-n matrix; c is an m-by-n matrix. + // + // Note: The half interface uses float precision internally; the version + // that uses half precision internally is not yet supported. There is no + // batched version of the half-precision interface. + virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, + blas::Transpose transb, uint64 m, uint64 n, uint64 k, + float alpha, + const DeviceMemory<Eigen::half> &a, int lda, + const DeviceMemory<Eigen::half> &b, int ldb, + float beta, + DeviceMemory<Eigen::half> *c, int ldc) = 0; virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, @@ -1599,6 +1611,11 @@ class BlasSupport { DeviceMemory<std::complex<double>> *x, int incx) override; \ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ + float alpha, const DeviceMemory<Eigen::half> &a, int lda, \ + const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ + DeviceMemory<Eigen::half> *c, int ldc) override; \ + bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ + blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ float alpha, const DeviceMemory<float> &a, int lda, \ const DeviceMemory<float> &b, int ldb, float beta, \ DeviceMemory<float> *c, int ldc) override; \ |