aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/blas.h66
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc217
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc78
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc16
-rw-r--r--tensorflow/stream_executor/dnn.h15
-rw-r--r--tensorflow/stream_executor/stream.cc304
-rw-r--r--tensorflow/stream_executor/stream.h44
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.cc12
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h2
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc5
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h5
-rw-r--r--tensorflow/stream_executor/stream_test.cc90
12 files changed, 722 insertions, 132 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index ea87744b22..7f851e3646 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1121,6 +1121,40 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ // Batched gemm with strides instead of pointer arrays.
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+
// Computes a matrix-matrix product where one input matrix is Hermitian:
//
// c <- alpha * a * b + beta * c,
@@ -1990,6 +2024,38 @@ class BlasSupport {
int ldb, std::complex<double> beta, \
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \
int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \
+ const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \
+ DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
+ int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \
+ int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \
+ int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, int64 stride_a, \
+ const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \
+ DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
+ int ldc, int64 stride_c, int batch_count); \
bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
uint64 m, uint64 n, std::complex<float> alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 874bf0e8cb..ab7091b3f5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx)
#if CUDA_VERSION >= 8000
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched)
#endif
#if CUDA_VERSION >= 9000
@@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#if CUDA_VERSION >= 9010
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx)
#endif
} // namespace wrap
@@ -643,7 +648,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
}
#endif
cublasStatus_t ret = cublas_func(parent_, blas_, args...);
- if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
+ if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) {
LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
<< ToString(ret);
}
@@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm(
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
&cc_minor);
- // GPUs < sm_70 don't support Volta hardware.
+ // GPUs < sm_70 don't support tensor ops.
if (cc_major >= 7 && TensorOpMathEnabled()) {
use_tensor_ops = true;
}
@@ -2139,6 +2144,10 @@ static bool UsesTensorOps(blas::AlgorithmType algo) {
template <typename InType>
static bool TensorOpsAvailable(int cc_major) {
#if CUDA_VERSION >= 9000
+ // cublas *does* allow tensor ops on inputs that are not fp16, so this is not
+ // strictly correct. We can't simply enable it, though, as that would change
+ // clients' behavior significantly: Using tensor ops on fp32 inputs cause them
+ // to be rounded to fp16.
if (cc_major >= 7 && TensorOpMathEnabled() &&
std::is_same<InType, Eigen::half>::value) {
return true;
@@ -2160,16 +2169,30 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor) &&
cc_major < 5) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major
+ << cc_minor << " devices don't support explicit gemm algorithms.";
return false;
}
if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) {
+ if (std::is_same<InT, Eigen::half>::value) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
+ << algorithm
+ << " uses tensor ops, but tensor ops are not available in sm"
+ << cc_major << "X devices.";
+ } else {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm "
+ << algorithm
+ << " uses tensor ops, but the input data type is not fp16.";
+ }
return false;
}
// Either both 'alpha' and 'beta' need to be pointers to device memory, or
// they need to be both host scalars.
if (alpha.is_pointer() != beta.is_pointer()) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` "
+ "and `beta` is a pointer, but the other is not.";
return false;
}
@@ -2177,6 +2200,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false because "
+ "output_profile_result was given, but we were unable to "
+ "create a CUDATimer.";
return false;
}
}
@@ -2186,6 +2212,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
std::max({m, n, k}) >= 2097153 && cc_major < 7) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn "
+ "<9.2 bug with m, n, or k >= 2097153. See b/79126339.";
return false;
}
#endif
@@ -2211,6 +2239,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
// CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsCUDAStream(stream))) {
+ VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop "
+ "CUDATimer.";
return false;
}
output_profile_result->set_is_valid(true);
@@ -2223,26 +2253,60 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
bool CUDABlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
-// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
-// were first introduced in CUDA 8.
-// Note that when CUDA version and compute capability is not sufficient, we
-// still return the out_algorithms. Caller needs to make sure that in this case,
-// the returned vector is empty.
- for (cublasGemmAlgo_t algo : {
- CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
- CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
- CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
+ // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
+ // were first introduced in CUDA 8.
+ //
+ // Note that when CUDA version and compute capability is not sufficient, we
+ // still return the out_algorithms. Caller needs to make sure that in this
+ // case, the returned vector is empty.
+ *out_algorithms = {
+ CUBLAS_GEMM_DFALT,
+ CUBLAS_GEMM_ALGO0,
+ CUBLAS_GEMM_ALGO1,
+ CUBLAS_GEMM_ALGO2,
+ CUBLAS_GEMM_ALGO3,
+ CUBLAS_GEMM_ALGO4,
+ CUBLAS_GEMM_ALGO5,
+ CUBLAS_GEMM_ALGO6,
+ CUBLAS_GEMM_ALGO7,
#if CUDA_VERSION >= 9000
- CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
- CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
- CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
- CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
- CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
- CUBLAS_GEMM_ALGO2_TENSOR_OP
+ CUBLAS_GEMM_ALGO8,
+ CUBLAS_GEMM_ALGO9,
+ CUBLAS_GEMM_ALGO10,
+ CUBLAS_GEMM_ALGO11,
+ CUBLAS_GEMM_ALGO12,
+ CUBLAS_GEMM_ALGO13,
+ CUBLAS_GEMM_ALGO14,
+ CUBLAS_GEMM_ALGO15,
+ CUBLAS_GEMM_ALGO16,
+ CUBLAS_GEMM_ALGO17,
+ CUBLAS_GEMM_DFALT_TENSOR_OP,
+ CUBLAS_GEMM_ALGO0_TENSOR_OP,
+ CUBLAS_GEMM_ALGO1_TENSOR_OP,
+ CUBLAS_GEMM_ALGO2_TENSOR_OP,
+ CUBLAS_GEMM_ALGO3_TENSOR_OP,
+ CUBLAS_GEMM_ALGO4_TENSOR_OP,
#endif
- }) {
- out_algorithms->push_back(algo);
- }
+#if CUDA_VERSION >= 9200
+ CUBLAS_GEMM_ALGO18,
+ CUBLAS_GEMM_ALGO19,
+ CUBLAS_GEMM_ALGO20,
+ CUBLAS_GEMM_ALGO21,
+ CUBLAS_GEMM_ALGO22,
+ CUBLAS_GEMM_ALGO23,
+ CUBLAS_GEMM_ALGO5_TENSOR_OP,
+ CUBLAS_GEMM_ALGO6_TENSOR_OP,
+ CUBLAS_GEMM_ALGO7_TENSOR_OP,
+ CUBLAS_GEMM_ALGO8_TENSOR_OP,
+ CUBLAS_GEMM_ALGO9_TENSOR_OP,
+ CUBLAS_GEMM_ALGO10_TENSOR_OP,
+ CUBLAS_GEMM_ALGO11_TENSOR_OP,
+ CUBLAS_GEMM_ALGO12_TENSOR_OP,
+ CUBLAS_GEMM_ALGO13_TENSOR_OP,
+ CUBLAS_GEMM_ALGO14_TENSOR_OP,
+ CUBLAS_GEMM_ALGO15_TENSOR_OP,
+#endif
+ };
return true;
}
@@ -2564,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched(
return status.ok();
}
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ bool use_tensor_ops = false;
+#if CUDA_VERSION >= 9000
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor)) {
+ // GPUs < sm_70 don't support tensor ops.
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#if CUDA_VERSION >= 9010
+ if (cc_major >= 5) {
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasGemmStridedBatchedEx, stream,
+ true /* = pointer_mode_host */, true /* = err_on_failure */,
+ use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb),
+ m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a,
+ CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c),
+ CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo);
+ if (ok) {
+ return true;
+ }
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+#endif
+ }
+#endif
+ // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
+ for (int batch = 0; batch < batch_count; ++batch) {
+ const auto *a_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a);
+ const auto *b_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b);
+ auto *c_matrix =
+ reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */,
+ true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF,
+ lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix,
+ SE_CUDA_DATA_HALF, ldc);
+ if (!ok) {
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+ }
+ return true;
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<float> alpha,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 1c3940e92c..55408ab9ab 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1986,15 +1986,14 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmDesc& algorithm_desc,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv,
- const CudnnTensorDescriptor& output_nd,
+ const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc,
ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is
// functionally correct because the convolution is run with the algorithm of
// the last call to this function, but should be fixed anyway.
- conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+ conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled());
// Query the size of the workspace and allocate it.
size_t size_in_bytes;
@@ -2002,8 +2001,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
cudnn.handle(),
/*xDesc=*/input_nd.handle(),
/*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
+ /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(*algorithm_desc),
/*sizeInBytes=*/&size_in_bytes));
+
+ if (TF_PREDICT_FALSE(!algorithm_desc)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No AlgorithmDesc provided");
+ }
+ algorithm_desc->set_scratch_size(size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2028,15 +2033,14 @@ port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(
Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmDesc& algorithm_desc,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv,
- const CudnnTensorDescriptor& output_nd,
+ const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc,
ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is
// functionally correct because the convolution is run with the algorithm of
// the last call to this function, but should be fixed anyway.
- conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+ conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled());
// Query the size of the workspace and allocate it.
size_t size_in_bytes;
@@ -2046,8 +2050,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
/*dyDesc=*/output_nd.handle(),
/*convDesc=*/conv.handle(),
/*dxDesc=*/input_nd.handle(),
- /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
+ /*algo=*/ToConvBackwardDataAlgo(*algorithm_desc),
/*sizeInBytes=*/&size_in_bytes));
+
+ if (TF_PREDICT_FALSE(!algorithm_desc)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No AlgorithmDesc provided");
+ }
+ algorithm_desc->set_scratch_size(size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2072,15 +2082,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace(
port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(
Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmDesc& algorithm_desc,
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
const CudnnConvolutionDescriptor& conv,
- const CudnnTensorDescriptor& output_nd,
+ const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc,
ScratchAllocator* scratch_allocator) {
// TODO(csigg): This has side effects on the convolution descriptor. It is
// functionally correct because the convolution is run with the algorithm of
// the last call to this function, but should be fixed anyway.
- conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+ conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled());
// Query the size of the workspace and allocate it.
size_t size_in_bytes;
@@ -2090,8 +2099,14 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
/*dyDesc=*/output_nd.handle(),
/*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(),
- /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
+ /*algo=*/ToConvBackwardFilterAlgo(*algorithm_desc),
/*sizeInBytes=*/&size_in_bytes));
+
+ if (TF_PREDICT_FALSE(!algorithm_desc)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No AlgorithmDesc provided");
+ }
+ algorithm_desc->set_scratch_size(size_in_bytes);
int64 size_in_bytes_int64 = size_in_bytes;
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
@@ -2138,7 +2153,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
}
auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
- stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc,
scratch_allocator);
if (scratch_or.ok()) {
@@ -2155,11 +2170,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
"while a secondary algorithm is not provided.");
}
- SE_ASSIGN_OR_RETURN(
- *scratch, AllocateCudnnConvolutionForwardWorkspace(
- stream, cudnn, algorithm_config.algorithm_no_scratch(),
- input_nd, filter, conv, output_nd, scratch_allocator));
- return algorithm_config.algorithm_no_scratch();
+ algo_desc = algorithm_config.algorithm_no_scratch();
+ SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
+ stream, cudnn, input_nd, filter, conv,
+ output_nd, &algo_desc, scratch_allocator));
+ return algo_desc;
}
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
@@ -2187,7 +2202,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
}
auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
- stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc,
scratch_allocator);
if (scratch_or.ok()) {
@@ -2204,11 +2219,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
"while a secondary algorithm is not provided.");
}
- SE_ASSIGN_OR_RETURN(
- *scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
- stream, cudnn, algorithm_config.algorithm_no_scratch(),
- input_nd, filter, conv, output_nd, scratch_allocator));
- return algorithm_config.algorithm_no_scratch();
+ algo_desc = algorithm_config.algorithm_no_scratch();
+ SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
+ stream, cudnn, input_nd, filter, conv,
+ output_nd, &algo_desc, scratch_allocator));
+ return algo_desc;
}
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
@@ -2236,7 +2251,7 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
}
auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
- stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc,
scratch_allocator);
if (scratch_or.ok()) {
@@ -2253,11 +2268,11 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
"while a secondary algorithm is not provided.");
}
- SE_ASSIGN_OR_RETURN(*scratch,
- AllocateCudnnConvolutionBackwardFilterWorkspace(
- stream, cudnn, algorithm_config.algorithm(), input_nd,
- filter, conv, output_nd, scratch_allocator));
- return algorithm_config.algorithm_no_scratch();
+ algo_desc = algorithm_config.algorithm_no_scratch();
+ SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
+ stream, cudnn, input_nd, filter, conv,
+ output_nd, &algo_desc, scratch_allocator));
+ return algo_desc;
}
// A helper class to set env-vars and choose options for cudnn-related
@@ -3082,8 +3097,7 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
- // zero-initialized.
- // TODO(timshen): Add an nvbugs/ link.
+ // zero-initialized, nvbugs/2254619.
if (CUDNN_VERSION >= 7000 &&
algorithm_config.algorithm().algo_id() ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index dbece3adf9..f982f34b98 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/human_readable.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -66,14 +67,17 @@ class CreatedContexts {
return Live()->find(context) != Live()->end();
}
- // Adds context to the live set.
+ // Adds context to the live set, or returns it if it's already present.
static CudaContext* Add(CUcontext context) {
CHECK(context != nullptr);
mutex_lock lock(mu_);
- auto cuda_context = new CudaContext(context, next_id_++);
- Live()->insert(
- std::make_pair(context, std::unique_ptr<CudaContext>(cuda_context)));
- return cuda_context;
+ auto insert_result = Live()->insert(std::make_pair(context, nullptr));
+ auto it = insert_result.first;
+ if (insert_result.second) {
+ // context was not present in the map. Add it.
+ it->second = MakeUnique<CudaContext>(context, next_id_++);
+ }
+ return it->second.get();
}
// Removes context from the live set.
@@ -427,7 +431,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
*context = CreatedContexts::Add(new_context);
CHECK(*context != nullptr)
<< "success in this call must entail non-null result";
- VLOG(2) << "created context " << context << " for this thread";
+ VLOG(2) << "created or reused context " << context << " for this thread";
return port::Status::OK();
}
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index a7449c2df4..9abfa1db6a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -713,15 +713,23 @@ class PoolingDescriptor {
class AlgorithmDesc {
public:
typedef int64 Index;
- AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
+ AlgorithmDesc()
+ : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {}
AlgorithmDesc(Index a, bool use_tensor_ops)
- : algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
+ : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {}
+ AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size)
+ : algo_(a),
+ tensor_ops_enabled_(use_tensor_ops),
+ scratch_size_(scratch_size) {}
bool is_default() const { return algo_ == kDefaultAlgorithm; }
bool tensor_ops_enabled() const { return tensor_ops_enabled_; }
Index algo_id() const { return algo_; }
+ size_t scratch_size() const { return scratch_size_; }
+ void set_scratch_size(size_t val) { scratch_size_ = val; }
bool operator==(const AlgorithmDesc& other) const {
return this->algo_ == other.algo_ &&
- this->tensor_ops_enabled_ == other.tensor_ops_enabled_;
+ this->tensor_ops_enabled_ == other.tensor_ops_enabled_ &&
+ this->scratch_size_ == other.scratch_size_;
}
uint64 hash() const;
@@ -729,6 +737,7 @@ class AlgorithmDesc {
enum { kDefaultAlgorithm = -1 };
Index algo_;
bool tensor_ops_enabled_;
+ size_t scratch_size_;
};
// Describes the result from a perf experiment.
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 6248aa2d01..9efd34de24 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) {
}
string ToVlogString(const DeviceMemoryBase *memory) {
- return ToVlogString(*memory);
+ return memory == nullptr ? "null" : ToVlogString(*memory);
}
string ToVlogString(const Eigen::half &h) {
@@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream,
// constructing all the strings in params is expensive.
CHECK(VLOG_IS_ON(1));
- string str = port::StrCat("Called Stream::", function_name, "(");
+ string str = port::StrCat(stream->DebugStreamPointers(),
+ " Called Stream::", function_name, "(");
const char *separator = "";
for (const auto &param : params) {
port::StrAppend(&str, separator, param.first, "=", param.second);
separator = ", ";
}
- port::StrAppend(&str, ") stream=", ToVlogString(stream));
+ port::StrAppend(&str, ")");
if (VLOG_IS_ON(10)) {
port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
}
@@ -267,13 +268,13 @@ Stream::Stream(StreamExecutor *parent,
Stream::~Stream() {
VLOG_CALL();
- temporary_memory_manager_.ForceDeallocateAll();
// Ensure the stream is completed.
auto status = BlockHostUntilDone();
if (!status.ok()) {
LOG(WARNING) << "Error blocking host until done in stream destructor: "
<< status;
}
+ temporary_memory_manager_.ForceDeallocateAll();
if (allocated_) {
parent_->DeallocateStream(this);
@@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer(
Stream *Stream::GetOrCreateSubStream() {
mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (stream.second) {
- stream.second = false;
- return stream.first.get();
+
+ // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
+ // we encounter along the way.
+ for (int64 index = 0; index < sub_streams_.size();) {
+ std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
+ if (pair.second) {
+ // The sub_stream is reusable.
+ Stream *sub_stream = pair.first.get();
+ if (sub_stream->ok()) {
+ VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
+ << sub_stream->DebugStreamPointers();
+ pair.second = false;
+ return sub_stream;
+ }
+
+ // The stream is reusable and not ok. Streams have a monotonic state
+ // machine; the stream will remain in !ok forever. Swap it with the last
+ // stream and pop it off.
+ const int64 last = sub_streams_.size() - 1;
+ if (index != last) {
+ std::swap(pair, sub_streams_[last]);
+ }
+ sub_streams_.pop_back();
+ VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ } else {
+ // The sub_stream is not reusable, move on to the next one.
+ ++index;
}
}
+
+ // No streams are reusable; create a new stream.
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
false);
Stream *sub_stream = sub_streams_.back().first.get();
sub_stream->Init();
CHECK(ok_) << "sub-stream failed to be initialized";
+ VLOG(1) << DebugStreamPointers() << " created new sub_stream "
+ << sub_stream->DebugStreamPointers();
return sub_stream;
}
void Stream::ReturnSubStream(Stream *sub_stream) {
mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (stream.first.get() == sub_stream) {
- // Streams have a monotonic state machine; if a stream
- // encounters an error, it will remain in an error state
- // forever. Only allow re-use of ok streams.
- //
- // TODO(toddw): Improve this mechanism, if necessary, to drop
- // failed streams completely.
- const bool ready_to_reuse = sub_stream->ok();
- stream.second = ready_to_reuse;
- return;
+
+ // Look for the sub-stream.
+ for (int64 index = 0; index < sub_streams_.size(); ++index) {
+ std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
+ if (pair.first.get() != sub_stream) {
+ continue;
}
+
+ // Found the sub_stream.
+ if (sub_stream->ok()) {
+ VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ pair.second = true;
+ } else {
+ // The returned stream is not ok. Streams have a monotonic state
+ // machine; the stream will remain in !ok forever. Swap it with the last
+ // stream and pop it off.
+ VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ const int64 last = sub_streams_.size() - 1;
+ if (index != last) {
+ std::swap(pair, sub_streams_[last]);
+ }
+ sub_streams_.pop_back();
+ }
+ return;
}
- LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
+
+ LOG(FATAL) << DebugStreamPointers()
+ << " did not create the returned sub-stream "
+ << sub_stream->DebugStreamPointers();
}
Stream &Stream::ThenStartTimer(Timer *t) {
@@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) {
if (ok()) {
CheckError(parent_->StartTimer(this, t));
} else {
- LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
+ LOG(INFO) << DebugStreamPointers()
+ << " did not enqueue 'start timer': " << t;
}
return *this;
}
@@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) {
if (ok()) {
CheckError(parent_->StopTimer(this, t));
} else {
- LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
+ LOG(INFO) << DebugStreamPointers()
+ << " did not enqueue 'stop timer': " << t;
}
return *this;
}
@@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) {
CheckError(parent_->CreateStreamDependency(this, other));
} else {
SetError();
- LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
+ LOG(INFO) << DebugStreamPointers() << " did not wait for "
+ << other->DebugStreamPointers();
}
return *this;
}
@@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) {
<< "at fault. Monitor for further errors.";
}
} else {
- LOG(INFO) << "stream " << this << " did not wait for an event.";
+ LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
}
return *this;
}
@@ -4685,6 +4734,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
scratch_allocator);
}
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<Eigen::half> &, int, int64,
+ const DeviceMemory<Eigen::half> &, int, int64, float,
+ DeviceMemory<Eigen::half> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<float> &, int, int64,
+ const DeviceMemory<float> &, int, int64, float,
+ DeviceMemory<float> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
+ const DeviceMemory<double> &, int, int64,
+ const DeviceMemory<double> &, int, int64, double,
+ DeviceMemory<double> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, int64, const DeviceMemory<std::complex<float>> &, int,
+ int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, int64, const DeviceMemory<std::complex<double>> &, int,
+ int64, std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
@@ -4693,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
CheckError(rng->SetSeed(this, seed, seed_bytes));
} else {
SetError();
- LOG(INFO) << "stream " << this << " unable to initialize RNG";
+ LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
}
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not set RNG seed: " << static_cast<const void *>(seed)
<< "; bytes: " << seed_bytes;
}
@@ -4711,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4727,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4743,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4758,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4774,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4790,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "stream " << this
- << " attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4805,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
if (ok()) {
CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy device-to-host; source: " << gpu_src.opaque();
}
return *this;
@@ -4818,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
if (ok()) {
CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy host-to-device; source: " << host_src;
}
return *this;
@@ -4831,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
if (ok()) {
CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy gpu-to-gpu; source: " << &gpu_src;
}
return *this;
@@ -4843,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
if (ok()) {
CheckError(parent_->MemZero(this, location, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memzero GPU location; source: " << location;
}
return *this;
@@ -4856,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
if (ok()) {
CheckError(parent_->Memset32(this, location, pattern, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memset GPU location; source: " << location
<< "; size: " << size << "; pattern: " << std::hex << pattern;
}
@@ -5125,12 +5288,25 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
if (ok()) {
CheckError(parent_->HostCallback(this, callback));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " was in error state before adding host callback";
}
return *this;
}
+Stream &Stream::ThenDoHostCallbackWithStatus(
+ std::function<port::Status()> callback) {
+ VLOG_CALL(PARAM(callback));
+
+ if (ok()) {
+ CheckError(parent_->HostCallback(this, std::move(callback)));
+ } else {
+ LOG(WARNING) << "stream " << DebugStreamPointers()
+ << " was in error state before adding host callback";
+ }
+ return *this;
+}
+
Stream &Stream::ThenFft(fft::Plan *plan,
const DeviceMemory<std::complex<float>> &input,
DeviceMemory<std::complex<float>> *output) {
@@ -5141,8 +5317,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5158,8 +5335,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5174,8 +5352,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5190,8 +5369,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5207,8 +5387,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5224,8 +5405,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5252,7 +5434,7 @@ port::Status Stream::BlockHostUntilDone() {
port::Status status = port::Status(
port::error::INTERNAL,
"stream did not block host until done; was already in an error state");
- LOG(INFO) << status << " " << this;
+ LOG(INFO) << DebugStreamPointers() << " " << status;
return status;
}
@@ -5263,4 +5445,10 @@ port::Status Stream::BlockHostUntilDone() {
return error;
}
+string Stream::DebugStreamPointers() const {
+ // Relies on the ToVlogString(const void*) overload above.
+ return port::StrCat("[stream=", ToVlogString(this),
+ ",impl=", ToVlogString(implementation_.get()), "]");
+}
+
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 706442a666..e1629b5b30 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -122,10 +122,14 @@ class Stream {
// Get or create a sub-stream from this stream. If there is any sub-stream in
// the pool that can be reused then just return this sub-stream. Otherwise
// create a new sub-stream.
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
// Return the sub-stream back to the host stream so that it can be reused
// later. Sub-streams that are !ok() will not be reused.
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
// Allocate temporary memories. The stream will deallocate them when blocked
@@ -1557,6 +1561,38 @@ class Stream {
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count);
// See BlasSupport::DoBlasHemm.
Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
@@ -2009,6 +2045,11 @@ class Stream {
// negative effects on performance.
Stream &ThenDoHostCallback(std::function<void()> callback);
+ // Entrains onto the stream a callback to the host (from the device).
+ // Behaves as ThenDoHostCallback above, but returns a Status instead of void.
+ // This overload should be preferred if the callback could fail.
+ Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
+
// Returns the StreamExecutor (parent object) associated with this stream.
StreamExecutor *parent() const {
CHECK(parent_ != nullptr);
@@ -2019,6 +2060,9 @@ class Stream {
// with this stream.
internal::TemporaryMemoryManager *temporary_memory_manager();
+ // Returns a debugging string "[stream=0x...,impl=0x...]".
+ string DebugStreamPointers() const;
+
private:
friend class host::HostBlas; // for parent_.
friend class host::HostFft; // for parent_.
diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc
index 8297228e6f..7df6a361c6 100644
--- a/tensorflow/stream_executor/stream_executor_internal.cc
+++ b/tensorflow/stream_executor/stream_executor_internal.cc
@@ -36,5 +36,17 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() {
StreamExecutorFactory MakeHostExecutorImplementation;
+// TODO(b/112125301): Consolodate this down to one implementation of
+// HostCallback, taking a callback that returns a Status.
+bool StreamExecutorInterface::HostCallback(
+ Stream* stream, std::function<port::Status()> callback) {
+ return HostCallback(stream, [callback]() {
+ port::Status s = callback();
+ if (!s.ok()) {
+ LOG(WARNING) << "HostCallback failed: " << s;
+ }
+ });
+}
+
} // namespace internal
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 52b3dc04c4..59a477b5c9 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -239,6 +239,8 @@ class StreamExecutorInterface {
const DeviceMemoryBase &gpu_src,
uint64 size) = 0;
virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0;
+ virtual bool HostCallback(Stream *stream,
+ std::function<port::Status()> callback);
virtual port::Status AllocateEvent(Event *event) = 0;
virtual port::Status DeallocateEvent(Event *event) = 0;
virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 2e0137a485..9515d8e62a 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -699,6 +699,11 @@ bool StreamExecutor::HostCallback(Stream *stream,
return implementation_->HostCallback(stream, std::move(callback));
}
+bool StreamExecutor::HostCallback(Stream *stream,
+ std::function<port::Status()> callback) {
+ return implementation_->HostCallback(stream, std::move(callback));
+}
+
port::Status StreamExecutor::AllocateEvent(Event *event) {
return implementation_->AllocateEvent(event);
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 47b3a2b030..437f298616 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -549,6 +549,11 @@ class StreamExecutor {
// See Stream::ThenDoHostCallback for full details.
bool HostCallback(Stream *stream, std::function<void()> callback);
+ // Entrains on a stream a user-specified function to be run on the host.
+ // See Stream::ThenDoHostCallback for full details.
+ // This is the preferred form for a callback that may return an error.
+ bool HostCallback(Stream *stream, std::function<port::Status()> callback);
+
// Performs platform-specific allocation and initialization of an event.
port::Status AllocateEvent(Event *event);
diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc
index 47dd675834..cfc051fd09 100644
--- a/tensorflow/stream_executor/stream_test.cc
+++ b/tensorflow/stream_executor/stream_test.cc
@@ -95,18 +95,18 @@ TEST_F(StreamTest, TwoSubStreams) {
EXPECT_NE(sub_stream3, sub_stream4);
}
-TEST_F(StreamTest, FailedSubStreamNotReused) {
+TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) {
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
Stream stream(executor.get());
stream.Init();
EXPECT_TRUE(stream.ok());
- // Get a sub-stream.
+ // Get sub_stream1.
Stream* sub_stream1 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream1->ok());
- // Force an error on the stream; here we call a method that requires
- // DNN support, which we know the Host platform doesn't support.
+ // Force an error on sub_stream1; here we call a method that requires DNN
+ // support, which we know the Host platform doesn't support.
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
EXPECT_FALSE(sub_stream1->ok());
@@ -115,20 +115,84 @@ TEST_F(StreamTest, FailedSubStreamNotReused) {
Stream* sub_stream2 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream2->ok());
- // The underlying streams should be different. They would have been
- // the same, but since we forced an error on sub_stream1, it will
- // not be re-used. Sadly we can't just check:
+ // The underlying sub_streams should be different. They would have been the
+ // same, but since we forced an error on sub_stream1, it will not be
+ // re-used. Sadly we can't just check:
// EXPECT_NE(sub_stream1, sub_stream2);
//
- // The above should hold logically, but it may fail if the new
- // stream instance allocated for sub_stream2 happens to reside in
- // the same memory address as sub_stream1.
+ // The above should hold logically, but it may fail if the new Stream instance
+ // allocated for sub_stream2 happens to reside in the same memory address as
+ // sub_stream1.
//
// The check that sub_stream2->ok() serves as a good-enough check.
- // Return sub_stream2 and get sub_stream3. The previous error on
- // sub_stream1 has no effect on these streams, and they are the
- // same.
+ // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
+ // has no effect on these streams, and they are the same.
+ stream.ReturnSubStream(sub_stream2);
+ Stream* sub_stream3 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream3->ok());
+ EXPECT_EQ(sub_stream2, sub_stream3);
+}
+
+TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get and return sub_stream1.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+ stream.ReturnSubStream(sub_stream1);
+
+ // Force an error on sub_stream1; here we call a method that requires DNN
+ // support, which we know the Host platform doesn't support.
+ //
+ // It is a bit weird to use sub_stream1 after it has already been returned. By
+ // doing this, we're simulating an asynchronous error that occurs during
+ // execution of the sub_stream, that occurs after the sub_stream is returned.
+ //
+ // E.g. the following is a common pattern of usage, where the execution of the
+ // operations enqueued onto the sub streams may occur after the streams have
+ // already been returned.
+ //
+ // void EnqueueOnSubStreams(Stream* stream) {
+ // Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ // Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ // // ... enqueue some operations on the sub streams ...
+ // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2);
+ // stream.ReturnSubStream(sub_stream1);
+ // stream.ReturnSubStream(sub_stream2);
+ // }
+ //
+ // Stream* main_stream = ...;
+ // EnqueueOnSubStreams(main_stream);
+ // main_stream.BlockHostUntilDone();
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone;
+ // GetOrCreateSubStream can still return a sub-stream that has not encountered
+ // an error yet, but will encounter one in the future, based on previously
+ // enqueued operations.
+ sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(sub_stream1->ok());
+
+ // Get and return sub_stream2.
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+
+ // The underlying streams should be different. They would have been the same,
+ // but since we forced an error on sub_stream1, it will not be re-used. Sadly
+ // we can't just check:
+ // EXPECT_NE(sub_stream1, sub_stream2);
+ //
+ // The above should hold logically, but it may fail if the new stream instance
+ // allocated for sub_stream2 happens to reside in the same memory address as
+ // sub_stream1.
+ //
+ // The check that sub_stream2->ok() serves as a good-enough check.
+
+ // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
+ // has no effect on these streams, and they are the same.
stream.ReturnSubStream(sub_stream2);
Stream* sub_stream3 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream3->ok());