aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/blas.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-11 09:46:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-11 10:51:58 -0700
commit523055469c8a61425e3b8f104be67787c2933ccb (patch)
tree2bab6823c11e909543614358364766b4e3de669c /tensorflow/stream_executor/blas.h
parent939ede027be73ecafcc422371afe27dceccc720d (diff)
Add fp16 matrix multiplication (GEMM) support to StreamExecutor, gated on
compilation with CUDA 7.5; fp16 convolutions via cuDNN will come soon. This does not update any TensorFlow ops, but it is a dependency of doing that. Note: fp16 axpy and dot do not exist in CUDA 7.5 and have thus not been added. CUDA 8.0 supports both (through the axpyEx and dotEx interfaces). Change: 122069402
Diffstat (limited to 'tensorflow/stream_executor/blas.h')
-rw-r--r--tensorflow/stream_executor/blas.h17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 1f1d427c45..ab4f125861 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/platform/port.h"
+#include "third_party/eigen3/Eigen/Core"
namespace perftools {
namespace gputools {
@@ -846,6 +847,17 @@ class BlasSupport {
// op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
// beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
// op(b) is a k-by-n matrix; c is an m-by-n matrix.
+ //
+ // Note: The half interface uses float precision internally; the version
+ // that uses half precision internally is not yet supported. There is no
+ // batched version of the half-precision interface.
+ virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ float alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ float beta,
+ DeviceMemory<Eigen::half> *c, int ldc) = 0;
virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k,
float alpha, const DeviceMemory<float> &a, int lda,
@@ -1599,6 +1611,11 @@ class BlasSupport {
DeviceMemory<std::complex<double>> *x, int incx) override; \
bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
+ float alpha, const DeviceMemory<Eigen::half> &a, int lda, \
+ const DeviceMemory<Eigen::half> &b, int ldb, float beta, \
+ DeviceMemory<Eigen::half> *c, int ldc) override; \
+ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
float alpha, const DeviceMemory<float> &a, int lda, \
const DeviceMemory<float> &b, int ldb, float beta, \
DeviceMemory<float> *c, int ldc) override; \