aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/cuda/cuda_blas.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/cuda/cuda_blas.h')
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 80cda97117..deb211c04b 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -84,7 +84,7 @@ class CUDABlas : public blas::BlasSupport {
template <typename FuncT, typename... Args>
bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
- Args... args);
+ bool use_tensor_op_math, Args... args);
// Convenience functions that call DoBlasInternalImpl with different values
// for err_on_failure.
@@ -92,13 +92,17 @@ class CUDABlas : public blas::BlasSupport {
bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
Args... args) {
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
- /*err_on_failure=*/true, args...);
+ /*err_on_failure=*/true, /*use_tensor_ops=*/false,
+ args...);
}
template <typename FuncT, typename... Args>
bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
+ // Tensor ops are hard-coded off in this path, but can still be enabled with
+ // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
- /*err_on_failure=*/false, args...);
+ /*err_on_failure=*/false,
+ /*use_tensor_ops=*/false, args...);
}
// A helper function to implement DoBlasGemmBatched interfaces for generic