aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/blas.h66
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc217
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc108
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h21
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc133
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc181
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h18
-rw-r--r--tensorflow/stream_executor/cuda/cuda_stream.h4
-rw-r--r--tensorflow/stream_executor/dnn.cc2
-rw-r--r--tensorflow/stream_executor/dnn.h36
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.cc2
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.h4
-rw-r--r--tensorflow/stream_executor/host/host_stream.cc26
-rw-r--r--tensorflow/stream_executor/host/host_stream.h4
-rw-r--r--tensorflow/stream_executor/module_spec.h66
-rw-r--r--tensorflow/stream_executor/stream.cc364
-rw-r--r--tensorflow/stream_executor/stream.h67
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.cc12
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h70
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc43
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h81
-rw-r--r--tensorflow/stream_executor/stream_test.cc203
22 files changed, 1344 insertions, 384 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index ea87744b22..7f851e3646 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1121,6 +1121,40 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ // Batched gemm with strides instead of pointer arrays.
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+
// Computes a matrix-matrix product where one input matrix is Hermitian:
//
// c <- alpha * a * b + beta * c,
@@ -1990,6 +2024,38 @@ class BlasSupport {
int ldb, std::complex<double> beta, \
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \
int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \
+ const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \
+ DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
+ int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \
+ int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \
+ int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, int64 stride_a, \
+ const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \
+ DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
+ int ldc, int64 stride_c, int batch_count); \
bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
uint64 m, uint64 n, std::complex<float> alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 874bf0e8cb..ab7091b3f5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx)
#if CUDA_VERSION >= 8000
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched)
#endif
#if CUDA_VERSION >= 9000
@@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#if CUDA_VERSION >= 9010
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx)
#endif
} // namespace wrap
@@ -643,7 +648,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
}
#endif
cublasStatus_t ret = cublas_func(parent_, blas_, args...);
- if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
+ if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
<< ToString(ret);
}
@@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm(
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
&cc_minor);
- // GPUs < sm_70 don't support Volta hardware.
+ // GPUs < sm_70 don't support tensor ops.
if (cc_major >= 7 && TensorOpMathEnabled()) {
use_tensor_ops = true;
}
@@ -2139,6 +2144,10 @@ static bool UsesTensorOps(blas::AlgorithmType algo) {
template <typename InType>
static bool TensorOpsAvailable(int cc_major) {
#if CUDA_VERSION >= 9000
+ // cublas *does* allow tensor ops on inputs that are not fp16, so this is not
+ // strictly correct. We can't simply enable it, though, as that would change
+ // clients' behavior significantly: Using tensor ops on fp32 inputs cause them
+ // to be rounded to fp16.
if (cc_major >= 7 && TensorOpMathEnabled() &&
std::is_same<InType, Eigen::half>::value) {
return true;
@@ -2160,16 +2169,30 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor) &&
cc_major < 5) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major
+ << cc_minor << " devices don't support explicit gemm algorithms.";
return false;
}
if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) {
+ if (std::is_same<InT, Eigen::half>::value) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
+ << algorithm
+ << " uses tensor ops, but tensor ops are not available in sm"
+ << cc_major << "X devices.";
+ } else {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
+ << algorithm
+ << " uses tensor ops, but the input data type is not fp16.";
+ }
return false;
}
// Either both 'alpha' and 'beta' need to be pointers to device memory, or
// they need to be both host scalars.
if (alpha.is_pointer() != beta.is_pointer()) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` "
+ "and `beta` is a pointer, but the other is not.";
return false;
}
@@ -2177,6 +2200,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because "
+ "output_profile_result was given, but we were unable to "
+ "create a CUDATimer.";
return false;
}
}
@@ -2186,6 +2212,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
std::max({m, n, k}) >= 2097153 && cc_major < 7) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn "
+ "<9.2 bug with m, n, or k >= 2097153. See b/79126339.";
return false;
}
#endif
@@ -2211,6 +2239,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
// CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsCUDAStream(stream))) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop "
+ "CUDATimer.";
return false;
}
output_profile_result->set_is_valid(true);
@@ -2223,26 +2253,60 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
bool CUDABlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
-// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
-// were first introduced in CUDA 8.
-// Note that when CUDA version and compute capability is not sufficient, we
-// still return the out_algorithms. Caller needs to make sure that in this case,
-// the returned vector is empty.
- for (cublasGemmAlgo_t algo : {
- CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
- CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
- CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
+ // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
+ // were first introduced in CUDA 8.
+ //
+ // Note that when CUDA version and compute capability is not sufficient, we
+ // still return the out_algorithms. Caller needs to make sure that in this
+ // case, the returned vector is empty.
+ *out_algorithms = {
+ CUBLAS_GEMM_DFALT,
+ CUBLAS_GEMM_ALGO0,
+ CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2,
+ CUBLAS_GEMM_ALGO3,
+ CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5,
+ CUBLAS_GEMM_ALGO6,
+ CUBLAS_GEMM_ALGO7,
#if CUDA_VERSION >= 9000
- CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
- CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
- CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
- CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
- CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
- CUBLAS_GEMM_ALGO2_TENSOR_OP
+ CUBLAS_GEMM_ALGO8,
+ CUBLAS_GEMM_ALGO9,
+ CUBLAS_GEMM_ALGO10,
+ CUBLAS_GEMM_ALGO11,
+ CUBLAS_GEMM_ALGO12,
+ CUBLAS_GEMM_ALGO13,
+ CUBLAS_GEMM_ALGO14,
+ CUBLAS_GEMM_ALGO15,
+ CUBLAS_GEMM_ALGO16,
+ CUBLAS_GEMM_ALGO17,
+ CUBLAS_GEMM_DFALT_TENSOR_OP,
+ CUBLAS_GEMM_ALGO0_TENSOR_OP,
+ CUBLAS_GEMM_ALGO1_TENSOR_OP,
+ CUBLAS_GEMM_ALGO2_TENSOR_OP,
+ CUBLAS_GEMM_ALGO3_TENSOR_OP,
+ CUBLAS_GEMM_ALGO4_TENSOR_OP,
#endif
- }) {
- out_algorithms->push_back(algo);
- }
+#if CUDA_VERSION >= 9200
+ CUBLAS_GEMM_ALGO18,
+ CUBLAS_GEMM_ALGO19,
+ CUBLAS_GEMM_ALGO20,
+ CUBLAS_GEMM_ALGO21,
+ CUBLAS_GEMM_ALGO22,
+ CUBLAS_GEMM_ALGO23,
+ CUBLAS_GEMM_ALGO5_TENSOR_OP,
+ CUBLAS_GEMM_ALGO6_TENSOR_OP,
+ CUBLAS_GEMM_ALGO7_TENSOR_OP,
+ CUBLAS_GEMM_ALGO8_TENSOR_OP,
+ CUBLAS_GEMM_ALGO9_TENSOR_OP,
+ CUBLAS_GEMM_ALGO10_TENSOR_OP,
+ CUBLAS_GEMM_ALGO11_TENSOR_OP,
+ CUBLAS_GEMM_ALGO12_TENSOR_OP,
+ CUBLAS_GEMM_ALGO13_TENSOR_OP,
+ CUBLAS_GEMM_ALGO14_TENSOR_OP,
+ CUBLAS_GEMM_ALGO15_TENSOR_OP,
+#endif
+ };
return true;
}
@@ -2564,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched(
return status.ok();
}
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ bool use_tensor_ops = false;
+#if CUDA_VERSION >= 9000
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor)) {
+ // GPUs < sm_70 don't support tensor ops.
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#if CUDA_VERSION >= 9010
+ if (cc_major >= 5) {
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasGemmStridedBatchedEx, stream,
+ true /* = pointer_mode_host */, true /* = err_on_failure */,
+ use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb),
+ m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a,
+ CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c),
+ CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo);
+ if (ok) {
+ return true;
+ }
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+#endif
+ }
+#endif
+ // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
+ for (int batch = 0; batch < batch_count; ++batch) {
+ const auto *a_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a);
+ const auto *b_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b);
+ auto *c_matrix =
+ reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */,
+ true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF,
+ lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix,
+ SE_CUDA_DATA_HALF, ldc);
+ if (!ok) {
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+ }
+ return true;
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<float> alpha,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 84916385a8..55408ab9ab 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -322,6 +322,7 @@ port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
CudnnSupport::CudnnSupport(CUDAExecutor* parent) : parent_(parent) {}
port::Status CudnnSupport::Init() {
+ ScopedActivateExecutorContext context(parent_);
cudnnHandle_t cudnn_handle = nullptr;
auto status = cudnnCreate(&cudnn_handle);
if (status == CUDNN_STATUS_SUCCESS) {
@@ -791,6 +792,11 @@ class CudnnActivationDescriptor {
double relu_ceiling = 0.0;
cudnnActivationMode_t mode;
switch (activation_mode) {
+#if CUDNN_VERSION >= 7100
+ case dnn::ActivationMode::kNone:
+ mode = CUDNN_ACTIVATION_IDENTITY;
+ break;
+#endif
case dnn::ActivationMode::kRelu6:
relu_ceiling = 6.0;
mode = CUDNN_ACTIVATION_CLIPPED_RELU;
@@ -1980,15 +1986,14 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmDesc& algorithm_desc,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv,
- const CudnnTensorDescriptor& output_nd,
+ const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc,
ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is
// functionally correct because the convolution is run with the algorithm of
// the last call to this function, but should be fixed anyway.
- conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+ conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled());
// Query the size of the workspace and allocate it.
size_t size_in_bytes;
@@ -1996,8 +2001,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
cudnn.handle(),
/*xDesc=*/input_nd.handle(),
/*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
+ /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(*algorithm_desc),
/*sizeInBytes=*/&size_in_bytes));
+
+ if (TF_PREDICT_FALSE(!algorithm_desc)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No AlgorithmDesc provided");
+ }
+ algorithm_desc->set_scratch_size(size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2022,15 +2033,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(
Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmDesc& algorithm_desc,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv,
- const CudnnTensorDescriptor& output_nd,
+ const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc,
ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is
// functionally correct because the convolution is run with the algorithm of
// the last call to this function, but should be fixed anyway.
- conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+ conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled());
// Query the size of the workspace and allocate it.
size_t size_in_bytes;
@@ -2040,8 +2050,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
/*dyDesc=*/output_nd.handle(),
/*convDesc=*/conv.handle(),
/*dxDesc=*/input_nd.handle(),
- /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
+ /*algo=*/ToConvBackwardDataAlgo(*algorithm_desc),
/*sizeInBytes=*/&size_in_bytes));
+
+ if (TF_PREDICT_FALSE(!algorithm_desc)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No AlgorithmDesc provided");
+ }
+ algorithm_desc->set_scratch_size(size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2066,15 +2082,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(
Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmDesc& algorithm_desc,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv,
- const CudnnTensorDescriptor& output_nd,
+ const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc,
ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is
// functionally correct because the convolution is run with the algorithm of
// the last call to this function, but should be fixed anyway.
- conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+ conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled());
// Query the size of the workspace and allocate it.
size_t size_in_bytes;
@@ -2084,8 +2099,14 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
/*dyDesc=*/output_nd.handle(),
/*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(),
- /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
+ /*algo=*/ToConvBackwardFilterAlgo(*algorithm_desc),
/*sizeInBytes=*/&size_in_bytes));
+
+ if (TF_PREDICT_FALSE(!algorithm_desc)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No AlgorithmDesc provided");
+ }
+ algorithm_desc->set_scratch_size(size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2132,7 +2153,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
}
auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
- stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc,
scratch_allocator);
if (scratch_or.ok()) {
@@ -2149,11 +2170,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
"while a secondary algorithm is not provided.");
}
- SE_ASSIGN_OR_RETURN(
- *scratch, AllocateCudnnConvolutionForwardWorkspace(
- stream, cudnn, algorithm_config.algorithm_no_scratch(),
- input_nd, filter, conv, output_nd, scratch_allocator));
- return algorithm_config.algorithm_no_scratch();
+ algo_desc = algorithm_config.algorithm_no_scratch();
+ SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
+ stream, cudnn, input_nd, filter, conv,
+ output_nd, &algo_desc, scratch_allocator));
+ return algo_desc;
}
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
@@ -2181,7 +2202,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
}
auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
- stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc,
scratch_allocator);
if (scratch_or.ok()) {
@@ -2198,11 +2219,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
"while a secondary algorithm is not provided.");
}
- SE_ASSIGN_OR_RETURN(
- *scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
- stream, cudnn, algorithm_config.algorithm_no_scratch(),
- input_nd, filter, conv, output_nd, scratch_allocator));
- return algorithm_config.algorithm_no_scratch();
+ algo_desc = algorithm_config.algorithm_no_scratch();
+ SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
+ stream, cudnn, input_nd, filter, conv,
+ output_nd, &algo_desc, scratch_allocator));
+ return algo_desc;
}
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
@@ -2230,7 +2251,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
}
auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
- stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc,
scratch_allocator);
if (scratch_or.ok()) {
@@ -2247,11 +2268,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
"while a secondary algorithm is not provided.");
}
- SE_ASSIGN_OR_RETURN(*scratch,
- AllocateCudnnConvolutionBackwardFilterWorkspace(
- stream, cudnn, algorithm_config.algorithm(), input_nd,
- filter, conv, output_nd, scratch_allocator));
- return algorithm_config.algorithm_no_scratch();
+ algo_desc = algorithm_config.algorithm_no_scratch();
+ SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
+ stream, cudnn, input_nd, filter, conv,
+ output_nd, &algo_desc, scratch_allocator));
+ return algo_desc;
}
// A helper class to set env-vars and choose options for cudnn-related
@@ -2480,10 +2501,11 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- if (activation_mode != dnn::ActivationMode::kRelu) {
+ if (activation_mode != dnn::ActivationMode::kRelu &&
+ activation_mode != dnn::ActivationMode::kNone) {
return port::Status(port::error::INVALID_ARGUMENT,
"cudnnConvolutionBiasActivationForward() only supports "
- "Relu activation.");
+ "Relu or None activation.");
}
CudnnTensorDescriptor conv_input_nd(
@@ -3075,8 +3097,7 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
- // zero-initialized.
- // TODO(timshen): Add an nvbugs/ link.
+ // zero-initialized, nvbugs/2254619.
if (CUDNN_VERSION >= 7000 &&
algorithm_config.algorithm().algo_id() ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
@@ -3603,7 +3624,7 @@ bool CudnnSupport::DoPoolForward(
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<double>* output_data) {
+ DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
// Alpha is the scaling factor for input.
double alpha = 1.0;
// Beta is the scaling factor for output.
@@ -3628,7 +3649,7 @@ bool CudnnSupport::DoPoolForward(
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<float>* output_data) {
+ DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
@@ -3653,7 +3674,8 @@ bool CudnnSupport::DoPoolForward(
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<Eigen::half>* output_data) {
+ DeviceMemory<Eigen::half>* output_data,
+ ScratchAllocator* workspace_allocator) {
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
@@ -3679,7 +3701,8 @@ bool CudnnSupport::DoPoolBackward(
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<double>& output_data,
const DeviceMemory<double>& input_diff_data,
- DeviceMemory<double>* output_diff_data) {
+ DeviceMemory<double>* output_diff_data,
+ ScratchAllocator* workspace_allocator) {
// Alpha is the scaling factor for input.
double alpha = 1.0;
// Beta is the scaling factor for output.
@@ -3708,7 +3731,8 @@ bool CudnnSupport::DoPoolBackward(
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<float>& output_data,
const DeviceMemory<float>& input_diff_data,
- DeviceMemory<float>* output_diff_data) {
+ DeviceMemory<float>* output_diff_data,
+ ScratchAllocator* workspace_allocator) {
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
@@ -3737,7 +3761,8 @@ bool CudnnSupport::DoPoolBackward(
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<Eigen::half>& output_data,
const DeviceMemory<Eigen::half>& input_diff_data,
- DeviceMemory<Eigen::half>* output_diff_data) {
+ DeviceMemory<Eigen::half>* output_diff_data,
+ ScratchAllocator* workspace_allocator) {
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
@@ -3806,7 +3831,8 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
const DeviceMemory<float>& normalized_data,
const DeviceMemory<float>& normalized_variable_gradient,
- DeviceMemory<float>* raw_variable_gradient) {
+ DeviceMemory<float>* raw_variable_gradient,
+ ScratchAllocator* workspace_allocator) {
// Check for unsupported modes.
if (normalize_descriptor.wrap_around()) {
LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index c924d41cb5..9d88f971bb 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -515,21 +515,24 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<double>* output_data) override;
+ DeviceMemory<double>* output_data,
+ ScratchAllocator* workspace_allocator) override;
bool DoPoolForward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<float>* output_data) override;
+ DeviceMemory<float>* output_data,
+ ScratchAllocator* workspace_allocator) override;
bool DoPoolForward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<Eigen::half>* output_data) override;
+ DeviceMemory<Eigen::half>* output_data,
+ ScratchAllocator* workspace_allocator) override;
bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
@@ -538,7 +541,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<double>& output_data,
const DeviceMemory<double>& input_diff_data,
- DeviceMemory<double>* output_diff_data) override;
+ DeviceMemory<double>* output_diff_data,
+ ScratchAllocator* workspace_allocator) override;
bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
@@ -547,7 +551,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<float>& output_data,
const DeviceMemory<float>& input_diff_data,
- DeviceMemory<float>* output_diff_data) override;
+ DeviceMemory<float>* output_diff_data,
+ ScratchAllocator* workspace_allocator) override;
bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
@@ -556,7 +561,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<Eigen::half>& output_data,
const DeviceMemory<Eigen::half>& input_diff_data,
- DeviceMemory<Eigen::half>* output_diff_data) override;
+ DeviceMemory<Eigen::half>* output_diff_data,
+ ScratchAllocator* workspace_allocator) override;
bool DoNormalize(Stream* stream,
const dnn::NormalizeDescriptor& normalize_descriptor,
@@ -575,7 +581,8 @@ class CudnnSupport : public dnn::DnnSupport {
const DeviceMemory<float>& raw_data,
const DeviceMemory<float>& normalized_data,
const DeviceMemory<float>& normalized_variable_gradient,
- DeviceMemory<float>* raw_variable_gradient) override;
+ DeviceMemory<float>* raw_variable_gradient,
+ ScratchAllocator* workspace_allocator) override;
bool DoDepthConcatenate(
Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index d508f6594a..f982f34b98 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/human_readable.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -66,14 +67,17 @@ class CreatedContexts {
return Live()->find(context) != Live()->end();
}
- // Adds context to the live set.
+ // Adds context to the live set, or returns it if it's already present.
static CudaContext* Add(CUcontext context) {
CHECK(context != nullptr);
mutex_lock lock(mu_);
- auto cuda_context = new CudaContext(context, next_id_++);
- Live()->insert(
- std::make_pair(context, std::unique_ptr<CudaContext>(cuda_context)));
- return cuda_context;
+ auto insert_result = Live()->insert(std::make_pair(context, nullptr));
+ auto it = insert_result.first;
+ if (insert_result.second) {
+ // context was not present in the map. Add it.
+ it->second = MakeUnique<CudaContext>(context, next_id_++);
+ }
+ return it->second.get();
}
// Removes context from the live set.
@@ -102,117 +106,16 @@ class CreatedContexts {
/* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context"
// Formats CUresult to output prettified values into a log stream.
-// Error summaries taken from:
-// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc6c391505e117393cc2558fff6bfc2e9
-//
-// TODO(leary) switch to cuGetErrorName when updated cuda.h is available.
string ToString(CUresult result) {
-#define OSTREAM_CUDA_ERROR(__name) \
- case CUDA_ERROR_##__name: \
- return "CUDA_ERROR_" #__name;
-
-///////////////
-// NOTE: here we specify return code values outside of the enum explicitly
-// because our in-tree cuda.h is from the CUDA 5.5 SDK, but CUDA 6.0+ driver
-// libraries are deployed in the fleet these error codes are backwards
-// compatible, but if we see a "new" one, we want to be able to identify it in
-// the logs.
-//
-// Once we get a cuda.h that has cuGetErrorName (TODO is above) we can
-// eliminate this function and just rely on the driver to provide us these
-// strings.
-//
-// NOTE: "Must reboot all context" below is shorthand for, "must
-// destroy/recreate the offending context and any allocation which come from
-// it if you are to continue using CUDA."
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wswitch"
- switch (result) {
- OSTREAM_CUDA_ERROR(INVALID_VALUE)
- OSTREAM_CUDA_ERROR(OUT_OF_MEMORY)
- OSTREAM_CUDA_ERROR(NOT_INITIALIZED)
- OSTREAM_CUDA_ERROR(DEINITIALIZED)
- OSTREAM_CUDA_ERROR(NO_DEVICE)
- OSTREAM_CUDA_ERROR(INVALID_DEVICE)
- OSTREAM_CUDA_ERROR(INVALID_IMAGE)
- OSTREAM_CUDA_ERROR(INVALID_CONTEXT)
- OSTREAM_CUDA_ERROR(INVALID_HANDLE)
- OSTREAM_CUDA_ERROR(NOT_FOUND)
- OSTREAM_CUDA_ERROR(NOT_READY)
- OSTREAM_CUDA_ERROR(NO_BINARY_FOR_GPU)
-
- // Encountered an uncorrectable ECC error during execution.
- OSTREAM_CUDA_ERROR(ECC_UNCORRECTABLE)
-
- // Load/store on an invalid address. Must reboot all context.
- case 700:
- return "CUDA_ERROR_ILLEGAL_ADDRESS";
- // Passed too many / wrong arguments, too many threads for register count.
- case 701:
- return "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES";
- // Kernel took too long to execute.
- case 702:
- return "CUDA_ERROR_LAUNCH_TIMEOUT";
- // Kernel launch uses an incompatible texturing mode.
- case 703:
- return "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING";
- // Trying to re-enable peer access that already has it enabled.
- case 704:
- return "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED";
- // Trying to disable peer access that has not yet been enabled.
- case 705:
- return "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED";
- // Primary context for the specified device has already been initialized.
- case 708:
- return "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE";
- // Context current to calling thread has been destroyed or is a primary
- // context that has not yet been initialized.
- case 709:
- return "CUDA_ERROR_CONTEXT_IS_DESTROYED";
- // Device-side assert triggered during kernel execution. Must reboot all
- // context.
- case 710:
- return "CUDA_ERROR_ASSERT";
- // Hardware resources to enable peer access have been exhausted.
- case 711:
- return "CUDA_ERROR_TOO_MANY_PEERS";
- // Memory range has already been registered.
- case 712:
- return "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED";
- // Pointer does not correspond to any currently registered memory region.
- case 713:
- return "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED";
- // Due to stack corruption or exceeding stack size limit. Must reboot all
- // context.
- case 714:
- return "CUDA_ERROR_HARDWARE_STACK_ERROR";
- case 715:
- return "CUDA_ERROR_ILLEGAL_INSTRUCTION";
- // Load/store on an unaligned memory address. Must reboot all context.
- case 716:
- return "CUDA_ERROR_MISALIGNED_ADDRESS";
- // Device instruction with specific address space given address not
- // belonging to allowed address space. Must reboot all context.
- case 717:
- return "CUDA_ERROR_INVALID_ADDRESS_SPACE";
- // Device program counter wrapped its address space. Must reboot all
- // context.
- case 718:
- return "CUDA_ERROR_INVALID_PC";
- // Exception on device while executing a kernel; e.g. deref invalid device
- // pointer, accessing OOB shared memory. Must reboot all context.
- case 719:
- return "CUDA_ERROR_LAUNCH_FAILED";
-
- OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE)
- OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED)
- OSTREAM_CUDA_ERROR(NOT_PERMITTED)
- OSTREAM_CUDA_ERROR(NOT_SUPPORTED)
- OSTREAM_CUDA_ERROR(UNKNOWN) // Unknown internal error to CUDA.
- default:
- return port::StrCat("CUresult(", static_cast<int>(result), ")");
+ const char *error_name;
+ if (cuGetErrorName(result, &error_name)) {
+ return port::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
+ }
+ const char *error_string;
+ if (cuGetErrorString(result, &error_string)) {
+ return error_name;
}
-#pragma GCC diagnostic pop
+ return port::StrCat(error_name, ": ", error_string);
}
// Returns the current context and checks that it is in the set of CUDA contexts
@@ -528,7 +431,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
*context = CreatedContexts::Add(new_context);
CHECK(*context != nullptr)
<< "success in this call must entail non-null result";
- VLOG(2) << "created context " << context << " for this thread";
+ VLOG(2) << "created or reused context " << context << " for this thread";
return port::Status::OK();
}
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index f11022ef1d..73f05b94db 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -206,6 +206,48 @@ static string GetBinaryDir(bool strip_exe) {
return exe_path;
}
+bool CUDAExecutor::LoadModuleFromCuBin(const char *cubin, CUmodule *module) {
+ uint64_t module_refcount;
+ std::tie(*module, module_refcount) = gpu_binary_to_module_[cubin];
+
+ if (*module == nullptr) {
+ auto load_status = CUDADriver::LoadCubin(context_, cubin, module);
+ if (!load_status.ok()) {
+ LOG(ERROR) << "failed to load CUBIN: " << load_status;
+ return false;
+ }
+ module_refcount = 1;
+ VLOG(3) << "Loaded CUBIN " << static_cast<const void *>(cubin)
+ << " as module " << *module;
+ } else {
+ ++module_refcount;
+ VLOG(3) << "CUBIN " << static_cast<const void *>(cubin)
+ << " is already loaded as module " << *module;
+ }
+ gpu_binary_to_module_[cubin] = {*module, module_refcount};
+ return true;
+}
+
+bool CUDAExecutor::LoadModuleFromPtx(const char *ptx, CUmodule *module) {
+ uint64_t module_refcount;
+ std::tie(*module, module_refcount) = gpu_binary_to_module_[ptx];
+
+ if (*module == nullptr) {
+ if (!CUDADriver::LoadPtx(context_, ptx, module)) {
+ return false;
+ }
+ VLOG(3) << "Loaded PTX " << static_cast<const void *>(ptx) << " as module "
+ << *module;
+ module_refcount = 1;
+ } else {
+ ++module_refcount;
+ VLOG(3) << "PTX " << static_cast<const void *>(ptx)
+ << " is already loaded as module " << module;
+ }
+ gpu_binary_to_module_[ptx] = {*module, module_refcount};
+ return true;
+}
+
bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
KernelBase *kernel) {
CUDAKernel *cuda_kernel = AsCUDAKernel(kernel);
@@ -215,28 +257,13 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name();
if (spec.has_cuda_cubin_in_memory()) {
+ mutex_lock lock{in_memory_modules_mu_};
kernelname = &spec.cuda_cubin_in_memory().kernelname();
const char *cubin = spec.cuda_cubin_in_memory().bytes();
- mutex_lock lock{in_memory_modules_mu_};
- uint64_t module_refcount;
- std::tie(module, module_refcount) = gpu_binary_to_module_[cubin];
-
- if (module == nullptr) {
- auto load_status = CUDADriver::LoadCubin(context_, cubin, &module);
- if (!load_status.ok()) {
- LOG(ERROR) << "failed to load CUBIN: " << load_status;
- return false;
- }
- module_refcount = 1;
- VLOG(3) << "Loaded CUBIN " << static_cast<const void *>(cubin)
- << " as module " << module;
- } else {
- ++module_refcount;
- VLOG(3) << "CUBIN " << static_cast<const void *>(cubin)
- << " is already loaded as module " << module;
+ if (!LoadModuleFromCuBin(cubin, &module)) {
+ return false;
}
kernel_to_gpu_binary_[kernel] = cubin;
- gpu_binary_to_module_[cubin] = {module, module_refcount};
} else if (spec.has_cuda_ptx_in_memory()) {
kernelname = &spec.cuda_ptx_in_memory().kernelname();
@@ -254,24 +281,10 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
}
mutex_lock lock{in_memory_modules_mu_};
- uint64_t module_refcount;
- std::tie(module, module_refcount) = gpu_binary_to_module_[ptx];
-
- if (module == nullptr) {
- if (!CUDADriver::LoadPtx(context_, ptx, &module)) {
- LOG(ERROR) << "failed to load PTX for kernel " << *kernelname;
- return false;
- }
- VLOG(3) << "Loaded PTX " << static_cast<const void *>(ptx)
- << " as module " << module;
- module_refcount = 1;
- } else {
- ++module_refcount;
- VLOG(3) << "PTX " << static_cast<const void *>(ptx)
- << " is already loaded as module " << module;
+ if (!LoadModuleFromPtx(ptx, &module)) {
+ return false;
}
kernel_to_gpu_binary_[kernel] = ptx;
- gpu_binary_to_module_[ptx] = {module, module_refcount};
} else {
LOG(WARNING) << "no method of loading CUDA kernel provided";
return false;
@@ -295,6 +308,23 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
return true;
}
+bool CUDAExecutor::UnloadGpuBinary(const void *gpu_binary) {
+ auto module_it = gpu_binary_to_module_.find(gpu_binary);
+ if (gpu_binary_to_module_.end() == module_it) {
+ VLOG(3) << "No loaded CUDA module for " << gpu_binary;
+ return false;
+ }
+ auto &module = module_it->second.first;
+ auto &refcount = module_it->second.second;
+ VLOG(3) << "Found CUDA module " << module << " with refcount " << refcount;
+ if (--refcount == 0) {
+ VLOG(3) << "Unloading CUDA module " << module;
+ CUDADriver::UnloadModule(context_, module);
+ gpu_binary_to_module_.erase(module_it);
+ }
+ return true;
+}
+
void CUDAExecutor::UnloadKernel(const KernelBase *kernel) {
VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name();
@@ -307,25 +337,52 @@ void CUDAExecutor::UnloadKernel(const KernelBase *kernel) {
}
VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
<< " has loaded GPU code " << gpu_binary_it->second;
- auto module_it = gpu_binary_to_module_.find(gpu_binary_it->second);
- if (gpu_binary_to_module_.end() == module_it) {
- VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
- << " has no loaded CUDA module.";
- return; // This kernel never loaded any modules
- }
- auto &module = module_it->second.first;
- auto &refcount = module_it->second.second;
- VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
- << " has loaded GPU code " << gpu_binary_it->second
- << " into CUDA module " << module << " with refcount " << refcount;
- if (--refcount == 0) {
- VLOG(3) << "Unloading CUDA module " << module;
- CUDADriver::UnloadModule(context_, module);
- gpu_binary_to_module_.erase(module_it);
- }
+ UnloadGpuBinary(gpu_binary_it->second);
kernel_to_gpu_binary_.erase(gpu_binary_it);
}
+bool CUDAExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) {
+ // In CUDAExecutor we store the pointer to the GPU binary (PTX or CUBIN) as
+ // ModuleHandle::id().
+ CUmodule cu_module;
+ if (spec.has_cuda_cubin_in_memory()) {
+ mutex_lock lock{in_memory_modules_mu_};
+ if (!LoadModuleFromCuBin(
+ reinterpret_cast<const char *>(spec.cuda_cubin_in_memory().data()),
+ &cu_module)) {
+ return false;
+ }
+ *module_handle = ModuleHandle(const_cast<void *>(
+ static_cast<const void *>(spec.cuda_cubin_in_memory().data())));
+ return true;
+ } else if (spec.has_cuda_ptx_in_memory()) {
+ if (cc_major_ == 0 && cc_minor_ == 0) {
+ return false;
+ }
+
+ if (!spec.cuda_ptx_in_memory()) {
+ return false;
+ }
+
+ mutex_lock lock{in_memory_modules_mu_};
+ if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) {
+ return false;
+ }
+ *module_handle = ModuleHandle(const_cast<void *>(
+ static_cast<const void *>(spec.cuda_ptx_in_memory())));
+ return true;
+ }
+ LOG(WARNING) << "no method of loading CUDA module provided";
+ return false;
+}
+
+bool CUDAExecutor::UnloadModule(ModuleHandle module_handle) {
+ const char *gpu_binary = reinterpret_cast<const char *>(module_handle.id());
+ mutex_lock lock{in_memory_modules_mu_};
+ return UnloadGpuBinary(gpu_binary);
+}
+
bool CUDAExecutor::GetKernelMetadata(CUDAKernel *cuda_kernel,
KernelMetadata *kernel_metadata) {
int value;
@@ -783,16 +840,26 @@ bool CUDAExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
return CUDADriver::GetDeviceMemoryInfo(context_, free, total);
}
-bool CUDAExecutor::GetSymbol(const string& symbol_name, void **mem,
+bool CUDAExecutor::GetSymbol(const string &symbol_name,
+ ModuleHandle module_handle, void **mem,
size_t *bytes) {
+ auto lookup_in_module = [&](CUmodule module) {
+ CHECK(module != nullptr);
+ return CUDADriver::GetModuleSymbol(context_, module, symbol_name.c_str(),
+ reinterpret_cast<CUdeviceptr *>(mem),
+ bytes);
+ };
+
{ // give limited scope to mutex_lock
mutex_lock lock{in_memory_modules_mu_};
+ if (static_cast<bool>(module_handle)) {
+ auto it = gpu_binary_to_module_.find(module_handle.id());
+ CHECK(it != gpu_binary_to_module_.end());
+ return lookup_in_module(it->second.first);
+ }
+
for (auto &it : gpu_binary_to_module_) {
- CUmodule module = it.second.first;
- CHECK(module != nullptr);
- if (CUDADriver::GetModuleSymbol(context_, module, symbol_name.c_str(),
- reinterpret_cast<CUdeviceptr *>(mem),
- bytes)) {
+ if (lookup_in_module(it.second.first)) {
return true;
}
}
@@ -844,7 +911,7 @@ CUDAExecutor::GetTimerImplementation() {
return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this));
}
-void *CUDAExecutor::CudaContextHack() { return context_; }
+void *CUDAExecutor::GpuContextHack() { return context_; }
CudaContext* CUDAExecutor::cuda_context() { return context_; }
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 773cbfb8a1..8a954d5461 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -62,6 +62,9 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
bool GetKernel(const MultiKernelLoaderSpec &spec,
KernelBase *kernel) override;
void UnloadKernel(const KernelBase *kernel) override;
+ bool LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) override;
+ bool UnloadModule(ModuleHandle module_handle) override;
bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &k,
@@ -175,7 +178,8 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
// Search for the symbol and returns a device pointer and size.
// Returns false if symbol does not exist.
- bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) override;
+ bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+ void **mem, size_t *bytes) override;
DeviceDescription *PopulateDeviceDescription() const override;
@@ -210,7 +214,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override;
- void *CudaContextHack() override;
+ void *GpuContextHack() override;
CudaContext* cuda_context();
@@ -239,6 +243,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
void VlogOccupancyInfo(const KernelBase &kernel, const ThreadDim &thread_dims,
const BlockDim &block_dims);
+ bool LoadModuleFromCuBin(const char *cubin, CUmodule *module)
+ EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
+
+ // Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated.
+ bool LoadModuleFromPtx(const char *ptx, CUmodule *module)
+ EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
+
+ bool UnloadGpuBinary(const void *gpu_binary)
+ EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);
+
// Guards the in-memory-module mapping.
mutex in_memory_modules_mu_;
diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h
index 02edff6431..bb8bda4755 100644
--- a/tensorflow/stream_executor/cuda/cuda_stream.h
+++ b/tensorflow/stream_executor/cuda/cuda_stream.h
@@ -40,8 +40,8 @@ class CUDAStream : public internal::StreamInterface {
// Note: teardown is handled by a parent's call to DeallocateStream.
~CUDAStream() override {}
- void *CudaStreamHack() override { return cuda_stream_; }
- void **CudaStreamMemberHack() override {
+ void *GpuStreamHack() override { return cuda_stream_; }
+ void **GpuStreamMemberHack() override {
return reinterpret_cast<void **>(&cuda_stream_);
}
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 82aa8ceb32..2a30f922bc 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -117,6 +117,8 @@ string FilterLayoutString(FilterLayout layout) {
switch (layout) {
case FilterLayout::kOutputInputYX:
return "OutputInputYX";
+ case FilterLayout::kOutputYXInput:
+ return "OutputYXInput";
case FilterLayout::kOutputInputYX4:
return "OutputInputYX4";
case FilterLayout::kInputYXOutput:
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 9eca5abe1a..9abfa1db6a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -713,15 +713,23 @@ class PoolingDescriptor {
class AlgorithmDesc {
public:
typedef int64 Index;
- AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
+ AlgorithmDesc()
+ : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {}
AlgorithmDesc(Index a, bool use_tensor_ops)
- : algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
+ : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {}
+ AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size)
+ : algo_(a),
+ tensor_ops_enabled_(use_tensor_ops),
+ scratch_size_(scratch_size) {}
bool is_default() const { return algo_ == kDefaultAlgorithm; }
bool tensor_ops_enabled() const { return tensor_ops_enabled_; }
Index algo_id() const { return algo_; }
+ size_t scratch_size() const { return scratch_size_; }
+ void set_scratch_size(size_t val) { scratch_size_ = val; }
bool operator==(const AlgorithmDesc& other) const {
return this->algo_ == other.algo_ &&
- this->tensor_ops_enabled_ == other.tensor_ops_enabled_;
+ this->tensor_ops_enabled_ == other.tensor_ops_enabled_ &&
+ this->scratch_size_ == other.scratch_size_;
}
uint64 hash() const;
@@ -729,6 +737,7 @@ class AlgorithmDesc {
enum { kDefaultAlgorithm = -1 };
Index algo_;
bool tensor_ops_enabled_;
+ size_t scratch_size_;
};
// Describes the result from a perf experiment.
@@ -1552,14 +1561,16 @@ class DnnSupport {
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<float>* output_data) = 0;
+ DeviceMemory<float>* output_data,
+ ScratchAllocator* workspace_allocator) = 0;
virtual bool DoPoolForward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<double>* output_data) {
+ DeviceMemory<double>* output_data,
+ ScratchAllocator* workspace_allocator) {
LOG(FATAL) << "DoPoolForward not implemented for double.";
return false;
}
@@ -1569,7 +1580,8 @@ class DnnSupport {
const dnn::BatchDescriptor& input_dimensions,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& output_dimensions,
- DeviceMemory<Eigen::half>* output_data) {
+ DeviceMemory<Eigen::half>* output_data,
+ ScratchAllocator* workspace_allocator) {
LOG(FATAL) << "DoPoolForward not implemented for float16.";
return false;
}
@@ -1582,7 +1594,8 @@ class DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<double>& output_data,
const DeviceMemory<double>& input_diff_data,
- DeviceMemory<double>* output_diff_data) {
+ DeviceMemory<double>* output_diff_data,
+ ScratchAllocator* workspace_allocator) {
LOG(FATAL) << "DoPoolBackward not implemented.";
return false;
}
@@ -1594,7 +1607,8 @@ class DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<float>& output_data,
const DeviceMemory<float>& input_diff_data,
- DeviceMemory<float>* output_diff_data) {
+ DeviceMemory<float>* output_diff_data,
+ ScratchAllocator* workspace_allocator) {
LOG(FATAL) << "DoPoolBackward not implemented.";
return false;
}
@@ -1606,7 +1620,8 @@ class DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
const DeviceMemory<Eigen::half>& output_data,
const DeviceMemory<Eigen::half>& input_diff_data,
- DeviceMemory<Eigen::half>* output_diff_data) {
+ DeviceMemory<Eigen::half>* output_diff_data,
+ ScratchAllocator* workspace_allocator) {
LOG(FATAL) << "DoPoolBackward not implemented.";
return false;
}
@@ -1653,7 +1668,8 @@ class DnnSupport {
const DeviceMemory<float>& raw_data,
const DeviceMemory<float>& normalized_data,
const DeviceMemory<float>& normalized_variable_gradient,
- DeviceMemory<float>* raw_variable_gradient) {
+ DeviceMemory<float>* raw_variable_gradient,
+ ScratchAllocator* workspace_allocator) {
return false;
}
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc
index 3cd97b3cf1..8adf739b17 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.cc
+++ b/tensorflow/stream_executor/host/host_gpu_executor.cc
@@ -93,7 +93,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream,
// the nature of the HostExecutor) memcpy on the stream (HostStream)
// associated with the HostExecutor.
AsHostStream(stream)->EnqueueTask(
- [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); });
+ [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); });
return true;
}
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index e82f57569f..7ba1f18101 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -88,7 +88,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
uint64 size) override;
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return false; }
+ bool SynchronizeAllActivity() override { return true; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override;
bool SynchronousMemSet(DeviceMemoryBase *location, int value,
@@ -202,7 +202,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
return std::unique_ptr<internal::TimerInterface>(new HostTimer());
}
- void *CudaContextHack() override { return nullptr; }
+ void *GpuContextHack() override { return nullptr; }
private:
const PluginConfig plugin_config_;
diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc
index 5a7d3b3dd4..bfbfb56cd7 100644
--- a/tensorflow/stream_executor/host/host_stream.cc
+++ b/tensorflow/stream_executor/host/host_stream.cc
@@ -28,18 +28,28 @@ HostStream::HostStream()
HostStream::~HostStream() {}
bool HostStream::EnqueueTask(std::function<void()> task) {
+ struct NotifiedTask {
+ HostStream* stream;
+ std::function<void()> task;
+
+ void operator()() {
+ task();
+ // Destroy the task before unblocking its waiters, as BlockHostUntilDone()
+ // should guarantee that all tasks are destroyed.
+ task = std::function<void()>();
+ {
+ mutex_lock lock(stream->mu_);
+ --stream->pending_tasks_;
+ }
+ stream->completion_condition_.notify_all();
+ }
+ };
+
{
mutex_lock lock(mu_);
++pending_tasks_;
}
- host_executor_->Schedule([this, task]() {
- task();
- {
- mutex_lock lock(mu_);
- --pending_tasks_;
- }
- completion_condition_.notify_all();
- });
+ host_executor_->Schedule(NotifiedTask{this, std::move(task)});
return true;
}
diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h
index 5d7b8a3782..be88f074cf 100644
--- a/tensorflow/stream_executor/host/host_stream.h
+++ b/tensorflow/stream_executor/host/host_stream.h
@@ -34,8 +34,8 @@ class HostStream : public internal::StreamInterface {
bool EnqueueTask(std::function<void()> task);
- void *CudaStreamHack() override { return nullptr; }
- void **CudaStreamMemberHack() override { return nullptr; }
+ void *GpuStreamHack() override { return nullptr; }
+ void **GpuStreamMemberHack() override { return nullptr; }
void BlockUntilDone();
diff --git a/tensorflow/stream_executor/module_spec.h b/tensorflow/stream_executor/module_spec.h
new file mode 100644
index 0000000000..75bdfed2d7
--- /dev/null
+++ b/tensorflow/stream_executor/module_spec.h
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_
+#define TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_
+
+#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace stream_executor {
+
+// Describes how to load a module on a target platform.
+//
+// The exact meaning of a "module" may differ from platform to platform but
+// loosely speaking a module a collection of kernels and global variables. It
+// corresponds to CUmodule when running on CUDA.
+class MultiModuleLoaderSpec {
+ public:
+ bool has_cuda_cubin_in_memory() const { return has_cuda_cubin_in_memory_; }
+ port::ArraySlice<const uint8> cuda_cubin_in_memory() const {
+ CHECK(has_cuda_cubin_in_memory());
+ return {cuda_cubin_in_memory_.data(), cuda_cubin_in_memory_.size()};
+ }
+
+ bool has_cuda_ptx_in_memory() const { return has_cuda_ptx_in_memory_; }
+ const char* cuda_ptx_in_memory() const {
+ CHECK(has_cuda_ptx_in_memory());
+ return cuda_ptx_in_memory_;
+ }
+
+ void AddCudaCubinInMemory(port::ArraySlice<const uint8> cubin_bytes) {
+ CHECK(!cubin_bytes.empty());
+ has_cuda_cubin_in_memory_ = true;
+ cuda_cubin_in_memory_ = cubin_bytes;
+ }
+
+ void AddCudaPtxInMemory(const char* ptx) {
+ has_cuda_ptx_in_memory_ = true;
+ // The CUDA driver does not like getting an empty string as PTX.
+ cuda_ptx_in_memory_ = *ptx ? ptx : nullptr;
+ }
+
+ private:
+ port::ArraySlice<const uint8> cuda_cubin_in_memory_;
+ bool has_cuda_cubin_in_memory_ = false;
+ const char* cuda_ptx_in_memory_;
+ bool has_cuda_ptx_in_memory_ = false;
+};
+
+} // namespace stream_executor
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 9369183133..9efd34de24 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) {
}
string ToVlogString(const DeviceMemoryBase *memory) {
- return ToVlogString(*memory);
+ return memory == nullptr ? "null" : ToVlogString(*memory);
}
string ToVlogString(const Eigen::half &h) {
@@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream,
// constructing all the strings in params is expensive.
CHECK(VLOG_IS_ON(1));
- string str = port::StrCat("Called Stream::", function_name, "(");
+ string str = port::StrCat(stream->DebugStreamPointers(),
+ " Called Stream::", function_name, "(");
const char *separator = "";
for (const auto &param : params) {
port::StrAppend(&str, separator, param.first, "=", param.second);
separator = ", ";
}
- port::StrAppend(&str, ") stream=", ToVlogString(stream));
+ port::StrAppend(&str, ")");
if (VLOG_IS_ON(10)) {
port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
}
@@ -267,6 +268,12 @@ Stream::Stream(StreamExecutor *parent,
Stream::~Stream() {
VLOG_CALL();
+ // Ensure the stream is completed.
+ auto status = BlockHostUntilDone();
+ if (!status.ok()) {
+ LOG(WARNING) << "Error blocking host until done in stream destructor: "
+ << status;
+ }
temporary_memory_manager_.ForceDeallocateAll();
if (allocated_) {
@@ -1377,15 +1384,16 @@ Stream &Stream::ThenPoolForward(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<double> *output_data) {
+ DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
- PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
- input_data, output_dimensions,
- output_data));
+ input_data, output_dimensions, output_data,
+ workspace_allocator));
} else {
SetError();
LOG(WARNING)
@@ -1401,15 +1409,16 @@ Stream &Stream::ThenPoolForward(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<float> *output_data) {
+ DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
- PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
- input_data, output_dimensions,
- output_data));
+ input_data, output_dimensions, output_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1422,15 +1431,17 @@ Stream &Stream::ThenPoolForward(
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<Eigen::half> *output_data) {
+ DeviceMemory<Eigen::half> *output_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
- PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
- input_data, output_dimensions,
- output_data));
+ input_data, output_dimensions, output_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1445,16 +1456,19 @@ Stream &Stream::ThenPoolBackward(
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<double> &output_data,
const DeviceMemory<double> &input_diff_data,
- DeviceMemory<double> *output_diff_data) {
+ DeviceMemory<double> *output_diff_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
- PARAM(input_diff_data), PARAM(output_diff_data));
+ PARAM(input_diff_data), PARAM(output_diff_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
- input_diff_data, output_diff_data));
+ input_diff_data, output_diff_data,
+ workspace_allocator));
} else {
SetError();
LOG(WARNING)
@@ -1472,16 +1486,19 @@ Stream &Stream::ThenPoolBackward(
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<float> &output_data,
const DeviceMemory<float> &input_diff_data,
- DeviceMemory<float> *output_diff_data) {
+ DeviceMemory<float> *output_diff_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
- PARAM(input_diff_data), PARAM(output_diff_data));
+ PARAM(input_diff_data), PARAM(output_diff_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
- input_diff_data, output_diff_data));
+ input_diff_data, output_diff_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1496,16 +1513,19 @@ Stream &Stream::ThenPoolBackward(
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<Eigen::half> &output_data,
const DeviceMemory<Eigen::half> &input_diff_data,
- DeviceMemory<Eigen::half> *output_diff_data) {
+ DeviceMemory<Eigen::half> *output_diff_data,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
- PARAM(input_diff_data), PARAM(output_diff_data));
+ PARAM(input_diff_data), PARAM(output_diff_data),
+ PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
input_data, output_dimensions, output_data,
- input_diff_data, output_diff_data));
+ input_diff_data, output_diff_data,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1552,16 +1572,18 @@ Stream &Stream::ThenNormalizeBackwardWithDimensions(
const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
const DeviceMemory<float> &normalized_data,
const DeviceMemory<float> &normalized_variable_gradient,
- DeviceMemory<float> *raw_variable_gradient) {
+ DeviceMemory<float> *raw_variable_gradient,
+ ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
PARAM(normalized_data), PARAM(normalized_variable_gradient),
- PARAM(raw_variable_gradient));
+ PARAM(raw_variable_gradient), PARAM(workspace_allocator));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoNormalizeBackwardWithDimensions(
this, normalize_descriptor, dimensions, raw_data, normalized_data,
- normalized_variable_gradient, raw_variable_gradient));
+ normalized_variable_gradient, raw_variable_gradient,
+ workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@@ -1901,30 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer(
Stream *Stream::GetOrCreateSubStream() {
mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (stream.second) {
- stream.second = false;
- return stream.first.get();
+
+ // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
+ // we encounter along the way.
+ for (int64 index = 0; index < sub_streams_.size();) {
+ std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
+ if (pair.second) {
+ // The sub_stream is reusable.
+ Stream *sub_stream = pair.first.get();
+ if (sub_stream->ok()) {
+ VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
+ << sub_stream->DebugStreamPointers();
+ pair.second = false;
+ return sub_stream;
+ }
+
+ // The stream is reusable and not ok. Streams have a monotonic state
+ // machine; the stream will remain in !ok forever. Swap it with the last
+ // stream and pop it off.
+ const int64 last = sub_streams_.size() - 1;
+ if (index != last) {
+ std::swap(pair, sub_streams_[last]);
+ }
+ sub_streams_.pop_back();
+ VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ } else {
+ // The sub_stream is not reusable, move on to the next one.
+ ++index;
}
}
+
+ // No streams are reusable; create a new stream.
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
false);
Stream *sub_stream = sub_streams_.back().first.get();
sub_stream->Init();
CHECK(ok_) << "sub-stream failed to be initialized";
+ VLOG(1) << DebugStreamPointers() << " created new sub_stream "
+ << sub_stream->DebugStreamPointers();
return sub_stream;
}
void Stream::ReturnSubStream(Stream *sub_stream) {
mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (stream.first.get() == sub_stream) {
- stream.second = true;
- return;
+
+ // Look for the sub-stream.
+ for (int64 index = 0; index < sub_streams_.size(); ++index) {
+ std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
+ if (pair.first.get() != sub_stream) {
+ continue;
}
+
+ // Found the sub_stream.
+ if (sub_stream->ok()) {
+ VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ pair.second = true;
+ } else {
+ // The returned stream is not ok. Streams have a monotonic state
+ // machine; the stream will remain in !ok forever. Swap it with the last
+ // stream and pop it off.
+ VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ const int64 last = sub_streams_.size() - 1;
+ if (index != last) {
+ std::swap(pair, sub_streams_[last]);
+ }
+ sub_streams_.pop_back();
+ }
+ return;
}
- LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
+
+ LOG(FATAL) << DebugStreamPointers()
+ << " did not create the returned sub-stream "
+ << sub_stream->DebugStreamPointers();
}
Stream &Stream::ThenStartTimer(Timer *t) {
@@ -1933,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) {
if (ok()) {
CheckError(parent_->StartTimer(this, t));
} else {
- LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
+ LOG(INFO) << DebugStreamPointers()
+ << " did not enqueue 'start timer': " << t;
}
return *this;
}
@@ -1944,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) {
if (ok()) {
CheckError(parent_->StopTimer(this, t));
} else {
- LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
+ LOG(INFO) << DebugStreamPointers()
+ << " did not enqueue 'stop timer': " << t;
}
return *this;
}
@@ -1957,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) {
CheckError(parent_->CreateStreamDependency(this, other));
} else {
SetError();
- LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
+ LOG(INFO) << DebugStreamPointers() << " did not wait for "
+ << other->DebugStreamPointers();
}
return *this;
}
@@ -1974,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) {
<< "at fault. Monitor for further errors.";
}
} else {
- LOG(INFO) << "stream " << this << " did not wait for an event.";
+ LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
}
return *this;
}
@@ -4657,6 +4734,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
scratch_allocator);
}
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<Eigen::half> &, int, int64,
+ const DeviceMemory<Eigen::half> &, int, int64, float,
+ DeviceMemory<Eigen::half> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<float> &, int, int64,
+ const DeviceMemory<float> &, int, int64, float,
+ DeviceMemory<float> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
+ const DeviceMemory<double> &, int, int64,
+ const DeviceMemory<double> &, int, int64, double,
+ DeviceMemory<double> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, int64, const DeviceMemory<std::complex<float>> &, int,
+ int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, int64, const DeviceMemory<std::complex<double>> &, int,
+ int64, std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
@@ -4665,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
CheckError(rng->SetSeed(this, seed, seed_bytes));
} else {
SetError();
- LOG(INFO) << "stream " << this << " unable to initialize RNG";
+ LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
}
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not set RNG seed: " << static_cast<const void *>(seed)
<< "; bytes: " << seed_bytes;
}
@@ -4683,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4699,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4715,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4730,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4746,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4762,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "stream " << this
- << " attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4777,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
if (ok()) {
CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy device-to-host; source: " << gpu_src.opaque();
}
return *this;
@@ -4790,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
if (ok()) {
CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy host-to-device; source: " << host_src;
}
return *this;
@@ -4803,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
if (ok()) {
CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy gpu-to-gpu; source: " << &gpu_src;
}
return *this;
@@ -4815,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
if (ok()) {
CheckError(parent_->MemZero(this, location, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memzero GPU location; source: " << location;
}
return *this;
@@ -4828,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
if (ok()) {
CheckError(parent_->Memset32(this, location, pattern, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memset GPU location; source: " << location
<< "; size: " << size << "; pattern: " << std::hex << pattern;
}
@@ -5097,12 +5288,25 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
if (ok()) {
CheckError(parent_->HostCallback(this, callback));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " was in error state before adding host callback";
}
return *this;
}
+Stream &Stream::ThenDoHostCallbackWithStatus(
+ std::function<port::Status()> callback) {
+ VLOG_CALL(PARAM(callback));
+
+ if (ok()) {
+ CheckError(parent_->HostCallback(this, std::move(callback)));
+ } else {
+ LOG(WARNING) << "stream " << DebugStreamPointers()
+ << " was in error state before adding host callback";
+ }
+ return *this;
+}
+
Stream &Stream::ThenFft(fft::Plan *plan,
const DeviceMemory<std::complex<float>> &input,
DeviceMemory<std::complex<float>> *output) {
@@ -5113,8 +5317,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5130,8 +5335,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5146,8 +5352,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5162,8 +5369,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5179,8 +5387,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5196,8 +5405,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5224,7 +5434,7 @@ port::Status Stream::BlockHostUntilDone() {
port::Status status = port::Status(
port::error::INTERNAL,
"stream did not block host until done; was already in an error state");
- LOG(INFO) << status << " " << this;
+ LOG(INFO) << DebugStreamPointers() << " " << status;
return status;
}
@@ -5235,4 +5445,10 @@ port::Status Stream::BlockHostUntilDone() {
return error;
}
+string Stream::DebugStreamPointers() const {
+ // Relies on the ToVlogString(const void*) overload above.
+ return port::StrCat("[stream=", ToVlogString(this),
+ ",impl=", ToVlogString(implementation_.get()), "]");
+}
+
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index e8885e1eb6..e1629b5b30 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -122,10 +122,14 @@ class Stream {
// Get or create a sub-stream from this stream. If there is any sub-stream in
// the pool that can be reused then just return this sub-stream. Otherwise
// create a new sub-stream.
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
// Return the sub-stream back to the host stream so that it can be reused
- // later.
+ // later. Sub-streams that are !ok() will not be reused.
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
// Allocate temporary memories. The stream will deallocate them when blocked
@@ -629,19 +633,22 @@ class Stream {
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<double> *output_data);
+ DeviceMemory<double> *output_data,
+ ScratchAllocator *workspace_allocator = nullptr);
Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<float> *output_data);
+ DeviceMemory<float> *output_data,
+ ScratchAllocator *workspace_allocator = nullptr);
Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<Eigen::half> *output_data);
+ DeviceMemory<Eigen::half> *output_data,
+ ScratchAllocator *workspace_allocator = nullptr);
Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
@@ -649,7 +656,8 @@ class Stream {
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<double> &output_data,
const DeviceMemory<double> &input_diff_data,
- DeviceMemory<double> *output_diff_data);
+ DeviceMemory<double> *output_diff_data,
+ ScratchAllocator *workspace_allocator = nullptr);
Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
@@ -657,7 +665,8 @@ class Stream {
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<float> &output_data,
const DeviceMemory<float> &input_diff_data,
- DeviceMemory<float> *output_diff_data);
+ DeviceMemory<float> *output_diff_data,
+ ScratchAllocator *workspace_allocator = nullptr);
Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
@@ -665,7 +674,8 @@ class Stream {
const dnn::BatchDescriptor &output_dimensions,
const DeviceMemory<Eigen::half> &output_data,
const DeviceMemory<Eigen::half> &input_diff_data,
- DeviceMemory<Eigen::half> *output_diff_data);
+ DeviceMemory<Eigen::half> *output_diff_data,
+ ScratchAllocator *workspace_allocator = nullptr);
Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor,
const DeviceMemory<float> &input_data,
@@ -684,7 +694,8 @@ class Stream {
const DeviceMemory<float> &raw_data,
const DeviceMemory<float> &normalized_data,
const DeviceMemory<float> &normalized_variable_gradient,
- DeviceMemory<float> *raw_variable_gradient);
+ DeviceMemory<float> *raw_variable_gradient,
+ ScratchAllocator *workspace_allocator = nullptr);
Stream &ThenActivate(dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor &dimensions,
@@ -1550,6 +1561,38 @@ class Stream {
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count);
// See BlasSupport::DoBlasHemm.
Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
@@ -2002,6 +2045,11 @@ class Stream {
// negative effects on performance.
Stream &ThenDoHostCallback(std::function<void()> callback);
+ // Entrains onto the stream a callback to the host (from the device).
+ // Behaves as ThenDoHostCallback above, but returns a Status instead of void.
+ // This overload should be preferred if the callback could fail.
+ Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
+
// Returns the StreamExecutor (parent object) associated with this stream.
StreamExecutor *parent() const {
CHECK(parent_ != nullptr);
@@ -2012,6 +2060,9 @@ class Stream {
// with this stream.
internal::TemporaryMemoryManager *temporary_memory_manager();
+ // Returns a debugging string "[stream=0x...,impl=0x...]".
+ string DebugStreamPointers() const;
+
private:
friend class host::HostBlas; // for parent_.
friend class host::HostFft; // for parent_.
diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc
index 8297228e6f..7df6a361c6 100644
--- a/tensorflow/stream_executor/stream_executor_internal.cc
+++ b/tensorflow/stream_executor/stream_executor_internal.cc
@@ -36,5 +36,17 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() {
StreamExecutorFactory MakeHostExecutorImplementation;
+// TODO(b/112125301): Consolodate this down to one implementation of
+// HostCallback, taking a callback that returns a Status.
+bool StreamExecutorInterface::HostCallback(
+ Stream* stream, std::function<port::Status()> callback) {
+ return HostCallback(stream, [callback]() {
+ port::Status s = callback();
+ if (!s.ok()) {
+ LOG(WARNING) << "HostCallback failed: " << s;
+ }
+ });
+}
+
} // namespace internal
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 9c989b971d..59a477b5c9 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -36,20 +36,38 @@ limitations under the License.
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/kernel_spec.h"
#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/module_spec.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/shared_memory_config.h"
#include "tensorflow/stream_executor/trace_listener.h"
-#include "tensorflow/stream_executor/lib/inlined_vector.h"
namespace stream_executor {
class Stream;
class Timer;
+// An opaque handle to a loaded module.
+//
+// An instance of this is returned from StreamExecutor::GetModule.
+class ModuleHandle {
+ public:
+ /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {}
+
+ // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a
+ // null pointer.
+ void *id() const { return id_; }
+
+ explicit operator bool() const { return id() != nullptr; }
+
+ private:
+ void *id_;
+};
+
namespace internal {
// Platform-dependent interface class for the generic Events interface, in
@@ -100,19 +118,20 @@ class StreamInterface {
// Default destructor for the abstract interface.
virtual ~StreamInterface() {}
- // Returns the CUDA stream associated with this platform's stream
+ // Returns the GPU stream associated with this platform's stream
// implementation.
//
- // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
- // fatal error if it is not. This hack is made available solely for use from
- // distbelief code, which temporarily has strong ties to CUDA as a platform.
- virtual void *CudaStreamHack() { return nullptr; }
-
- // See the above comment on CudaStreamHack -- this further breaks abstraction
- // for Eigen within distbelief, which has strong ties to CUDA as a platform,
- // and a historical attachment to a programming model which takes a
+ // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
+ // causing a fatal error if it is not. This hack is made available solely for
+ // use from distbelief code, which temporarily has strong ties to CUDA or
+ // ROCm as a platform.
+ virtual void *GpuStreamHack() { return nullptr; }
+
+ // See the above comment on GpuStreamHack -- this further breaks abstraction
+ // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a
+ // platform, and a historical attachment to a programming model which takes a
// stream-slot rather than a stream-value.
- virtual void **CudaStreamMemberHack() { return nullptr; }
+ virtual void **GpuStreamMemberHack() { return nullptr; }
private:
SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
@@ -163,6 +182,11 @@ class StreamExecutorInterface {
KernelBase *kernel) {
return false;
}
+ virtual bool LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) {
+ return false;
+ }
+ virtual bool UnloadModule(ModuleHandle module_handle) { return false; }
virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &k,
const KernelArgsArrayBase &args) {
@@ -212,9 +236,11 @@ class StreamExecutorInterface {
virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
const void *host_src, uint64 size) = 0;
virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
- const DeviceMemoryBase &host_src,
+ const DeviceMemoryBase &gpu_src,
uint64 size) = 0;
virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0;
+ virtual bool HostCallback(Stream *stream,
+ std::function<port::Status()> callback);
virtual port::Status AllocateEvent(Event *event) = 0;
virtual port::Status DeallocateEvent(Event *event) = 0;
virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
@@ -246,7 +272,12 @@ class StreamExecutorInterface {
// null, however, both of them cannot be null at the same time. To use
// constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
// is found.
- virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) {
+ //
+ // If ModuleHandle is set then we search for `symbol_name` only within the
+ // module corresponding to `module_handle`. Otherwise all loaded modules are
+ // searched.
+ virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+ void **mem, size_t *bytes) {
return false;
}
@@ -324,13 +355,14 @@ class StreamExecutorInterface {
virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
- // Returns the CUDA context associated with this StreamExecutor platform
- // implementation.
+ // Returns the CUDA or ROCm context associated with this StreamExecutor
+ // platform implementation.
//
- // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
- // fatal error if it is not. This hack is made available solely for use from
- // distbelief code, which temporarily has strong ties to CUDA as a platform.
- virtual void *CudaContextHack() { return nullptr; }
+ // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm,
+ // causing a fatal error if it is not. This hack is made available solely for
+ // use from distbelief code, which temporarily has strong ties to CUDA or ROCm
+ // as a platform.
+ virtual void *GpuContextHack() { return nullptr; }
private:
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 000795ff00..9515d8e62a 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -220,6 +220,15 @@ void StreamExecutor::UnloadKernel(const KernelBase *kernel) {
implementation_->UnloadKernel(kernel);
}
+bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle) {
+ return implementation_->LoadModule(spec, module_handle);
+}
+
+bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
+ return implementation_->UnloadModule(module_handle);
+}
+
void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
<< ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
@@ -459,9 +468,34 @@ void *StreamExecutor::Allocate(uint64 size) {
return buf;
}
-bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
+port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
+ const string &symbol_name, ModuleHandle module_handle) {
+ // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
+ // be nullptr/0 for consistency with DeviceMemory semantics.
+ void *opaque = nullptr;
+ size_t bytes = 0;
+ if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
+ return DeviceMemoryBase(opaque, bytes);
+ }
+
+ if (static_cast<bool>(module_handle)) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::StrCat("Check if module containing symbol ", symbol_name,
+ " is loaded (module_handle = ",
+ reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
+ } else {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::StrCat("Check if kernel using the symbol is loaded: ",
+ symbol_name));
+ }
+}
+
+bool StreamExecutor::GetSymbol(const string &symbol_name,
+ ModuleHandle module_handle, void **mem,
size_t *bytes) {
- return implementation_->GetSymbol(symbol_name, mem, bytes);
+ return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
}
void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) {
@@ -665,6 +699,11 @@ bool StreamExecutor::HostCallback(Stream *stream,
return implementation_->HostCallback(stream, std::move(callback));
}
+bool StreamExecutor::HostCallback(Stream *stream,
+ std::function<port::Status()> callback) {
+ return implementation_->HostCallback(stream, std::move(callback));
+}
+
port::Status StreamExecutor::AllocateEvent(Event *event) {
return implementation_->AllocateEvent(event);
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index ad80a1ba25..437f298616 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -106,6 +106,16 @@ class StreamExecutor {
// Releases any state associated with the previously loaded kernel.
void UnloadKernel(const KernelBase *kernel);
+ // Loads a module for the platform this StreamExecutor is acting upon.
+ //
+ // `spec` describes the module to be loaded. On success writes the handle for
+ // the loaded module to `module_handle` and returns true. Else returns false.
+ bool LoadModule(const MultiModuleLoaderSpec &spec,
+ ModuleHandle *module_handle);
+
+ // Unloads the module with handle `module_handle`.
+ bool UnloadModule(ModuleHandle module_handle);
+
// Synchronously allocates an array on the device of type T with element_count
// elements.
template <typename T>
@@ -169,8 +179,16 @@ class StreamExecutor {
// type of symbol and T match.
// - Note: symbol_name should include its namespace as well. For example,
// pass "nms0::symbol" if referring to nms0::symbol.
+ //
+ // If `module_handle` is set then searches only within the module
+ // corresponding to `module_handle`.
template <typename T>
- port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name);
+ port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
+ ModuleHandle module_handle = {});
+
+ // An untyped version of GetSymbol.
+ port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
+ const string &symbol_name, ModuleHandle module_handle = {});
// Deallocate the DeviceMemory previously allocated via this interface.
// Deallocation of a nullptr-representative value is permitted.
@@ -507,7 +525,8 @@ class StreamExecutor {
// Finds and retrieves device memory for the symbol on the underlying
// platform.
- bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes);
+ bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
+ void **mem, size_t *bytes);
// Entrains a memcpy operation onto stream, with a host destination location
// host_dst and a device memory source, with target size size.
@@ -530,6 +549,11 @@ class StreamExecutor {
// See Stream::ThenDoHostCallback for full details.
bool HostCallback(Stream *stream, std::function<void()> callback);
+ // Entrains on a stream a user-specified function to be run on the host.
+ // See Stream::ThenDoHostCallback for full details.
+ // This is the preferred form for a callback that may return an error.
+ bool HostCallback(Stream *stream, std::function<port::Status()> callback);
+
// Performs platform-specific allocation and initialization of an event.
port::Status AllocateEvent(Event *event);
@@ -678,6 +702,41 @@ class StreamExecutor {
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
};
+// A wrapper around ModuleHandle that uses RAII to manage its lifetime.
+class ScopedModuleHandle {
+ public:
+ explicit ScopedModuleHandle(StreamExecutor *executor,
+ ModuleHandle module_handle)
+ : executor_(executor), module_handle_(module_handle) {}
+
+ ScopedModuleHandle(ScopedModuleHandle &&other) {
+ executor_ = other.executor_;
+ module_handle_ = other.module_handle_;
+ other.executor_ = nullptr;
+ other.module_handle_ = ModuleHandle();
+ }
+
+ ScopedModuleHandle &operator=(ScopedModuleHandle &&other) {
+ executor_ = other.executor_;
+ module_handle_ = other.module_handle_;
+ other.executor_ = nullptr;
+ other.module_handle_ = ModuleHandle();
+ return *this;
+ }
+
+ ~ScopedModuleHandle() {
+ if (static_cast<bool>(module_handle_)) {
+ CHECK(executor_->UnloadModule(module_handle_));
+ }
+ }
+
+ private:
+ StreamExecutor *executor_;
+ ModuleHandle module_handle_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle);
+};
+
////////////
// Inlines
@@ -690,19 +749,13 @@ inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
template <typename T>
inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
- const string &symbol_name) {
- // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
- // be nullptr/0 for consistency with DeviceMemory semantics.
- void *opaque = nullptr;
- size_t bytes = 0;
- if (GetSymbol(symbol_name, &opaque, &bytes)) {
- CHECK_EQ(bytes % sizeof(T), 0);
- return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
+ const string &symbol_name, ModuleHandle module_handle) {
+ port::StatusOr<DeviceMemoryBase> untyped_symbol =
+ GetUntypedSymbol(symbol_name, module_handle);
+ if (!untyped_symbol.ok()) {
+ return untyped_symbol.status();
}
- return port::Status(
- port::error::NOT_FOUND,
- port::StrCat("Check if kernel using the symbol is loaded: ",
- symbol_name));
+ return DeviceMemory<T>(untyped_symbol.ValueOrDie());
}
template <typename ElemT>
diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc
new file mode 100644
index 0000000000..cfc051fd09
--- /dev/null
+++ b/tensorflow/stream_executor/stream_test.cc
@@ -0,0 +1,203 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/stream_executor.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace stream_executor {
+namespace {
+
+class StreamTest : public ::testing::Test {
+ protected:
+ std::unique_ptr<StreamExecutor> NewStreamExecutor() {
+ Platform* platform =
+ MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie();
+ StreamExecutorConfig config(/*ordinal=*/0);
+ return platform->GetUncachedExecutor(config).ConsumeValueOrDie();
+ }
+};
+
+TEST_F(StreamTest, NoInitNotOk) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ EXPECT_FALSE(stream.ok());
+}
+
+TEST_F(StreamTest, InitOk) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+}
+
+TEST_F(StreamTest, OneSubStream) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get and return a sub-stream. Sub-streams are always initialized.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+ stream.ReturnSubStream(sub_stream1);
+
+ // Get and return another sub-stream.
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+ stream.ReturnSubStream(sub_stream1);
+
+ // The underlying sub-streams should be the same, since sub_stream1
+ // was returned before we tried to get sub_stream2.
+ EXPECT_EQ(sub_stream1, sub_stream2);
+}
+
+TEST_F(StreamTest, TwoSubStreams) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get two sub-streams.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+
+ // The underlying sub-streams should be different, since neither
+ // sub-stream has been returned.
+ EXPECT_NE(sub_stream1, sub_stream2);
+
+ // Return sub_stream1 and get sub_stream3, which should be the same.
+ stream.ReturnSubStream(sub_stream1);
+ Stream* sub_stream3 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream3->ok());
+ EXPECT_EQ(sub_stream1, sub_stream3);
+ EXPECT_NE(sub_stream2, sub_stream3);
+
+ // Return sub_stream2 and get sub_stream4, which should be the same.
+ stream.ReturnSubStream(sub_stream2);
+ Stream* sub_stream4 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream4->ok());
+ EXPECT_EQ(sub_stream2, sub_stream4);
+ EXPECT_NE(sub_stream3, sub_stream4);
+}
+
+TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get sub_stream1.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+
+ // Force an error on sub_stream1; here we call a method that requires DNN
+ // support, which we know the Host platform doesn't support.
+ sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(sub_stream1->ok());
+
+ // Return sub_stream1 and get sub_stream2.
+ stream.ReturnSubStream(sub_stream1);
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+
+ // The underlying sub_streams should be different. They would have been the
+ // same, but since we forced an error on sub_stream1, it will not be
+ // re-used. Sadly we can't just check:
+ // EXPECT_NE(sub_stream1, sub_stream2);
+ //
+ // The above should hold logically, but it may fail if the new Stream instance
+ // allocated for sub_stream2 happens to reside in the same memory address as
+ // sub_stream1.
+ //
+ // The check that sub_stream2->ok() serves as a good-enough check.
+
+ // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
+ // has no effect on these streams, and they are the same.
+ stream.ReturnSubStream(sub_stream2);
+ Stream* sub_stream3 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream3->ok());
+ EXPECT_EQ(sub_stream2, sub_stream3);
+}
+
+TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get and return sub_stream1.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+ stream.ReturnSubStream(sub_stream1);
+
+ // Force an error on sub_stream1; here we call a method that requires DNN
+ // support, which we know the Host platform doesn't support.
+ //
+ // It is a bit weird to use sub_stream1 after it has already been returned. By
+ // doing this, we're simulating an asynchronous error that occurs during
+ // execution of the sub_stream, that occurs after the sub_stream is returned.
+ //
+ // E.g. the following is a common pattern of usage, where the execution of the
+ // operations enqueued onto the sub streams may occur after the streams have
+ // already been returned.
+ //
+ // void EnqueueOnSubStreams(Stream* stream) {
+ // Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ // Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ // // ... enqueue some operations on the sub streams ...
+ // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2);
+ // stream.ReturnSubStream(sub_stream1);
+ // stream.ReturnSubStream(sub_stream2);
+ // }
+ //
+ // Stream* main_stream = ...;
+ // EnqueueOnSubStreams(main_stream);
+ // main_stream.BlockHostUntilDone();
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone;
+ // GetOrCreateSubStream can still return a sub-stream that has not encountered
+ // an error yet, but will encounter one in the future, based on previously
+ // enqueued operations.
+ sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(sub_stream1->ok());
+
+ // Get and return sub_stream2.
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+
+ // The underlying streams should be different. They would have been the same,
+ // but since we forced an error on sub_stream1, it will not be re-used. Sadly
+ // we can't just check:
+ // EXPECT_NE(sub_stream1, sub_stream2);
+ //
+ // The above should hold logically, but it may fail if the new stream instance
+ // allocated for sub_stream2 happens to reside in the same memory address as
+ // sub_stream1.
+ //
+ // The check that sub_stream2->ok() serves as a good-enough check.
+
+ // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
+ // has no effect on these streams, and they are the same.
+ stream.ReturnSubStream(sub_stream2);
+ Stream* sub_stream3 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream3->ok());
+ EXPECT_EQ(sub_stream2, sub_stream3);
+}
+
+} // namespace
+} // namespace stream_executor