diff options
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.h')
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_blas.h | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index 80cda97117..deb211c04b 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -84,7 +84,7 @@ class CUDABlas : public blas::BlasSupport { template <typename FuncT, typename... Args> bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream, bool pointer_mode_host, bool err_on_failure, - Args... args); + bool use_tensor_op_math, Args... args); // Convenience functions that call DoBlasInternalImpl with different values // for err_on_failure. @@ -92,13 +92,17 @@ class CUDABlas : public blas::BlasSupport { bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, Args... args) { return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, - /*err_on_failure=*/true, args...); + /*err_on_failure=*/true, /*use_tensor_ops=*/false, + args...); } template <typename FuncT, typename... Args> bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream, bool pointer_mode_host, Args... args) { + // Tensor ops are hard-coded off in this path, but can still be enabled with + // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl(). return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, - /*err_on_failure=*/false, args...); + /*err_on_failure=*/false, + /*use_tensor_ops=*/false, args...); } // A helper function to implement DoBlasGemmBatched interfaces for generic |