aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/BUILD39
-rw-r--r--tensorflow/stream_executor/blas.cc57
-rw-r--r--tensorflow/stream_executor/blas.h1780
-rw-r--r--tensorflow/stream_executor/cuda/cuda_activation.cc30
-rw-r--r--tensorflow/stream_executor/cuda/cuda_activation.h53
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc2184
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h100
-rw-r--r--tensorflow/stream_executor/cuda/cuda_diagnostics.cc260
-rw-r--r--tensorflow/stream_executor/cuda/cuda_diagnostics.h85
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc1074
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h206
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc1608
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.h460
-rw-r--r--tensorflow/stream_executor/cuda/cuda_event.cc56
-rw-r--r--tensorflow/stream_executor/cuda/cuda_event.h49
-rw-r--r--tensorflow/stream_executor/cuda/cuda_fft.cc327
-rw-r--r--tensorflow/stream_executor/cuda/cuda_fft.h95
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc1082
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h270
-rw-r--r--tensorflow/stream_executor/cuda/cuda_helpers.h95
-rw-r--r--tensorflow/stream_executor/cuda/cuda_kernel.h115
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform.cc172
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform.h98
-rw-r--r--tensorflow/stream_executor/cuda/cuda_rng.cc317
-rw-r--r--tensorflow/stream_executor/cuda/cuda_rng.h89
-rw-r--r--tensorflow/stream_executor/cuda/cuda_stream.cc51
-rw-r--r--tensorflow/stream_executor/cuda/cuda_stream.h74
-rw-r--r--tensorflow/stream_executor/cuda/cuda_timer.cc73
-rw-r--r--tensorflow/stream_executor/cuda/cuda_timer.h69
-rw-r--r--tensorflow/stream_executor/cuda/multi_op_activation.h16
-rw-r--r--tensorflow/stream_executor/device_description.cc221
-rw-r--r--tensorflow/stream_executor/device_description.h370
-rw-r--r--tensorflow/stream_executor/device_memory.h284
-rw-r--r--tensorflow/stream_executor/device_options.h70
-rw-r--r--tensorflow/stream_executor/dnn.cc297
-rw-r--r--tensorflow/stream_executor/dnn.h895
-rw-r--r--tensorflow/stream_executor/dso_loader.cc208
-rw-r--r--tensorflow/stream_executor/dso_loader.h107
-rw-r--r--tensorflow/stream_executor/event.cc48
-rw-r--r--tensorflow/stream_executor/event.h63
-rw-r--r--tensorflow/stream_executor/executor_cache.cc43
-rw-r--r--tensorflow/stream_executor/executor_cache.h45
-rw-r--r--tensorflow/stream_executor/fft.h187
-rw-r--r--tensorflow/stream_executor/gcuda.cc87
-rw-r--r--tensorflow/stream_executor/gcuda.h415
-rw-r--r--tensorflow/stream_executor/gpu_launch_dim.h8
-rw-r--r--tensorflow/stream_executor/kernel.cc95
-rw-r--r--tensorflow/stream_executor/kernel.h499
-rw-r--r--tensorflow/stream_executor/kernel_cache_config.h29
-rw-r--r--tensorflow/stream_executor/kernel_spec.cc236
-rw-r--r--tensorflow/stream_executor/kernel_spec.h365
-rw-r--r--tensorflow/stream_executor/launch_dim.h65
-rw-r--r--tensorflow/stream_executor/lib/array_slice.h17
-rw-r--r--tensorflow/stream_executor/lib/casts.h85
-rw-r--r--tensorflow/stream_executor/lib/demangle.cc38
-rw-r--r--tensorflow/stream_executor/lib/demangle.h16
-rw-r--r--tensorflow/stream_executor/lib/env.h29
-rw-r--r--tensorflow/stream_executor/lib/error.h16
-rw-r--r--tensorflow/stream_executor/lib/human_readable.h58
-rw-r--r--tensorflow/stream_executor/lib/initialize.h35
-rw-r--r--tensorflow/stream_executor/lib/inlined_vector.h16
-rw-r--r--tensorflow/stream_executor/lib/mathutil.h88
-rw-r--r--tensorflow/stream_executor/lib/notification.h16
-rw-r--r--tensorflow/stream_executor/lib/numbers.cc27
-rw-r--r--tensorflow/stream_executor/lib/numbers.h19
-rw-r--r--tensorflow/stream_executor/lib/path.cc50
-rw-r--r--tensorflow/stream_executor/lib/path.h44
-rw-r--r--tensorflow/stream_executor/lib/process_state.cc37
-rw-r--r--tensorflow/stream_executor/lib/process_state.h17
-rw-r--r--tensorflow/stream_executor/lib/ptr_util.h48
-rw-r--r--tensorflow/stream_executor/lib/stacktrace.h18
-rw-r--r--tensorflow/stream_executor/lib/static_threadlocal.h30
-rw-r--r--tensorflow/stream_executor/lib/status.h23
-rw-r--r--tensorflow/stream_executor/lib/status_macros.h54
-rw-r--r--tensorflow/stream_executor/lib/statusor.h234
-rw-r--r--tensorflow/stream_executor/lib/str_util.h30
-rw-r--r--tensorflow/stream_executor/lib/strcat.h17
-rw-r--r--tensorflow/stream_executor/lib/stringpiece.h17
-rw-r--r--tensorflow/stream_executor/lib/stringprintf.h18
-rw-r--r--tensorflow/stream_executor/lib/thread_options.h16
-rw-r--r--tensorflow/stream_executor/lib/threadpool.h19
-rw-r--r--tensorflow/stream_executor/machine_manager.cc276
-rw-r--r--tensorflow/stream_executor/machine_manager.h197
-rw-r--r--tensorflow/stream_executor/multi_platform_manager.cc66
-rw-r--r--tensorflow/stream_executor/multi_platform_manager.h144
-rw-r--r--tensorflow/stream_executor/platform.cc115
-rw-r--r--tensorflow/stream_executor/platform.h185
-rw-r--r--tensorflow/stream_executor/platform/default/mutex.h60
-rw-r--r--tensorflow/stream_executor/platform/logging.h21
-rw-r--r--tensorflow/stream_executor/platform/mutex.h12
-rw-r--r--tensorflow/stream_executor/platform/port.h40
-rw-r--r--tensorflow/stream_executor/platform/thread_annotations.h6
-rw-r--r--tensorflow/stream_executor/plugin.cc40
-rw-r--r--tensorflow/stream_executor/plugin.h74
-rw-r--r--tensorflow/stream_executor/plugin_registry.cc228
-rw-r--r--tensorflow/stream_executor/plugin_registry.h155
-rw-r--r--tensorflow/stream_executor/rng.cc36
-rw-r--r--tensorflow/stream_executor/rng.h80
-rw-r--r--tensorflow/stream_executor/shared_memory_config.h21
-rw-r--r--tensorflow/stream_executor/stream.cc3329
-rw-r--r--tensorflow/stream_executor/stream.h1258
-rw-r--r--tensorflow/stream_executor/stream_executor.h50
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.cc65
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h364
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc642
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h725
-rw-r--r--tensorflow/stream_executor/temporary_device_memory.cc53
-rw-r--r--tensorflow/stream_executor/temporary_device_memory.h123
-rw-r--r--tensorflow/stream_executor/temporary_memory_manager.cc113
-rw-r--r--tensorflow/stream_executor/temporary_memory_manager.h138
-rw-r--r--tensorflow/stream_executor/timer.cc41
-rw-r--r--tensorflow/stream_executor/timer.h60
-rw-r--r--tensorflow/stream_executor/trace_listener.h59
113 files changed, 25529 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
new file mode 100644
index 0000000000..b91fe431f6
--- /dev/null
+++ b/tensorflow/stream_executor/BUILD
@@ -0,0 +1,39 @@
+licenses(["restricted"])
+
+load("/tensorflow/tensorflow", "if_cuda")
+
+cc_library(
+ name = "stream_executor",
+ srcs = glob(
+ [
+ "*.cc",
+ "lib/*.cc",
+ ],
+ exclude = [
+ "**/*_test.cc",
+ ],
+ ) + if_cuda(
+ glob([
+ "cuda/*.cc",
+ ]),
+ ),
+ hdrs = glob([
+ "*.h",
+ "lib/*.h",
+ "platform/**/*.h",
+ ]),
+ data = [
+ "//tensorflow/core:cuda",
+ "//third_party/gpus/cuda:cublas",
+ "//third_party/gpus/cuda:cudnn",
+ ],
+ linkopts = [
+ "-ldl",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//third_party/gpus/cuda:cuda_headers",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc
new file mode 100644
index 0000000000..70a6bb7030
--- /dev/null
+++ b/tensorflow/stream_executor/blas.cc
@@ -0,0 +1,57 @@
+#include "tensorflow/stream_executor/blas.h"
+
+#include "tensorflow/stream_executor/lib/strcat.h"
+
+namespace perftools {
+namespace gputools {
+namespace blas {
+
+string TransposeString(Transpose t) {
+ switch (t) {
+ case Transpose::kNoTranspose:
+ return "NoTranspose";
+ case Transpose::kTranspose:
+ return "Transpose";
+ case Transpose::kConjugateTranspose:
+ return "ConjugateTranspose";
+ default:
+ LOG(FATAL) << "Unknown transpose " << static_cast<int32>(t);
+ }
+}
+
+string UpperLowerString(UpperLower ul) {
+ switch (ul) {
+ case UpperLower::kUpper:
+ return "Upper";
+ case UpperLower::kLower:
+ return "Lower";
+ default:
+ LOG(FATAL) << "Unknown upperlower " << static_cast<int32>(ul);
+ }
+}
+
+string DiagonalString(Diagonal d) {
+ switch (d) {
+ case Diagonal::kUnit:
+ return "Unit";
+ case Diagonal::kNonUnit:
+ return "NonUnit";
+ default:
+ LOG(FATAL) << "Unknown diagonal " << static_cast<int32>(d);
+ }
+}
+
+string SideString(Side s) {
+ switch (s) {
+ case Side::kLeft:
+ return "Left";
+ case Side::kRight:
+ return "Right";
+ default:
+ LOG(FATAL) << "Unknown side " << static_cast<int32>(s);
+ }
+}
+
+} // namespace blas
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
new file mode 100644
index 0000000000..f6ee29837d
--- /dev/null
+++ b/tensorflow/stream_executor/blas.h
@@ -0,0 +1,1780 @@
+// Exposes the family of BLAS routines as pre-canned high performance calls for
+// use in conjunction with the StreamExecutor abstraction.
+//
+// Note that this interface is optionally supported by platforms; see
+// StreamExecutor::SupportsBlas() for details.
+//
+// This abstraction makes it simple to entrain BLAS operations on GPU data into
+// a Stream -- users typically will not use this API directly, but will use the
+// Stream builder methods to entrain these operations "under the hood". For
+// example:
+//
+// DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024);
+// DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024);
+// // ... populate x and y ...
+// Stream stream{stream_exec};
+// stream
+// .Init()
+// .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1)
+// .BlockHostUntilDone();
+//
+// By using stream operations in this manner the user can easily intermix custom
+// kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS
+// routines.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
+
+#include <complex>
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+
+template <typename ElemT>
+class DeviceMemory;
+
+namespace blas {
+
+// Specifies whether the input matrix will be transposed or
+// transposed+conjugated before any BLAS operations.
+enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
+
+// Returns a name for t.
+string TransposeString(Transpose t);
+
+// Specifies whether the upper or lower triangular part of a
+// symmetric/Hermitian matrix is used.
+enum class UpperLower { kUpper, kLower };
+
+// Returns a name for ul.
+string UpperLowerString(UpperLower ul);
+
+// Specifies whether a matrix is unit triangular.
+enum class Diagonal { kUnit, kNonUnit };
+
+// Returns a name for d.
+string DiagonalString(Diagonal d);
+
+// Specifies whether a Hermitian matrix appears on the left or right in
+// operation.
+enum class Side { kLeft, kRight };
+
+// Returns a name for s.
+string SideString(Side s);
+
+// BLAS support interface -- this can be derived from a GPU executor when the
+// underlying platform has an BLAS library implementation available. See
+// StreamExecutor::AsBlas().
+//
+// Thread-hostile: CUDA associates a CUDA-context with a particular thread in
+// the system. Any operation that a user attempts to perform by enqueueing BLAS
+// operations on a thread not-associated with the CUDA-context has unknown
+// behavior at the current time; see b/13176597
+class BlasSupport {
+ public:
+ virtual ~BlasSupport() {}
+
+ // Computes the sum of magnitudes of the vector elements.
+ // result <- |Re x(1)| + |Im x(1)| + |Re x(2)| + |Im x(2)|+ ... + |Re x(n)|
+ // + |Im x(n)|.
+ // Note that Im x(i) = 0 for real types float/double.
+ virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *result) = 0;
+ virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *result) = 0;
+ virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result) = 0;
+ virtual bool DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result) = 0;
+
+ // Performs a BLAS y <- ax+y operation.
+ virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *y, int incy) = 0;
+ virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Copies vector to another vector: y <- x.
+ virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *y, int incy) = 0;
+ virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Performs a BLAS dot product result <- x . y.
+ virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *result) = 0;
+ virtual bool DoBlasDot(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *result) = 0;
+
+ // Performs a BLAS dot product result <- conj(x) . y for complex types.
+ virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result) = 0;
+ virtual bool DoBlasDotc(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result) = 0;
+
+ // Performs a BLAS dot product result <- x . y for complex types. Note that
+ // x is unconjugated in this routine.
+ virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result) = 0;
+ virtual bool DoBlasDotu(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result) = 0;
+
+ // Computes the Euclidean norm of a vector: result <- ||x||.
+ // See the following link for more information of Euclidean norm:
+ // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm
+ virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *result) = 0;
+ virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *result) = 0;
+ virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result) = 0;
+ virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result) = 0;
+
+ // Performs rotation of points in the plane:
+ // x(i) = c*x(i) + s*y(i)
+ // y(i) = c*y(i) - s*x(i).
+ virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy, float c,
+ float s) = 0;
+ virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy, double c,
+ double s) = 0;
+ virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<float>> *x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy,
+ float c, float s) = 0;
+ virtual bool DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<double>> *x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy,
+ double c, double s) = 0;
+
+ // Computes the parameters for a Givens rotation.
+ // Given the Cartesian coordinates (a, b) of a point, these routines return
+ // the parameters c, s, r, and z associated with the Givens rotation. The
+ // parameters c and s define a unitary matrix such that:
+ //
+ // | c s |.| a | = | r |
+ // | -s c | | b | | 0 |
+ //
+ // The parameter z is defined such that if |a| > |b|, z is s; otherwise if
+ // c is not 0 z is 1/c; otherwise z is 1.
+ virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
+ DeviceMemory<float> *b, DeviceMemory<float> *c,
+ DeviceMemory<float> *s) = 0;
+ virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
+ DeviceMemory<double> *b, DeviceMemory<double> *c,
+ DeviceMemory<double> *s) = 0;
+ virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
+ DeviceMemory<std::complex<float>> *b,
+ DeviceMemory<float> *c,
+ DeviceMemory<std::complex<float>> *s) = 0;
+ virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
+ DeviceMemory<std::complex<double>> *b,
+ DeviceMemory<double> *c,
+ DeviceMemory<std::complex<double>> *s) = 0;
+
+ // Performs modified Givens rotation of points in the plane.
+ // Given two vectors x and y, each vector element of these vectors is replaced
+ // as follows:
+ //
+ // | x(i) | = H | x(i) |
+ // | y(i) | | y(i) |
+ //
+ // for i=1 to n, where H is a modified Givens transformation matrix whose
+ // values are stored in the param[1] through param[4] array.
+ // For more information please Google this routine.
+ virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
+ DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy,
+ const DeviceMemory<float> &param) = 0;
+ virtual bool DoBlasRotm(Stream *stream, uint64 elem_count,
+ DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy,
+ const DeviceMemory<double> &param) = 0;
+
+ // Computes the parameters for a modified Givens rotation.
+ // Given Cartesian coordinates (x1, y1) of an input vector, these routines
+ // compute the components of a modified Givens transformation matrix H that
+ // zeros the y-component of the resulting vector:
+ //
+ // | x1 | = H | x1 * sqrt(d1) |
+ // | 0 | | y1 * sqrt(d1) |
+ //
+ // For more information please Google this routine.
+ virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
+ DeviceMemory<float> *d2, DeviceMemory<float> *x1,
+ const DeviceMemory<float> &y1,
+ DeviceMemory<float> *param) = 0;
+ virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
+ DeviceMemory<double> *d2, DeviceMemory<double> *x1,
+ const DeviceMemory<double> &y1,
+ DeviceMemory<double> *param) = 0;
+
+ // Computes the product of a vector by a scalar: x <- a*x.
+ virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
+ DeviceMemory<float> *x, int incx) = 0;
+ virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
+ DeviceMemory<double> *x, int incx) = 0;
+ virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
+ DeviceMemory<std::complex<float>> *x, int incx) = 0;
+ virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
+ DeviceMemory<std::complex<double>> *x, int incx) = 0;
+ virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
+ std::complex<float> alpha,
+ DeviceMemory<std::complex<float>> *x, int incx) = 0;
+ virtual bool DoBlasScal(Stream *stream, uint64 elem_count,
+ std::complex<double> alpha,
+ DeviceMemory<std::complex<double>> *x, int incx) = 0;
+
+ // Swaps a vector with another vector.
+ virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy) = 0;
+ virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<float>> *x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<double>> *x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Finds the index of the element with maximum absolute value.
+ virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<int> *result) = 0;
+ virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<int> *result) = 0;
+ virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result) = 0;
+ virtual bool DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<int> *result) = 0;
+
+ // Finds the index of the element with minimum absolute value.
+ virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<int> *result) = 0;
+ virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<int> *result) = 0;
+ virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result) = 0;
+ virtual bool DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<int> *result) = 0;
+
+ // Computes a matrix-vector product using a general band matrix:
+ //
+ // y <- alpha * a * x + beta * y,
+ // or
+ // y <- alpha * a' * x + beta * y,
+ // or
+ // y <- alpha * conj(a') * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an m-by-n general band matrix, with kl
+ // sub-diagonals and ku super-diagonals; x is a vector with
+ // n(trans==kNoTranspose)/m(otherwise) elements;
+ // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
+ virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy) = 0;
+ virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Computes a matrix-vector product using a general matrix.
+ //
+ // y <- alpha * a * x + beta * y,
+ // or
+ // y <- alpha * a' * x + beta * y,
+ // or
+ // y <- alpha * conj(a') * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector
+ // with n(trans==kNoTranspose)/m(otherwise) elements;
+ // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements.
+ virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x, int incx,
+ float beta, DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) = 0;
+ virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Performs a rank-1 update of a general matrix.
+ //
+ // a <- alpha * x * y' + a,
+ //
+ // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
+ // an m-by-n general matrix.
+ virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda) = 0;
+ virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda) = 0;
+
+ // Performs a rank-1 update (conjugated) of a general matrix.
+ //
+ // a <- alpha * x * conj(y') + a,
+ //
+ // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
+ // an m-by-n general matrix.
+ virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda) = 0;
+ virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda) = 0;
+
+ // Performs a rank-1 update (unconjugated) of a general matrix.
+ //
+ // a <- alpha * x * y' + a,
+ //
+ // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is
+ // an m-by-n general matrix.
+ virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda) = 0;
+ virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda) = 0;
+
+ // Computes a matrix-vector product using a Hermitian band matrix.
+ //
+ // y <- alpha * a * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k
+ // super-diagonals; x and y are n-element vectors.
+ virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Computes a matrix-vector product using a Hermitian matrix.
+ //
+ // y <- alpha * a * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are
+ // n-element vectors.
+ virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Performs a rank-1 update of a Hermitian matrix.
+ //
+ // a <- alpha * x * conj(x') + a,
+ //
+ // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
+ // matrix.
+ virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *a, int lda) = 0;
+ virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *a, int lda) = 0;
+
+ // Performs a rank-2 update of a Hermitian matrix.
+ //
+ // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
+ //
+ // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
+ // matrix.
+ virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda) = 0;
+ virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda) = 0;
+
+ // Computes a matrix-vector product using a Hermitian packed matrix.
+ //
+ // y <- alpha * a * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in
+ // packed form; x and y are n-element vectors.
+ virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &ap,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) = 0;
+ virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &ap,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) = 0;
+
+ // Performs a rank-1 update of a Hermitian packed matrix.
+ //
+ // a <- alpha * x * conj(x') + a,
+ //
+ // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian
+ // matrix, supplied in packed form.
+ virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *ap) = 0;
+ virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *ap) = 0;
+
+ // Performs a rank-2 update of a Hermitian packed matrix.
+ //
+ // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a,
+ //
+ // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian
+ // matrix, supplied in packed form.
+ virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *ap) = 0;
+ virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *ap) = 0;
+
+ // Computes a matrix-vector product using a symmetric band matrix.
+ //
+ // y <- alpha * a * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k
+ // super-diagonals; x and y are n-element vectors.
+ virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x, int incx,
+ float beta, DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) = 0;
+
+ // Computes a matrix-vector product using a symmetric packed matrix.
+ //
+ // y <- alpha * a * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in
+ // packed form; x and y are n-element vectors.
+ virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &ap,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &ap,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy) = 0;
+
+ // Performs a rank-1 update of a symmetric packed matrix.
+ //
+ // a <- alpha * x * x' + a,
+ //
+ // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
+ // matrix, supplied in packed form.
+ virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *ap) = 0;
+ virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *ap) = 0;
+
+ // Performs a rank-2 update of a symmetric packed matrix.
+ //
+ // a <- alpha * x * x' + alpha * y * x' + a,
+ //
+ // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
+ // matrix, supplied in packed form.
+ virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *ap) = 0;
+ virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *ap) = 0;
+
+ // Computes a matrix-vector product for a symmetric matrix.
+ //
+ // y <- alpha * a * x + beta * y,
+ //
+ // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are
+ // n-element vectors.
+ virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) = 0;
+ virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy) = 0;
+
+ // Performs a rank-1 update of a symmetric matrix.
+ //
+ // a <- alpha * x * x' + a,
+ //
+ // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric
+ // matrix.
+ virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *a, int lda) = 0;
+ virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *a, int lda) = 0;
+
+ // Performs a rank-2 update of symmetric matrix.
+ //
+ // a <- alpha * x * x' + alpha * y * x' + a,
+ //
+ // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric
+ // matrix.
+ virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda) = 0;
+ virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda) = 0;
+
+ // Computes a matrix-vector product using a triangular band matrix.
+ //
+ // x <- a * x,
+ // or
+ // x <- a' * x,
+ // or
+ // x <- conj(a') * x,
+ //
+ // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix,
+ // with k+1 diagonals; x is a n-element vector.
+ virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) = 0;
+ virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) = 0;
+ virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) = 0;
+ virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) = 0;
+
+ // Solves a system of linear equations whose coefficients are in a triangular
+ // band matrix as below:
+ //
+ // a * x = b,
+ // or
+ // a' * x = b,
+ // or
+ // conj(a') * x = b,
+ //
+ // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
+ // lower triangular band matrix, with k+1 diagonals.
+ virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) = 0;
+ virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) = 0;
+ virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) = 0;
+ virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) = 0;
+
+ // Computes a matrix-vector product using a triangular packed matrix.
+ //
+ // x <- a * x,
+ // or
+ // x <- a' * x,
+ // or
+ // x <- conj(a') * x,
+ //
+ // a is an n-by-n unit, or non-unit, upper or lower triangular matrix,
+ // supplied in packed form; x is a n-element vector.
+ virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx) = 0;
+ virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap,
+ DeviceMemory<double> *x, int incx) = 0;
+ virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx) = 0;
+ virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx) = 0;
+
+ // Solves a system of linear equations whose coefficients are in a triangular
+ // packed matrix as below:
+ //
+ // a * x = b,
+ // or
+ // a' * x = b,
+ // or
+ // conj(a') * x = b,
+ //
+ // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
+ // lower triangular matrix, supplied in packed form.
+ virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx) = 0;
+ virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap,
+ DeviceMemory<double> *x, int incx) = 0;
+ virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx) = 0;
+ virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx) = 0;
+
+ // Computes a matrix-vector product using a triangular matrix.
+ //
+ // x <- a * x,
+ // or
+ // x <- a' * x,
+ // or
+ // x <- conj(a') * x,
+ //
+ // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a
+ // n-element vector.
+ virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) = 0;
+ virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) = 0;
+ virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx) = 0;
+ virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx) = 0;
+
+ // Solves a system of linear equations whose coefficients are in a triangular
+ // matrix as below:
+ //
+ // a * x = b,
+ // or
+ // a' * x = b,
+ // or
+ // conj(a') * x = b,
+ //
+ // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or
+ // lower triangular matrix.
+ virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) = 0;
+ virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) = 0;
+ virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx) = 0;
+ virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx) = 0;
+
+ // Computes a matrix-matrix product with general matrices:
+ //
+ // c <- alpha * op(a) * op(b) + beta * c,
+ //
+ // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and
+ // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix;
+ // op(b) is a k-by-n matrix; c is an m-by-n matrix.
+ virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) = 0;
+ virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc) = 0;
+ virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) = 0;
+ virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) = 0;
+
+ // Computes a batch of matrix-matrix product with general matrices.
+ // This is a batched version of DoBlasGemm.
+ // The batched GEMM computes matrix product for each input/output in a, b,
+ // and c, which contain batch_count DeviceMemory objects.
+ virtual bool DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha,
+ const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
+ const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
+ std::complex<float> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
+ std::complex<double> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
+ int batch_count) = 0;
+
+ // Computes a matrix-matrix product where one input matrix is Hermitian:
+ //
+ // c <- alpha * a * b + beta * c,
+ // or
+ // c <- alpha * b * a + beta * c,
+ //
+ // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n
+ // matrices.
+ virtual bool DoBlasHemm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) = 0;
+ virtual bool DoBlasHemm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) = 0;
+
+ // Performs a Hermitian rank-k update.
+ //
+ // c <- alpha * a * conj(a') + beta * c,
+ // or
+ // c <- alpha * conj(a') * a + beta * c,
+ //
+ // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k
+ // matrix in the first case and a k-by-n matrix in the second case.
+ virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ float alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc) = 0;
+ virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ double alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc) = 0;
+
+ // Performs a Hermitian rank-2k update.
+ //
+ // c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c,
+ // or
+ // c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c,
+ //
+ // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are
+ // n-by-k matrices in the first case and k-by-n matrices in the second case.
+ virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc) = 0;
+ virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc) = 0;
+
+ // Computes a matrix-matrix product where one input matrix is symmetric.
+ //
+ // c <- alpha * a * b + beta * c,
+ // or
+ // c <- alpha * b * a + beta * c,
+ //
+ // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n
+ // matrices.
+ virtual bool DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) = 0;
+ virtual bool DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc) = 0;
+ virtual bool DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) = 0;
+ virtual bool DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) = 0;
+
+ // Performs a symmetric rank-k update.
+ //
+ // c <- alpha * a * a' + beta * c,
+ // or
+ // c <- alpha * a' * a + beta * c,
+ //
+ // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k
+ // matrix in the first case and a k-by-n matrix in the second case.
+ virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ float beta, DeviceMemory<float> *c, int ldc) = 0;
+ virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ double beta, DeviceMemory<double> *c, int ldc) = 0;
+ virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) = 0;
+ virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) = 0;
+
+ // Performs a symmetric rank-2k update.
+ //
+ // c <- alpha * a * b' + alpha * b * a' + beta * c,
+ // or
+ // c <- alpha * b' * a + alpha * a' * b + beta * c,
+ //
+ // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are
+ // n-by-k matrices in the first case and k-by-n matrices in the second case.
+ virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) = 0;
+ virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc) = 0;
+ virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) = 0;
+ virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) = 0;
+
+ // Computes a matrix-matrix product where one input matrix is triangular.
+ //
+ // b <- alpha * op(a) * b,
+ // or
+ // b <- alpha * b * op(a)
+ //
+ // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper
+ // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or
+ // op(a) = conj(a').
+ virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *b, int ldb) = 0;
+ virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *b, int ldb) = 0;
+ virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb) = 0;
+ virtual bool DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb) = 0;
+
+ // Solves a triangular matrix equation.
+ //
+ // op(a) * x = alpha * b,
+ // or
+ // x * op(a) = alpha * b
+ //
+ // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit,
+ // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a',
+ // or op(a) = conj(a').
+ virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *b, int ldb) = 0;
+ virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *b, int ldb) = 0;
+ virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb) = 0;
+ virtual bool DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb) = 0;
+
+ protected:
+ BlasSupport() {}
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport);
+};
+
+// Macro used to quickly declare overrides for abstract virtuals in the
+// BlasSupport base class.
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
+ bool DoBlasAsum(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<float> *result) override; \
+ bool DoBlasAsum(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<double> *result) override; \
+ bool DoBlasAsum(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<float> *result) override; \
+ bool DoBlasAsum(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ DeviceMemory<double> *result) override; \
+ bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, \
+ const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasAxpy(Stream *stream, uint64 elem_count, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasAxpy(Stream *stream, uint64 elem_count, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasCopy(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasCopy(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasCopy(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasCopy(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasDot(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<float> &x, int incx, \
+ const DeviceMemory<float> &y, int incy, \
+ DeviceMemory<float> *result) override; \
+ bool DoBlasDot(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<double> &x, int incx, \
+ const DeviceMemory<double> &y, int incy, \
+ DeviceMemory<double> *result) override; \
+ bool DoBlasDotc(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ const DeviceMemory<std::complex<float>> &y, int incy, \
+ DeviceMemory<std::complex<float>> *result) override; \
+ bool DoBlasDotc(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ const DeviceMemory<std::complex<double>> &y, int incy, \
+ DeviceMemory<std::complex<double>> *result) override; \
+ bool DoBlasDotu(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ const DeviceMemory<std::complex<float>> &y, int incy, \
+ DeviceMemory<std::complex<float>> *result) override; \
+ bool DoBlasDotu(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ const DeviceMemory<std::complex<double>> &y, int incy, \
+ DeviceMemory<std::complex<double>> *result) override; \
+ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<float> *result) override; \
+ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<double> *result) override; \
+ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<float> *result) override; \
+ bool DoBlasNrm2(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ DeviceMemory<double> *result) override; \
+ bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \
+ int incx, DeviceMemory<float> *y, int incy, float c, float s) \
+ override; \
+ bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \
+ int incx, DeviceMemory<double> *y, int incy, double c, \
+ double s) override; \
+ bool DoBlasRot(Stream *stream, uint64 elem_count, \
+ DeviceMemory<std::complex<float>> *x, int incx, \
+ DeviceMemory<std::complex<float>> *y, int incy, float c, \
+ float s) override; \
+ bool DoBlasRot(Stream *stream, uint64 elem_count, \
+ DeviceMemory<std::complex<double>> *x, int incx, \
+ DeviceMemory<std::complex<double>> *y, int incy, double c, \
+ double s) override; \
+ bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, \
+ DeviceMemory<float> *b, DeviceMemory<float> *c, \
+ DeviceMemory<float> *s) override; \
+ bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, \
+ DeviceMemory<double> *b, DeviceMemory<double> *c, \
+ DeviceMemory<double> *s) override; \
+ bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, \
+ DeviceMemory<std::complex<float>> *b, \
+ DeviceMemory<float> *c, \
+ DeviceMemory<std::complex<float>> *s) override; \
+ bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, \
+ DeviceMemory<std::complex<double>> *b, \
+ DeviceMemory<double> *c, \
+ DeviceMemory<std::complex<double>> *s) override; \
+ bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \
+ int incx, DeviceMemory<float> *y, int incy, \
+ const DeviceMemory<float> &param) override; \
+ bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \
+ int incx, DeviceMemory<double> *y, int incy, \
+ const DeviceMemory<double> &param) override; \
+ bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, \
+ DeviceMemory<float> *d2, DeviceMemory<float> *x1, \
+ const DeviceMemory<float> &y1, DeviceMemory<float> *param) \
+ override; \
+ bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, \
+ DeviceMemory<double> *d2, DeviceMemory<double> *x1, \
+ const DeviceMemory<double> &y1, \
+ DeviceMemory<double> *param) override; \
+ bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \
+ DeviceMemory<float> *x, int incx) override; \
+ bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \
+ DeviceMemory<double> *x, int incx) override; \
+ bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \
+ DeviceMemory<std::complex<float>> *x, int incx) override; \
+ bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \
+ DeviceMemory<std::complex<double>> *x, int incx) override; \
+ bool DoBlasScal(Stream *stream, uint64 elem_count, \
+ std::complex<float> alpha, \
+ DeviceMemory<std::complex<float>> *x, int incx) override; \
+ bool DoBlasScal(Stream *stream, uint64 elem_count, \
+ std::complex<double> alpha, \
+ DeviceMemory<std::complex<double>> *x, int incx) override; \
+ bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \
+ int incx, DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \
+ int incx, DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasSwap(Stream *stream, uint64 elem_count, \
+ DeviceMemory<std::complex<float>> *x, int incx, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasSwap(Stream *stream, uint64 elem_count, \
+ DeviceMemory<std::complex<double>> *x, int incx, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasIamax(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasIamax(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasIamax(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasIamax(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasIamin(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasIamin(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasIamin(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasIamin(Stream *stream, uint64 elem_count, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ DeviceMemory<int> *result) override; \
+ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ uint64 kl, uint64 ku, float alpha, \
+ const DeviceMemory<float> &a, int lda, \
+ const DeviceMemory<float> &x, int incx, float beta, \
+ DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ uint64 kl, uint64 ku, double alpha, \
+ const DeviceMemory<double> &a, int lda, \
+ const DeviceMemory<double> &x, int incx, double beta, \
+ DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ uint64 kl, uint64 ku, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ uint64 kl, uint64 ku, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ float alpha, const DeviceMemory<float> &a, int lda, \
+ const DeviceMemory<float> &x, int incx, float beta, \
+ DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ double alpha, const DeviceMemory<double> &a, int lda, \
+ const DeviceMemory<double> &x, int incx, double beta, \
+ DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \
+ const DeviceMemory<float> &x, int incx, \
+ const DeviceMemory<float> &y, int incy, \
+ DeviceMemory<float> *a, int lda) override; \
+ bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, \
+ const DeviceMemory<double> &x, int incx, \
+ const DeviceMemory<double> &y, int incy, \
+ DeviceMemory<double> *a, int lda) override; \
+ bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ const DeviceMemory<std::complex<float>> &y, int incy, \
+ DeviceMemory<std::complex<float>> *a, int lda) override; \
+ bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ const DeviceMemory<std::complex<double>> &y, int incy, \
+ DeviceMemory<std::complex<double>> *a, int lda) override; \
+ bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ const DeviceMemory<std::complex<float>> &y, int incy, \
+ DeviceMemory<std::complex<float>> *a, int lda) override; \
+ bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ const DeviceMemory<std::complex<double>> &y, int incy, \
+ DeviceMemory<std::complex<double>> *a, int lda) override; \
+ bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<std::complex<float>> *a, int lda) override; \
+ bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<std::complex<double>> &x, \
+ int incx, DeviceMemory<std::complex<double>> *a, int lda) \
+ override; \
+ bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ const DeviceMemory<std::complex<float>> &y, int incy, \
+ DeviceMemory<std::complex<float>> *a, int lda) override; \
+ bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ const DeviceMemory<std::complex<double>> &y, int incy, \
+ DeviceMemory<std::complex<double>> *a, int lda) override; \
+ bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &ap, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *y, int incy) override; \
+ bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &ap, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *y, int incy) override; \
+ bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ DeviceMemory<std::complex<float>> *ap) override; \
+ bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<std::complex<double>> &x, \
+ int incx, DeviceMemory<std::complex<double>> *ap) override; \
+ bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &x, int incx, \
+ const DeviceMemory<std::complex<float>> &y, int incy, \
+ DeviceMemory<std::complex<float>> *ap) override; \
+ bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &x, int incx, \
+ const DeviceMemory<std::complex<double>> &y, int incy, \
+ DeviceMemory<std::complex<double>> *ap) override; \
+ bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \
+ float alpha, const DeviceMemory<float> &a, int lda, \
+ const DeviceMemory<float> &x, int incx, float beta, \
+ DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \
+ double alpha, const DeviceMemory<double> &a, int lda, \
+ const DeviceMemory<double> &x, int incx, double beta, \
+ DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ float alpha, const DeviceMemory<float> &ap, \
+ const DeviceMemory<float> &x, int incx, float beta, \
+ DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<double> &ap, \
+ const DeviceMemory<double> &x, int incx, double beta, \
+ DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<float> *ap) override; \
+ bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<double> *ap) override; \
+ bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ float alpha, const DeviceMemory<float> &x, int incx, \
+ const DeviceMemory<float> &y, int incy, \
+ DeviceMemory<float> *ap) override; \
+ bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<double> &x, int incx, \
+ const DeviceMemory<double> &y, int incy, \
+ DeviceMemory<double> *ap) override; \
+ bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ float alpha, const DeviceMemory<float> &a, int lda, \
+ const DeviceMemory<float> &x, int incx, float beta, \
+ DeviceMemory<float> *y, int incy) override; \
+ bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<double> &a, int lda, \
+ const DeviceMemory<double> &x, int incx, double beta, \
+ DeviceMemory<double> *y, int incy) override; \
+ bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \
+ const DeviceMemory<float> &x, int incx, \
+ DeviceMemory<float> *a, int lda) override; \
+ bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<double> &x, int incx, \
+ DeviceMemory<double> *a, int lda) override; \
+ bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ float alpha, const DeviceMemory<float> &x, int incx, \
+ const DeviceMemory<float> &y, int incy, \
+ DeviceMemory<float> *a, int lda) override; \
+ bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \
+ double alpha, const DeviceMemory<double> &x, int incx, \
+ const DeviceMemory<double> &y, int incy, \
+ DeviceMemory<double> *a, int lda) override; \
+ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<float> &a, int lda, \
+ DeviceMemory<float> *x, int incx) override; \
+ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<double> &a, int lda, \
+ DeviceMemory<double> *x, int incx) override; \
+ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<std::complex<float>> &a, \
+ int lda, DeviceMemory<std::complex<float>> *x, int incx) \
+ override; \
+ bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<std::complex<double>> &a, \
+ int lda, DeviceMemory<std::complex<double>> *x, int incx) \
+ override; \
+ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<float> &a, int lda, \
+ DeviceMemory<float> *x, int incx) override; \
+ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<double> &a, int lda, \
+ DeviceMemory<double> *x, int incx) override; \
+ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<std::complex<float>> &a, \
+ int lda, DeviceMemory<std::complex<float>> *x, int incx) \
+ override; \
+ bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ uint64 k, const DeviceMemory<std::complex<double>> &a, \
+ int lda, DeviceMemory<std::complex<double>> *x, int incx) \
+ override; \
+ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x, \
+ int incx) override; \
+ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<double> &ap, DeviceMemory<double> *x, \
+ int incx) override; \
+ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<float>> &ap, \
+ DeviceMemory<std::complex<float>> *x, int incx) override; \
+ bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<double>> &ap, \
+ DeviceMemory<std::complex<double>> *x, int incx) override; \
+ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x, \
+ int incx) override; \
+ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<double> &ap, DeviceMemory<double> *x, \
+ int incx) override; \
+ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<float>> &ap, \
+ DeviceMemory<std::complex<float>> *x, int incx) override; \
+ bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<double>> &ap, \
+ DeviceMemory<std::complex<double>> *x, int incx) override; \
+ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<float> &a, int lda, \
+ DeviceMemory<float> *x, int incx) override; \
+ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<double> &a, int lda, \
+ DeviceMemory<double> *x, int incx) override; \
+ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ DeviceMemory<std::complex<float>> *x, int incx) override; \
+ bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ DeviceMemory<std::complex<double>> *x, int incx) override; \
+ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<float> &a, int lda, \
+ DeviceMemory<float> *x, int incx) override; \
+ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<double> &a, int lda, \
+ DeviceMemory<double> *x, int incx) override; \
+ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ DeviceMemory<std::complex<float>> *x, int incx) override; \
+ bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, blas::Diagonal diag, uint64 n, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ DeviceMemory<std::complex<double>> *x, int incx) override; \
+ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
+ float alpha, const DeviceMemory<float> &a, int lda, \
+ const DeviceMemory<float> &b, int ldb, float beta, \
+ DeviceMemory<float> *c, int ldc) override; \
+ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
+ double alpha, const DeviceMemory<double> &a, int lda, \
+ const DeviceMemory<double> &b, int ldb, double beta, \
+ DeviceMemory<double> *c, int ldc) override; \
+ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *c, int ldc) override; \
+ bool DoBlasGemm(Stream *stream, blas::Transpose transa, \
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasGemmBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
+ const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \
+ const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \
+ const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \
+ int batch_count) override; \
+ bool DoBlasGemmBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \
+ const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
+ const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \
+ int batch_count) override; \
+ bool DoBlasGemmBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
+ std::complex<float> beta, \
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
+ int batch_count) override; \
+ bool DoBlasGemmBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, \
+ int lda, \
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \
+ int ldb, std::complex<double> beta, \
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \
+ int ldc, int batch_count) override; \
+ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ uint64 m, uint64 n, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *c, int ldc) override; \
+ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ uint64 m, uint64 n, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ float beta, DeviceMemory<std::complex<float>> *c, int ldc) \
+ override; \
+ bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ double beta, DeviceMemory<std::complex<double>> *c, int ldc) \
+ override; \
+ bool DoBlasHer2k( \
+ Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \
+ uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, float beta, \
+ DeviceMemory<std::complex<float>> *c, int ldc) override; \
+ bool DoBlasHer2k( \
+ Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \
+ uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, double beta, \
+ DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ uint64 m, uint64 n, float alpha, \
+ const DeviceMemory<float> &a, int lda, \
+ const DeviceMemory<float> &b, int ldb, float beta, \
+ DeviceMemory<float> *c, int ldc) override; \
+ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ uint64 m, uint64 n, double alpha, \
+ const DeviceMemory<double> &a, int lda, \
+ const DeviceMemory<double> &b, int ldb, double beta, \
+ DeviceMemory<double> *c, int ldc) override; \
+ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ uint64 m, uint64 n, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *c, int ldc) override; \
+ bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ uint64 m, uint64 n, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<float> &a, int lda, float beta, \
+ DeviceMemory<float> *c, int ldc) override; \
+ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, double beta, \
+ DeviceMemory<double> *c, int ldc) override; \
+ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *c, int ldc) override; \
+ bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<float> &a, int lda, \
+ const DeviceMemory<float> &b, int ldb, float beta, \
+ DeviceMemory<float> *c, int ldc) override; \
+ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, \
+ const DeviceMemory<double> &b, int ldb, double beta, \
+ DeviceMemory<double> *c, int ldc) override; \
+ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, \
+ std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, \
+ std::complex<float> beta, \
+ DeviceMemory<std::complex<float>> *c, int ldc) override; \
+ bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \
+ blas::Transpose trans, uint64 n, uint64 k, \
+ std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, \
+ std::complex<double> beta, \
+ DeviceMemory<std::complex<double>> *c, int ldc) override; \
+ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, float alpha, const DeviceMemory<float> &a, \
+ int lda, DeviceMemory<float> *b, int ldb) override; \
+ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, double alpha, const DeviceMemory<double> &a, \
+ int lda, DeviceMemory<double> *b, int ldb) override; \
+ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ DeviceMemory<std::complex<float>> *b, int ldb) override; \
+ bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ DeviceMemory<std::complex<double>> *b, int ldb) override; \
+ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, float alpha, const DeviceMemory<float> &a, \
+ int lda, DeviceMemory<float> *b, int ldb) override; \
+ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, double alpha, const DeviceMemory<double> &a, \
+ int lda, DeviceMemory<double> *b, int ldb) override; \
+ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, \
+ DeviceMemory<std::complex<float>> *b, int ldb) override; \
+ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
+ blas::Transpose transa, blas::Diagonal diag, uint64 m, \
+ uint64 n, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, \
+ DeviceMemory<std::complex<double>> *b, int ldb) override;
+
+} // namespace blas
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_activation.cc b/tensorflow/stream_executor/cuda/cuda_activation.cc
new file mode 100644
index 0000000000..32d2c0d424
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_activation.cc
@@ -0,0 +1,30 @@
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+CUcontext ExtractCudaContext(CUDAExecutor *cuda_exec);
+CUDAExecutor *ExtractCudaExecutor(StreamExecutor *stream_exec);
+
+ScopedActivateExecutorContext::ScopedActivateExecutorContext(
+ CUDAExecutor *cuda_exec, MultiOpActivation moa)
+ : cuda_exec_(cuda_exec),
+ driver_scoped_activate_context_(
+ new ScopedActivateContext{ExtractCudaContext(cuda_exec), moa}) {}
+
+ScopedActivateExecutorContext::ScopedActivateExecutorContext(
+ StreamExecutor *stream_exec, MultiOpActivation moa)
+ : ScopedActivateExecutorContext(ExtractCudaExecutor(stream_exec), moa) {}
+
+ScopedActivateExecutorContext::~ScopedActivateExecutorContext() {
+ delete static_cast<ScopedActivateContext *>(driver_scoped_activate_context_);
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_activation.h b/tensorflow/stream_executor/cuda/cuda_activation.h
new file mode 100644
index 0000000000..4181d13d0a
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_activation.h
@@ -0,0 +1,53 @@
+// This file contains APIs that assume a StreamExecutor is backed by CUDA.
+// It reaches into the CUDA implementation to activate an underlying CUDA
+// context.
+//
+// Having this file separate from cuda_gpu_executor.h means that dependent
+// code does not also have to depend on cuda.h.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_
+
+#include "tensorflow/stream_executor/cuda/multi_op_activation.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+class StreamExecutor;
+
+namespace cuda {
+
+class CUDAExecutor;
+class ScopedActivateContext;
+
+// Activates a CUDA context within an enclosing scope.
+class ScopedActivateExecutorContext {
+ public:
+ // Form that takes a CUDA executor implementation.
+ explicit ScopedActivateExecutorContext(
+ CUDAExecutor* cuda_exec, MultiOpActivation moa = MultiOpActivation::kNo);
+
+ // Form that takes a pImpl executor and extracts a CUDA implementation --
+ // fatal failure if it is not CUDA inside.
+ explicit ScopedActivateExecutorContext(
+ StreamExecutor* stream_exec,
+ MultiOpActivation moa = MultiOpActivation::kNo);
+
+ ~ScopedActivateExecutorContext();
+
+ private:
+ // The CUDA executor implementation whose context is activated.
+ CUDAExecutor* cuda_exec_;
+
+ // The cuda.h-using datatype that we wrap.
+ ScopedActivateContext* driver_scoped_activate_context_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivateExecutorContext);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
new file mode 100644
index 0000000000..ef1036bca3
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -0,0 +1,2184 @@
+#include "tensorflow/stream_executor/cuda/cuda_blas.h"
+
+#include <dlfcn.h>
+
+#include <complex>
+
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/status_macros.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "third_party/gpus/cuda/include/cublas_v2.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
+
+namespace dynload {
+
+#define PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) \
+ struct DynLoadShim__##__name { \
+ static const char *kName; \
+ using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
+ static void *GetDsoHandle() { \
+ static auto status = internal::CachedDsoLoader::GetCublasDsoHandle(); \
+ return status.ValueOrDie(); \
+ } \
+ static FuncPointerT DynLoad() { \
+ static void *f = dlsym(GetDsoHandle(), kName); \
+ CHECK(f != nullptr) << "could not find " << kName \
+ << " in cuBLAS DSO; dlerror: " << dlerror(); \
+ return reinterpret_cast<FuncPointerT>(f); \
+ } \
+ template <typename... Args> \
+ cublasStatus_t operator()(CUDAExecutor * parent, Args... args) { \
+ cuda::ScopedActivateExecutorContext sac{parent}; \
+ return DynLoad()(args...); \
+ } \
+ } __name; \
+ const char *DynLoadShim__##__name::kName = #__name;
+
+#define PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(__name) \
+ PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name)
+
+#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
+ __macro(cublasSnrm2) \
+ __macro(cublasDnrm2) \
+ __macro(cublasScnrm2) \
+ __macro(cublasDznrm2) \
+ __macro(cublasSdot) \
+ __macro(cublasDdot) \
+ __macro(cublasCdotu) \
+ __macro(cublasCdotc) \
+ __macro(cublasZdotu) \
+ __macro(cublasZdotc) \
+ __macro(cublasSscal) \
+ __macro(cublasDscal) \
+ __macro(cublasCscal) \
+ __macro(cublasCsscal) \
+ __macro(cublasZscal) \
+ __macro(cublasZdscal) \
+ __macro(cublasSaxpy) \
+ __macro(cublasDaxpy) \
+ __macro(cublasCaxpy) \
+ __macro(cublasZaxpy) \
+ __macro(cublasScopy) \
+ __macro(cublasDcopy) \
+ __macro(cublasCcopy) \
+ __macro(cublasZcopy) \
+ __macro(cublasSswap) \
+ __macro(cublasDswap) \
+ __macro(cublasCswap) \
+ __macro(cublasZswap) \
+ __macro(cublasIsamax) \
+ __macro(cublasIdamax) \
+ __macro(cublasIcamax) \
+ __macro(cublasIzamax) \
+ __macro(cublasIsamin) \
+ __macro(cublasIdamin) \
+ __macro(cublasIcamin) \
+ __macro(cublasIzamin) \
+ __macro(cublasSasum) \
+ __macro(cublasDasum) \
+ __macro(cublasScasum) \
+ __macro(cublasDzasum) \
+ __macro(cublasSrot) \
+ __macro(cublasDrot) \
+ __macro(cublasCrot) \
+ __macro(cublasCsrot) \
+ __macro(cublasZrot) \
+ __macro(cublasZdrot) \
+ __macro(cublasSrotg) \
+ __macro(cublasDrotg) \
+ __macro(cublasCrotg) \
+ __macro(cublasZrotg) \
+ __macro(cublasSrotm) \
+ __macro(cublasDrotm) \
+ __macro(cublasSrotmg) \
+ __macro(cublasDrotmg) \
+ __macro(cublasSgemv) \
+ __macro(cublasDgemv) \
+ __macro(cublasCgemv) \
+ __macro(cublasZgemv) \
+ __macro(cublasSgbmv) \
+ __macro(cublasDgbmv) \
+ __macro(cublasCgbmv) \
+ __macro(cublasZgbmv) \
+ __macro(cublasStrmv) \
+ __macro(cublasDtrmv) \
+ __macro(cublasCtrmv) \
+ __macro(cublasZtrmv) \
+ __macro(cublasStbmv) \
+ __macro(cublasDtbmv) \
+ __macro(cublasCtbmv) \
+ __macro(cublasZtbmv) \
+ __macro(cublasStpmv) \
+ __macro(cublasDtpmv) \
+ __macro(cublasCtpmv) \
+ __macro(cublasZtpmv) \
+ __macro(cublasStrsv) \
+ __macro(cublasDtrsv) \
+ __macro(cublasCtrsv) \
+ __macro(cublasZtrsv) \
+ __macro(cublasStpsv) \
+ __macro(cublasDtpsv) \
+ __macro(cublasCtpsv) \
+ __macro(cublasZtpsv) \
+ __macro(cublasStbsv) \
+ __macro(cublasDtbsv) \
+ __macro(cublasCtbsv) \
+ __macro(cublasZtbsv) \
+ __macro(cublasSsymv) \
+ __macro(cublasDsymv) \
+ __macro(cublasCsymv) \
+ __macro(cublasZsymv) \
+ __macro(cublasChemv) \
+ __macro(cublasZhemv) \
+ __macro(cublasSsbmv) \
+ __macro(cublasDsbmv) \
+ __macro(cublasChbmv) \
+ __macro(cublasZhbmv) \
+ __macro(cublasSspmv) \
+ __macro(cublasDspmv) \
+ __macro(cublasChpmv) \
+ __macro(cublasZhpmv) \
+ __macro(cublasSger) \
+ __macro(cublasDger) \
+ __macro(cublasCgeru) \
+ __macro(cublasCgerc) \
+ __macro(cublasZgeru) \
+ __macro(cublasZgerc) \
+ __macro(cublasSsyr) \
+ __macro(cublasDsyr) \
+ __macro(cublasCsyr) \
+ __macro(cublasZsyr) \
+ __macro(cublasCher) \
+ __macro(cublasZher) \
+ __macro(cublasSspr) \
+ __macro(cublasDspr) \
+ __macro(cublasChpr) \
+ __macro(cublasZhpr) \
+ __macro(cublasSsyr2) \
+ __macro(cublasDsyr2) \
+ __macro(cublasCsyr2) \
+ __macro(cublasZsyr2) \
+ __macro(cublasCher2) \
+ __macro(cublasZher2) \
+ __macro(cublasSspr2) \
+ __macro(cublasDspr2) \
+ __macro(cublasChpr2) \
+ __macro(cublasZhpr2) \
+ __macro(cublasSgemm) \
+ __macro(cublasDgemm) \
+ __macro(cublasCgemm) \
+ __macro(cublasZgemm) \
+ __macro(cublasSsyrk) \
+ __macro(cublasDsyrk) \
+ __macro(cublasCsyrk) \
+ __macro(cublasZsyrk) \
+ __macro(cublasCherk) \
+ __macro(cublasZherk) \
+ __macro(cublasSsyr2k) \
+ __macro(cublasDsyr2k) \
+ __macro(cublasCsyr2k) \
+ __macro(cublasZsyr2k) \
+ __macro(cublasCher2k) \
+ __macro(cublasZher2k) \
+ __macro(cublasSsyrkx) \
+ __macro(cublasDsyrkx) \
+ __macro(cublasCsyrkx) \
+ __macro(cublasZsyrkx) \
+ __macro(cublasCherkx) \
+ __macro(cublasZherkx) \
+ __macro(cublasSsymm) \
+ __macro(cublasDsymm) \
+ __macro(cublasCsymm) \
+ __macro(cublasZsymm) \
+ __macro(cublasChemm) \
+ __macro(cublasZhemm) \
+ __macro(cublasStrsm) \
+ __macro(cublasDtrsm) \
+ __macro(cublasCtrsm) \
+ __macro(cublasZtrsm) \
+ __macro(cublasStrmm) \
+ __macro(cublasDtrmm) \
+ __macro(cublasCtrmm) \
+ __macro(cublasZtrmm) \
+ __macro(cublasSgeam) \
+ __macro(cublasDgeam) \
+ __macro(cublasCgeam) \
+ __macro(cublasZgeam) \
+ __macro(cublasSdgmm) \
+ __macro(cublasDdgmm) \
+ __macro(cublasCdgmm) \
+ __macro(cublasZdgmm)
+
+PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasCreate)
+PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasDestroy)
+PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetStream)
+PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetPointerMode)
+PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasGetPointerMode)
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmBatched)
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasDgemmBatched)
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasCgemmBatched)
+PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasZgemmBatched)
+CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP)
+
+} // namespace dynload
+
+static string ToString(cublasStatus_t status) {
+ switch (status) {
+ case CUBLAS_STATUS_SUCCESS:
+ return "CUBLAS_STATUS_SUCCESS";
+ case CUBLAS_STATUS_NOT_INITIALIZED:
+ return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED:
+ return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE:
+ return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH:
+ return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR:
+ return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED:
+ return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR:
+ return "CUBLAS_STATUS_INTERNAL_ERROR";
+ default:
+ return port::StrCat("<invalid cublas status: ", status, ">");
+ }
+}
+
+// cuBLAS has interfaces that permit pointers to be passed from either the host
+// memory space or the device memory space; however, you must instruct it as to
+// which address space those pointers are in with cublasSetPointerMode.
+//
+// This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
+// you are about to perform in a given scope.
+//
+// The prior cuBLAS pointer mode is retained and restored when this object goes
+// out of scope.
+class ScopedCublasPointerMode {
+ public:
+ // Note that, because the setting of the cublas pointer mode is fallible,
+ // construction of this scoped datatype must be paired with a call to
+ // Init().
+ //
+ // Parameters:
+ // handle: The cublas library handle to act upon in setting the pointer mode.
+ explicit ScopedCublasPointerMode(CUDAExecutor *parent, cublasHandle_t handle)
+ : parent_(parent), handle_(handle), ok_(false) {}
+
+ // Attempts the switch to the requested scoped pointer mode, new_mode.
+ //
+ // Note that when false is returned, an appropriate error has already been
+ // logged.
+ bool Init(cublasPointerMode_t new_mode) {
+ cublasStatus_t ret =
+ dynload::cublasGetPointerMode_v2(parent_, handle_, &old_mode_);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
+ return ok_ = false;
+ }
+
+ ret = dynload::cublasSetPointerMode_v2(parent_, handle_, new_mode);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
+ return ok_ = false;
+ }
+
+ return ok_ = true;
+ }
+
+ // Switches back to the prior pointer mode, if the switch operation was
+ // successful in the first place.
+ ~ScopedCublasPointerMode() {
+ if (ok_) {
+ cublasStatus_t ret =
+ dynload::cublasSetPointerMode_v2(parent_, handle_, old_mode_);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set former cublas pointer mode: "
+ << ToString(ret);
+ }
+ }
+ }
+
+ private:
+ CUDAExecutor *parent_; // Executor establishing this pointer mode for.
+ cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
+ cublasPointerMode_t old_mode_; // Prior cuBLAS pointer mode, to be restored.
+ bool ok_; // Whether the change was successful.
+};
+
+bool CUDABlas::Init() {
+ cublasStatus_t ret = dynload::cublasCreate_v2(parent_, &blas_);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
+ return false;
+ }
+
+ return true;
+}
+
+CUDABlas::CUDABlas(cuda::CUDAExecutor *parent)
+ : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
+
+CUDABlas::~CUDABlas() {
+ if (blas_ != nullptr) {
+ dynload::cublasDestroy_v2(parent_, blas_);
+ }
+}
+
+bool CUDABlas::SetStream(Stream *stream) {
+ CHECK(stream != nullptr);
+ CHECK(AsCUDAStreamValue(stream) != nullptr);
+ CHECK(blas_ != nullptr);
+ cublasStatus_t ret =
+ dynload::cublasSetStream_v2(parent_, blas_, AsCUDAStreamValue(stream));
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
+ return false;
+ }
+
+ return true;
+}
+
+namespace {
+
+// Helper functions transforming blas arguments into cuBLAS arguments.
+
+cublasOperation_t CUDABlasTranspose(blas::Transpose trans) {
+ switch (trans) {
+ case blas::Transpose::kNoTranspose:
+ return CUBLAS_OP_N;
+ case blas::Transpose::kTranspose:
+ return CUBLAS_OP_T;
+ case blas::Transpose::kConjugateTranspose:
+ return CUBLAS_OP_C;
+ default:
+ LOG(FATAL) << "Invalid value of blas::Transpose.";
+ }
+}
+
+cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
+ switch (uplo) {
+ case blas::UpperLower::kUpper:
+ return CUBLAS_FILL_MODE_UPPER;
+ case blas::UpperLower::kLower:
+ return CUBLAS_FILL_MODE_LOWER;
+ default:
+ LOG(FATAL) << "Invalid value of blas::UpperLower.";
+ }
+}
+
+cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
+ switch (diag) {
+ case blas::Diagonal::kUnit:
+ return CUBLAS_DIAG_UNIT;
+ case blas::Diagonal::kNonUnit:
+ return CUBLAS_DIAG_NON_UNIT;
+ default:
+ LOG(FATAL) << "Invalid value of blas::Diagonal.";
+ }
+}
+
+cublasSideMode_t CUDABlasSide(blas::Side side) {
+ switch (side) {
+ case blas::Side::kLeft:
+ return CUBLAS_SIDE_LEFT;
+ case blas::Side::kRight:
+ return CUBLAS_SIDE_RIGHT;
+ default:
+ LOG(FATAL) << "Invalid value of blas::Side.";
+ }
+}
+
+} // namespace
+
+template <typename FuncT, typename... Args>
+bool CUDABlas::DoBlasInternal(FuncT cublas_func, Stream *stream,
+ bool pointer_mode_host, Args... args) {
+ mutex_lock lock{mu_};
+
+ CHECK(blas_ != nullptr);
+ if (!SetStream(stream)) {
+ return false;
+ }
+
+ ScopedCublasPointerMode pointer_mode{parent_, blas_};
+ if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
+ : CUBLAS_POINTER_MODE_DEVICE)) {
+ return false;
+ }
+
+ cublasStatus_t ret = cublas_func(parent_, blas_, args...);
+ if (ret != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
+ << ToString(ret);
+ return false;
+ }
+
+ return true;
+}
+
+bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *result) {
+ return DoBlasInternal(dynload::cublasSasum, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *result) {
+ return DoBlasInternal(dynload::cublasDasum, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result) {
+ return DoBlasInternal(
+ dynload::cublasScasum, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result) {
+ return DoBlasInternal(
+ dynload::cublasDzasum, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(dynload::cublasSaxpy, stream,
+ true /* = pointer_mode_host */, elem_count, &alpha,
+ CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(dynload::cublasDaxpy, stream,
+ true /* = pointer_mode_host */, elem_count, &alpha,
+ CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(dynload::cublasCaxpy, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(dynload::cublasZaxpy, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(dynload::cublasScopy, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(dynload::cublasDcopy, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(dynload::cublasCcopy, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(dynload::cublasZcopy, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *result) {
+ return DoBlasInternal(
+ dynload::cublasSdot, stream, false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *result) {
+ return DoBlasInternal(
+ dynload::cublasDdot, stream, false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result) {
+ return DoBlasInternal(
+ dynload::cublasCdotc, stream, false /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(result)));
+}
+
+bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result) {
+ return DoBlasInternal(
+ dynload::cublasZdotc, stream, false /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(result)));
+}
+
+bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result) {
+ return DoBlasInternal(
+ dynload::cublasCdotu, stream, false /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(result)));
+}
+
+bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result) {
+ return DoBlasInternal(
+ dynload::cublasZdotu, stream, false /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(result)));
+}
+
+bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *result) {
+ return DoBlasInternal(dynload::cublasSnrm2, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *result) {
+ return DoBlasInternal(dynload::cublasDnrm2, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result) {
+ return DoBlasInternal(
+ dynload::cublasScnrm2, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result) {
+ return DoBlasInternal(
+ dynload::cublasDznrm2, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy, float c, float s) {
+ return DoBlasInternal(
+ dynload::cublasSrot, stream, true /* = pointer_mode_host */, elem_count,
+ CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, &c, &s);
+}
+
+bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy, double c,
+ double s) {
+ return DoBlasInternal(
+ dynload::cublasDrot, stream, true /* = pointer_mode_host */, elem_count,
+ CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, &c, &s);
+}
+
+bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<float>> *x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy,
+ float c, float s) {
+ return DoBlasInternal(dynload::cublasCsrot, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemoryMutable(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy, &c, &s);
+}
+
+bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<double>> *x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy,
+ double c, double s) {
+ return DoBlasInternal(dynload::cublasZdrot, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemoryMutable(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy, &c, &s);
+}
+
+bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
+ DeviceMemory<float> *b, DeviceMemory<float> *c,
+ DeviceMemory<float> *s) {
+ return DoBlasInternal(dynload::cublasSrotg, stream,
+ false /* = pointer_mode_host */, CUDAMemoryMutable(a),
+ CUDAMemoryMutable(b), CUDAMemoryMutable(c),
+ CUDAMemoryMutable(s));
+}
+
+bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
+ DeviceMemory<double> *b, DeviceMemory<double> *c,
+ DeviceMemory<double> *s) {
+ return DoBlasInternal(dynload::cublasDrotg, stream,
+ false /* = pointer_mode_host */,
+ CUDAComplex(CUDAMemoryMutable(a)), CUDAMemoryMutable(b),
+ CUDAMemoryMutable(c), CUDAMemoryMutable(s));
+}
+
+bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
+ DeviceMemory<std::complex<float>> *b,
+ DeviceMemory<float> *c,
+ DeviceMemory<std::complex<float>> *s) {
+ return DoBlasInternal(
+ dynload::cublasCrotg, stream, false /* = pointer_mode_host */,
+ CUDAComplex(CUDAMemoryMutable(a)), CUDAComplex(CUDAMemoryMutable(b)),
+ CUDAComplex(CUDAMemoryMutable(c)), CUDAComplex(CUDAMemoryMutable(s)));
+}
+
+bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
+ DeviceMemory<std::complex<double>> *b,
+ DeviceMemory<double> *c,
+ DeviceMemory<std::complex<double>> *s) {
+ return DoBlasInternal(
+ dynload::cublasZrotg, stream, false /* = pointer_mode_host */,
+ CUDAComplex(CUDAMemoryMutable(a)), CUDAComplex(CUDAMemoryMutable(b)),
+ CUDAComplex(CUDAMemoryMutable(c)), CUDAComplex(CUDAMemoryMutable(s)));
+}
+
+bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
+ DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy,
+ const DeviceMemory<float> &param) {
+ return DoBlasInternal(dynload::cublasSrotm, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy,
+ CUDAMemory(param));
+}
+
+bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
+ DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy,
+ const DeviceMemory<double> &param) {
+ return DoBlasInternal(dynload::cublasDrotm, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy,
+ CUDAMemory(param));
+}
+
+bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
+ DeviceMemory<float> *d2, DeviceMemory<float> *x1,
+ const DeviceMemory<float> &y1,
+ DeviceMemory<float> *param) {
+ return DoBlasInternal(dynload::cublasSrotmg, stream,
+ false /* = pointer_mode_host */, CUDAMemoryMutable(d1),
+ CUDAMemoryMutable(d2), CUDAMemoryMutable(x1),
+ CUDAMemory(y1), CUDAMemoryMutable(param));
+}
+
+bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
+ DeviceMemory<double> *d2, DeviceMemory<double> *x1,
+ const DeviceMemory<double> &y1,
+ DeviceMemory<double> *param) {
+ return DoBlasInternal(dynload::cublasDrotmg, stream,
+ false /* = pointer_mode_host */, CUDAMemoryMutable(d1),
+ CUDAMemoryMutable(d2), CUDAMemoryMutable(x1),
+ CUDAMemory(y1), CUDAMemoryMutable(param));
+}
+
+bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
+ DeviceMemory<float> *x, int incx) {
+ return DoBlasInternal(dynload::cublasSscal, stream,
+ true /* = pointer_mode_host */, elem_count, &alpha,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
+ DeviceMemory<double> *x, int incx) {
+ return DoBlasInternal(dynload::cublasDscal, stream,
+ true /* = pointer_mode_host */, elem_count, &alpha,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ return DoBlasInternal(
+ dynload::cublasCsscal, stream, true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ return DoBlasInternal(
+ dynload::cublasZdscal, stream, true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
+ std::complex<float> alpha,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ return DoBlasInternal(
+ dynload::cublasCscal, stream, true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
+ std::complex<double> alpha,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ return DoBlasInternal(
+ dynload::cublasZscal, stream, true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(dynload::cublasSswap, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(dynload::cublasDswap, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<float>> *x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(dynload::cublasCswap, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemoryMutable(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
+ DeviceMemory<std::complex<double>> *x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(dynload::cublasZswap, stream,
+ true /* = pointer_mode_host */, elem_count,
+ CUDAComplex(CUDAMemoryMutable(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<int> *result) {
+ return DoBlasInternal(dynload::cublasIsamax, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<int> *result) {
+ return DoBlasInternal(dynload::cublasIdamax, stream,
+ false /* = pointer_mode_host */, elem_count,
+ CUDAMemory(x), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result) {
+ return DoBlasInternal(
+ dynload::cublasIcamax, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<int> *result) {
+ return DoBlasInternal(
+ dynload::cublasIzamax, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<int> *result) {
+ return DoBlasInternal(
+ dynload::cublasIsamin, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<int> *result) {
+ return DoBlasInternal(
+ dynload::cublasIdamin, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result) {
+ return DoBlasInternal(
+ dynload::cublasIcamin, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<int> *result) {
+ return DoBlasInternal(
+ dynload::cublasIzamin, stream, false /* = pointer_mode_host */,
+ elem_count, CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
+}
+
+bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasSgbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, kl, ku, &alpha, CUDAMemory(a), lda,
+ CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasDgbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, kl, ku, &alpha, CUDAMemory(a), lda,
+ CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasCgbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, kl, ku, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, uint64 kl, uint64 ku,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasZgbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, kl, ku, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x, int incx,
+ float beta, DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasSgemv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
+ incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasDgemv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
+ incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasCgemv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasZgemv, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(trans), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasSger, stream, true /* = pointer_mode_host */, m, n, &alpha,
+ CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
+}
+
+bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasDger, stream, true /* = pointer_mode_host */, m, n, &alpha,
+ CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
+}
+
+bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasCgerc, stream, true /* = pointer_mode_host */, m, n,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasZgerc, stream, true /* = pointer_mode_host */, m, n,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasCgeru, stream, true /* = pointer_mode_host */, m, n,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasZgeru, stream, true /* = pointer_mode_host */, m, n,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasChbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, k, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasZhbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, k, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasChemv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasZhemv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasCher, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasZher, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasCher2, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda) {
+ return DoBlasInternal(
+ dynload::cublasZher2, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(a)), lda);
+}
+
+bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &ap,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasChpmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(ap)), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &ap,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasZhpmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(ap)), CUDAComplex(CUDAMemory(x)), incx,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
+}
+
+bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *ap) {
+ return DoBlasInternal(
+ dynload::cublasChpr, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemoryMutable(ap)));
+}
+
+bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *ap) {
+ return DoBlasInternal(
+ dynload::cublasZhpr, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemoryMutable(ap)));
+}
+
+bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *ap) {
+ return DoBlasInternal(
+ dynload::cublasChpr2, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(ap)));
+}
+
+bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *ap) {
+ return DoBlasInternal(
+ dynload::cublasZhpr2, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
+ CUDAComplex(CUDAMemoryMutable(ap)));
+}
+
+bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x, int incx,
+ float beta, DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasSsbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, k, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
+ incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(
+ dynload::cublasDsbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, k, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
+ incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &ap,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(dynload::cublasSspmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(ap),
+ CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &ap,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(dynload::cublasDspmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(ap),
+ CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *ap) {
+ return DoBlasInternal(dynload::cublasSspr, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemoryMutable(ap));
+}
+
+bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *ap) {
+ return DoBlasInternal(dynload::cublasDspr, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemoryMutable(ap));
+}
+
+bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *ap) {
+ return DoBlasInternal(dynload::cublasSspr2, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemory(y), incy, CUDAMemoryMutable(ap));
+}
+
+bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *ap) {
+ return DoBlasInternal(dynload::cublasDspr2, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemory(y), incy, CUDAMemoryMutable(ap));
+}
+
+bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ return DoBlasInternal(dynload::cublasSsymv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(a), lda,
+ CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy) {
+ return DoBlasInternal(dynload::cublasDsymv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(a), lda,
+ CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
+}
+
+bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *a, int lda) {
+ return DoBlasInternal(dynload::cublasSsyr, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemoryMutable(a), lda);
+}
+
+bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *a, int lda) {
+ return DoBlasInternal(dynload::cublasDsyr, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemoryMutable(a), lda);
+}
+
+bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ float alpha, const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda) {
+ return DoBlasInternal(dynload::cublasSsyr2, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
+}
+
+bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
+ double alpha, const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda) {
+ return DoBlasInternal(dynload::cublasDsyr2, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
+ incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
+}
+
+bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ return DoBlasInternal(dynload::cublasStbmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ return DoBlasInternal(dynload::cublasDtbmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) {
+ return DoBlasInternal(
+ dynload::cublasCtbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) {
+ return DoBlasInternal(
+ dynload::cublasZtbmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ return DoBlasInternal(dynload::cublasStbsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ return DoBlasInternal(dynload::cublasDtbsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) {
+ return DoBlasInternal(
+ dynload::cublasCtbsv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ uint64 k, const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) {
+ return DoBlasInternal(
+ dynload::cublasZtbsv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx) {
+ return DoBlasInternal(
+ dynload::cublasStpmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap,
+ DeviceMemory<double> *x, int incx) {
+ return DoBlasInternal(
+ dynload::cublasDtpmv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasCtpmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasZtpmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx) {
+ return DoBlasInternal(
+ dynload::cublasStpsv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap,
+ DeviceMemory<double> *x, int incx) {
+ return DoBlasInternal(
+ dynload::cublasDtpsv, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasCtpsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasZtpsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
+ CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ return DoBlasInternal(dynload::cublasStrmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ return DoBlasInternal(dynload::cublasDtrmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasCtrmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
+ lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasZtrmv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
+ lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ return DoBlasInternal(dynload::cublasStrsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ return DoBlasInternal(dynload::cublasDtrsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
+ CUDAMemoryMutable(x), incx);
+}
+
+bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasCtrsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
+ lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ return DoBlasInternal(dynload::cublasZtrsv, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
+ CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
+ lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
+}
+
+bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) {
+ VLOG(1) << port::Printf(
+ "doing cuBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
+ "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
+ "c=%p ldc=%d",
+ static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
+ a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
+ if (transa == blas::Transpose::kNoTranspose) {
+ if (lda < static_cast<int64>(m)) {
+ LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
+ "precondition violation";
+ }
+ } else {
+ if (lda < static_cast<int64>(k)) {
+ LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
+ << ") (transpose case); precondition violation";
+ }
+ }
+ if (transb == blas::Transpose::kNoTranspose) {
+ if (ldb < static_cast<int64>(k)) {
+ LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
+ << ") (no transpose case); precondition violation";
+ }
+ } else {
+ if (ldb < static_cast<int64>(n)) {
+ LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
+ "precondition violation";
+ }
+ }
+ return DoBlasInternal(
+ dynload::cublasSgemm, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasDgemm, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasCgemm, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasZgemm, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+template <typename T, typename FuncT>
+port::Status CUDABlas::DoBlasGemmBatchedInternal(
+ FuncT cublas_func, Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+ const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
+ const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
+ const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
+ int batch_count) {
+ std::vector<T *> a_ptr_vec, b_ptr_vec, c_ptr_vec;
+ for (int i = 0; i < batch_count; ++i) {
+ a_ptr_vec.push_back(static_cast<T *>(a_array[i]->opaque()));
+ b_ptr_vec.push_back(static_cast<T *>(b_array[i]->opaque()));
+ c_ptr_vec.push_back(static_cast<T *>(c_array[i]->opaque()));
+ }
+
+ typedef typename CUDAComplexT<T>::type CUDA_T;
+ SE_ASSIGN_OR_RETURN(
+ std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_ptr_array,
+ stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
+ SE_ASSIGN_OR_RETURN(
+ std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_ptr_array,
+ stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
+ SE_ASSIGN_OR_RETURN(
+ std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_ptr_array,
+ stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
+
+ if (!stream->ThenMemcpy(a_ptr_array->mutable_device_memory(),
+ a_ptr_vec.data(), batch_count * sizeof(T *))
+ .ok() ||
+ !stream->ThenMemcpy(b_ptr_array->mutable_device_memory(),
+ b_ptr_vec.data(), batch_count * sizeof(T *))
+ .ok() ||
+ !stream->ThenMemcpy(c_ptr_array->mutable_device_memory(),
+ c_ptr_vec.data(), batch_count * sizeof(T *))
+ .ok()) {
+ return port::Status(port::error::INTERNAL,
+ "failed to copy memory from host to device in "
+ "CUDABlas::DoBlasGemmBatched");
+ }
+
+ bool ok = DoBlasInternal(
+ cublas_func, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha),
+ const_cast<const CUDA_T **>(CUDAMemory(a_ptr_array->device_memory())),
+ lda,
+ const_cast<const CUDA_T **>(CUDAMemory(b_ptr_array->device_memory())),
+ ldb, CUDAComplex(&beta),
+ const_cast<CUDA_T **>(CUDAMemory(c_ptr_array->device_memory())), ldc,
+ batch_count);
+
+ if (ok) {
+ return port::Status::OK();
+ }
+ return port::Status(port::error::INTERNAL,
+ "failed BLAS call, see log for details");
+}
+
+bool CUDABlas::DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
+ const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
+ int batch_count) {
+ SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
+ dynload::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha,
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+}
+
+bool CUDABlas::DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha,
+ const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
+ const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
+ double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
+ int ldc, int batch_count) {
+ SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
+ dynload::cublasDgemmBatched, stream, transa, transb, m, n, k, alpha,
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+}
+
+bool CUDABlas::DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
+ int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
+ int ldb, std::complex<float> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
+ int ldc, int batch_count) {
+ SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
+ dynload::cublasCgemmBatched, stream, transa, transb, m, n, k, alpha,
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+}
+
+bool CUDABlas::DoBlasGemmBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
+ int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
+ int ldb, std::complex<double> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
+ int ldc, int batch_count) {
+ SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
+ dynload::cublasZgemmBatched, stream, transa, transb, m, n, k, alpha,
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+}
+
+bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasChemm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasZhemm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ float alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc) {
+ return DoBlasInternal(dynload::cublasCherk, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
+ k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ &beta, CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ double alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc) {
+ return DoBlasInternal(dynload::cublasZherk, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
+ k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ &beta, CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc) {
+ return DoBlasInternal(dynload::cublasCher2k, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
+ k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemory(b)), ldb, &beta,
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc) {
+ return DoBlasInternal(dynload::cublasZher2k, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
+ k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemory(b)), ldb, &beta,
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasSsymm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, &alpha, CUDAMemory(a),
+ lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasDsymm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, &alpha, CUDAMemory(a),
+ lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasCsymm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasZsymm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
+ CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ float beta, DeviceMemory<float> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasSsyrk, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
+ CUDAMemory(a), lda, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ double beta, DeviceMemory<double> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasDsyrk, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
+ CUDAMemory(a), lda, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasCsyrk, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasZsyrk, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasSsyr2k, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
+ CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ double alpha, const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc) {
+ return DoBlasInternal(
+ dynload::cublasDsyr2k, stream, true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
+ CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
+}
+
+bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ return DoBlasInternal(dynload::cublasCsyr2k, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
+ k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
+ blas::Transpose trans, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ return DoBlasInternal(dynload::cublasZsyr2k, stream,
+ true /* = pointer_mode_host */,
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
+ k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
+ CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc);
+}
+
+bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *b, int ldb) {
+ return DoBlasInternal(
+ dynload::cublasStrmm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), lda,
+ CUDAMemoryMutable(b), ldb, CUDAMemoryMutable(b), ldb);
+}
+
+bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *b, int ldb) {
+ return DoBlasInternal(
+ dynload::cublasDtrmm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), lda,
+ CUDAMemoryMutable(b), ldb, CUDAMemoryMutable(b), ldb);
+}
+
+bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb) {
+ return DoBlasInternal(
+ dynload::cublasCtrmm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb,
+ CUDAComplex(CUDAMemoryMutable(b)), ldb);
+}
+
+bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb) {
+ return DoBlasInternal(
+ dynload::cublasZtrmm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb,
+ CUDAComplex(CUDAMemoryMutable(b)), ldb);
+}
+
+bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *b, int ldb) {
+ return DoBlasInternal(dynload::cublasStrsm, stream,
+ true /* = pointer_mode_host */, CUDABlasSide(side),
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a),
+ lda, CUDAMemoryMutable(b), ldb);
+}
+
+bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *b, int ldb) {
+ return DoBlasInternal(dynload::cublasDtrsm, stream,
+ true /* = pointer_mode_host */, CUDABlasSide(side),
+ CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a),
+ lda, CUDAMemoryMutable(b), ldb);
+}
+
+bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb) {
+ return DoBlasInternal(
+ dynload::cublasCtrsm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb);
+}
+
+bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
+ blas::UpperLower uplo, blas::Transpose transa,
+ blas::Diagonal diag, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb) {
+ return DoBlasInternal(
+ dynload::cublasZtrsm, stream, true /* = pointer_mode_host */,
+ CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
+ CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
+ CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb);
+}
+
+} // namespace cuda
+
+namespace gpu = ::perftools::gputools;
+
+void initialize_cublas() {
+ gpu::port::Status status =
+ gpu::PluginRegistry::Instance()
+ ->RegisterFactory<gpu::PluginRegistry::BlasFactory>(
+ gpu::cuda::kCudaPlatformId, gpu::cuda::kCuBlasPlugin, "cuBLAS",
+ [](gpu::internal::StreamExecutorInterface
+ *parent) -> gpu::blas::BlasSupport * {
+ gpu::cuda::CUDAExecutor *cuda_executor =
+ dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
+ if (cuda_executor == nullptr) {
+ LOG(ERROR)
+ << "Attempting to initialize an instance of the cuBLAS "
+ << "support library with a non-CUDA StreamExecutor";
+ return nullptr;
+ }
+
+ gpu::cuda::CUDABlas *blas =
+ new gpu::cuda::CUDABlas(cuda_executor);
+ if (!blas->Init()) {
+ // Note: Init() will log a more specific error.
+ delete blas;
+ return nullptr;
+ }
+ return blas;
+ });
+
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to register cuBLAS factory: "
+ << status.error_message();
+ }
+
+ // Prime the cuBLAS DSO. The loader will log more information.
+ auto statusor = gpu::internal::CachedDsoLoader::GetCublasDsoHandle();
+ if (!statusor.ok()) {
+ LOG(INFO) << "Unable to load cuBLAS DSO.";
+ }
+
+ gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
+ gpu::PluginKind::kBlas,
+ gpu::cuda::kCuBlasPlugin);
+}
+
+} // namespace gputools
+} // namespace perftools
+
+REGISTER_MODULE_INITIALIZER(register_cublas,
+ { perftools::gputools::initialize_cublas(); });
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
new file mode 100644
index 0000000000..1dfec2ebc5
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -0,0 +1,100 @@
+// CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library
+// capabilities, and is only included into CUDA implementation code -- it will
+// not introduce cuda headers into other code.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+
+typedef struct cublasContext *cublasHandle_t;
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+
+namespace cuda {
+
+// Opaque and unique identifier for the cuBLAS plugin.
+extern const PluginId kCuBlasPlugin;
+
+class CUDAExecutor;
+
+// BLAS plugin for CUDA platform via cuBLAS library.
+//
+// This satisfies the platform-agnostic BlasSupport interface.
+//
+// Note that the cuBLAS handle that this encapsulates is implicitly tied to the
+// context (and, as a result, the device) that the parent CUDAExecutor is tied
+// to. This simply happens as an artifact of creating the cuBLAS handle when a
+// CUDA context is active.
+//
+// Thread-safe post-initialization.
+class CUDABlas : public blas::BlasSupport {
+ public:
+ explicit CUDABlas(CUDAExecutor *parent);
+
+ // Allocates a cuBLAS handle.
+ bool Init();
+
+ // Releases the cuBLAS handle, if present.
+ ~CUDABlas() override;
+
+ TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
+
+ private:
+ // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream.
+ //
+ // cuBLAS is stateful, and only be associated with one stream (in order to
+ // enqueue dispatch) at a given time. As a result, this generally must be
+ // invoked before calling into cuBLAS.
+ bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // A helper function that calls the real cuBLAS function together with error
+ // handling.
+ //
+ // cublas_func: cuBLAS function pointer.
+ // cublas_name: cuBLAS function name.
+ // stream: Stream to enqueue the BLAS operation onto.
+ // pointer_mode_host: Indicate if the pointer to a scalar value is from host
+ // (true) or device (false).
+ // args: Arguments of cuBLAS function.
+ template <typename FuncT, typename... Args>
+ bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
+ Args... args);
+
+ // A helper function to implement DoBlasGemmBatched interfaces for generic
+ // types.
+ template <typename T, typename FuncT>
+ port::Status DoBlasGemmBatchedInternal(
+ FuncT cublas_func, Stream *stream, blas::Transpose transa,
+ blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
+ const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
+ const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
+ const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
+ int batch_count);
+
+ // mutex that guards the cuBLAS handle for this device.
+ mutex mu_;
+
+ // CUDAExecutor which instantiated this CUDABlas.
+ // Immutable post-initialization.
+ CUDAExecutor *parent_;
+
+ // cuBLAS library handle on the device.
+ cublasHandle_t blas_ GUARDED_BY(mu_);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
new file mode 100644
index 0000000000..c01c9978a1
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
@@ -0,0 +1,260 @@
+#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
+
+#include <dirent.h>
+#include <limits.h>
+#include <link.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/sysmacros.h>
+#include <unistd.h>
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
+#include "tensorflow/stream_executor/lib/numbers.h"
+#include "tensorflow/stream_executor/lib/process_state.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/str_util.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+static const char *kDriverVersionPath = "/proc/driver/nvidia/version";
+
+string DriverVersionToString(DriverVersion version) {
+ return port::Printf("%d.%d", std::get<0>(version), std::get<1>(version));
+}
+
+string DriverVersionStatusToString(port::StatusOr<DriverVersion> version) {
+ if (!version.ok()) {
+ return version.status().ToString();
+ }
+
+ return DriverVersionToString(version.ValueOrDie());
+}
+
+port::StatusOr<DriverVersion> StringToDriverVersion(const string &value) {
+ std::vector<string> pieces = port::Split(value, '.');
+ if (pieces.size() != 2) {
+ return port::Status{
+ port::error::INVALID_ARGUMENT,
+ port::Printf("expected %%d.%%d form for driver version; got \"%s\"",
+ value.c_str())};
+ }
+
+ int major;
+ int minor;
+ if (!port::safe_strto32(pieces[0], &major)) {
+ return port::Status{
+ port::error::INVALID_ARGUMENT,
+ port::Printf("could not parse major version number \"%s\" as an "
+ "integer from string \"%s\"",
+ pieces[0].c_str(), value.c_str())};
+ }
+ if (!port::safe_strto32(pieces[1], &minor)) {
+ return port::Status{
+ port::error::INVALID_ARGUMENT,
+ port::Printf("could not parse minor version number \"%s\" as an "
+ "integer from string \"%s\"",
+ pieces[1].c_str(), value.c_str())};
+ }
+
+ DriverVersion result{major, minor};
+ VLOG(2) << "version string \"" << value << "\" made value "
+ << DriverVersionToString(result);
+ return result;
+}
+
+// -- class Diagnostician
+
+string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
+ return port::StrCat("/dev/nvidia", dev_node_ordinal);
+}
+
+void Diagnostician::LogDiagnosticInformation() {
+ if (access(kDriverVersionPath, F_OK) != 0) {
+ LOG(INFO) << "kernel driver does not appear to be running on this host "
+ << "(" << port::Hostname() << "): "
+ << "/proc/driver/nvidia/version does not exist";
+ return;
+ }
+ auto dev0_path = GetDevNodePath(0);
+ if (access(dev0_path.c_str(), F_OK) != 0) {
+ LOG(INFO) << "no NVIDIA GPU device is present: " << dev0_path
+ << " does not exist";
+ return;
+ }
+
+ LOG(INFO) << "retrieving CUDA diagnostic information for host: "
+ << port::Hostname();
+
+
+ LogDriverVersionInformation();
+}
+
+/* static */ void Diagnostician::LogDriverVersionInformation() {
+ LOG(INFO) << "hostname: " << port::Hostname();
+
+ if (VLOG_IS_ON(1)) {
+ const char *value = getenv("LD_LIBRARY_PATH");
+ string library_path = value == nullptr ? "" : value;
+ VLOG(1) << "LD_LIBRARY_PATH is: \"" << library_path << "\"";
+
+ std::vector<string> pieces = port::Split(library_path, ':');
+ for (auto piece : pieces) {
+ if (piece.empty()) {
+ continue;
+ }
+ DIR *dir = opendir(piece.c_str());
+ if (dir == nullptr) {
+ VLOG(1) << "could not open \"" << piece << "\"";
+ continue;
+ }
+ while (dirent *entity = readdir(dir)) {
+ VLOG(1) << piece << " :: " << entity->d_name;
+ }
+ closedir(dir);
+ }
+ }
+
+ port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
+ LOG(INFO) << "libcuda reported version is: "
+ << DriverVersionStatusToString(dso_version);
+
+ port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
+ LOG(INFO) << "kernel reported version is: "
+ << DriverVersionStatusToString(kernel_version);
+ if (kernel_version.ok() && dso_version.ok()) {
+ WarnOnDsoKernelMismatch(dso_version, kernel_version);
+ }
+}
+
+// Iterates through loaded DSOs with DlIteratePhdrCallback to find the
+// driver-interfacing DSO version number. Returns it as a string.
+port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
+ port::StatusOr<DriverVersion> result{port::Status{
+ port::error::NOT_FOUND,
+ "was unable to find libcuda.so DSO loaded into this program"}};
+
+ // Callback used when iterating through DSOs. Looks for the driver-interfacing
+ // DSO and yields its version number into the callback data, when found.
+ auto iterate_phdr =
+ [](struct dl_phdr_info *info, size_t size, void *data) -> int {
+ if (strstr(info->dlpi_name, "libcuda.so")) {
+ VLOG(1) << "found DLL info with name: " << info->dlpi_name;
+ char resolved_path[PATH_MAX] = {0};
+ if (realpath(info->dlpi_name, resolved_path) == nullptr) {
+ return 0;
+ }
+ VLOG(1) << "found DLL info with resolved path: " << resolved_path;
+ const char *slash = rindex(resolved_path, '/');
+ if (slash == nullptr) {
+ return 0;
+ }
+ const char *so_suffix = ".so.";
+ const char *dot = strstr(slash, so_suffix);
+ if (dot == nullptr) {
+ return 0;
+ }
+ string dso_version = dot + strlen(so_suffix);
+ // TODO(b/22689637): Eliminate the explicit namespace if possible.
+ auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
+ auto result = static_cast<port::StatusOr<DriverVersion> *>(data);
+ *result = StringToDriverVersion(stripped_dso_version);
+ return 1;
+ }
+ return 0;
+ };
+
+ dl_iterate_phdr(iterate_phdr, &result);
+
+ return result;
+}
+
+port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
+ const string &driver_version_file_contents) {
+ static const char *kDriverFilePrelude = "Kernel Module ";
+ size_t offset = driver_version_file_contents.find(kDriverFilePrelude);
+ if (offset == string::npos) {
+ return port::Status{
+ port::error::NOT_FOUND,
+ port::StrCat("could not find kernel module information in "
+ "driver version file contents: \"",
+ driver_version_file_contents, "\"")};
+ }
+
+ string version_and_rest = driver_version_file_contents.substr(
+ offset + strlen(kDriverFilePrelude), string::npos);
+ size_t space_index = version_and_rest.find(" ");
+ auto kernel_version = version_and_rest.substr(0, space_index);
+ // TODO(b/22689637): Eliminate the explicit namespace if possible.
+ auto stripped_kernel_version =
+ port::StripSuffixString(kernel_version, ".ld64");
+ return StringToDriverVersion(stripped_kernel_version);
+}
+
+void Diagnostician::WarnOnDsoKernelMismatch(
+ port::StatusOr<DriverVersion> dso_version,
+ port::StatusOr<DriverVersion> kernel_version) {
+ if (kernel_version.ok() && dso_version.ok() &&
+ dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
+ LOG(INFO) << "kernel version seems to match DSO: "
+ << DriverVersionToString(kernel_version.ValueOrDie());
+ } else {
+ LOG(ERROR) << "kernel version "
+ << DriverVersionStatusToString(kernel_version)
+ << " does not match DSO version "
+ << DriverVersionStatusToString(dso_version)
+ << " -- cannot find working devices in this configuration";
+ }
+}
+
+
+port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
+ FILE *driver_version_file = fopen(kDriverVersionPath, "r");
+ if (driver_version_file == nullptr) {
+ return port::Status{
+ port::error::PERMISSION_DENIED,
+ port::StrCat("could not open driver version path for reading: ",
+ kDriverVersionPath)};
+ }
+
+ static const int kContentsSize = 1024;
+ port::InlinedVector<char, 4> contents(kContentsSize);
+ size_t retcode =
+ fread(contents.begin(), 1, kContentsSize - 2, driver_version_file);
+ if (retcode < kContentsSize - 1) {
+ contents[retcode] = '\0';
+ }
+ contents[kContentsSize - 1] = '\0';
+
+ if (retcode != 0) {
+ LOG(INFO) << "driver version file contents: \"\"\"" << contents.begin()
+ << "\"\"\"";
+ fclose(driver_version_file);
+ return FindKernelModuleVersion(string{contents.begin()});
+ }
+
+ auto status =
+ port::Status{port::error::INTERNAL,
+ port::StrCat("failed to read driver version file contents: ",
+ kDriverVersionPath, "; ferror: ",
+ ferror(driver_version_file))};
+ fclose(driver_version_file);
+ return status;
+}
+
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.h b/tensorflow/stream_executor/cuda/cuda_diagnostics.h
new file mode 100644
index 0000000000..005b3dc310
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.h
@@ -0,0 +1,85 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
+
+#include <tuple>
+
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// e.g. DriverVersion{331, 79}
+using DriverVersion = std::tuple<int, int>;
+
+// Converts a parsed driver version to string form.
+string DriverVersionToString(DriverVersion version);
+
+// Converts a parsed driver version or status value to natural string form.
+string DriverVersionStatusToString(port::StatusOr<DriverVersion> version);
+
+// Converts a string of a form like "331.79" to a DriverVersion{331, 79}.
+port::StatusOr<DriverVersion> StringToDriverVersion(const string &value);
+
+class Diagnostician {
+ public:
+ // Logs diagnostic information when CUDA appears to be misconfigured (e.g. is
+ // not initializing).
+ //
+ // Note: if we're running on a machine that has no GPUs, we don't want to
+ // produce very much log spew beyond saying, "looks like there's no CUDA
+ // kernel
+ // module running".
+ //
+ // Note: we use non-Google-File:: API here because we may be called before
+ // InitGoogle has completed.
+ static void LogDiagnosticInformation();
+
+ // Given the driver version file contents, finds the kernel module version and
+ // returns it as a string.
+ //
+ // This is solely used for more informative log messages when the user is
+ // running on a machine that happens to have a libcuda/kernel driver mismatch.
+ static port::StatusOr<DriverVersion> FindKernelModuleVersion(
+ const string &driver_version_file_contents);
+
+ // Extracts the kernel driver version from the current host.
+ static port::StatusOr<DriverVersion> FindKernelDriverVersion();
+
+ // Iterates through loaded DSOs with DlIteratePhdrCallback to find the
+ // driver-interfacing DSO version number. Returns it as a string.
+ static port::StatusOr<DriverVersion> FindDsoVersion();
+
+ // Logs information about the kernel driver version and userspace driver
+ // library version.
+ static void LogDriverVersionInformation();
+
+ private:
+ // Logs information about the loaded nvidia-related kernel modules.
+ static void LogKernelModuleInformation();
+
+ // Given the DSO version number and the driver version file contents, extracts
+ // the driver version and compares, warning the user in the case of
+ // incompatability.
+ //
+ // This is solely used for more informative log messages when the user is
+ // running on a machine that happens to have a libcuda/kernel driver mismatch.
+ static void WarnOnDsoKernelMismatch(
+ port::StatusOr<DriverVersion> dso_version,
+ port::StatusOr<DriverVersion> kernel_version);
+
+ // Logs information about the dev nodes present on this machine: their
+ // existence, permissions, accessibility from this uid/gid.
+ static void LogDevNodeDiagnosticInformation();
+
+ static string GetDevNodePath(int dev_node_ordinal);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(Diagnostician);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
new file mode 100644
index 0000000000..6e4403512b
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -0,0 +1,1074 @@
+#include "tensorflow/stream_executor/cuda/cuda_dnn.h"
+
+#include <dlfcn.h>
+#include <functional>
+
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/threadpool.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "third_party/gpus/cuda/include/cudnn.h"
+
+namespace {
+
+// Converts (via narrowing) a type T value to a type U, and checks that the
+// value has no value change due to the conversion.
+template <typename WideT, typename NarrowT>
+NarrowT CheckedNarrowing(const WideT& wide) {
+ NarrowT narrow = wide;
+ CHECK_EQ(narrow, wide)
+ << "checked narrowing failed; values not equal post-conversion";
+ return narrow;
+}
+
+} // namespace
+
+namespace perftools {
+namespace gputools {
+
+using dnn::BatchDescriptor;
+using dnn::FilterDescriptor;
+using dnn::ConvolutionDescriptor;
+using dnn::PoolingDescriptor;
+
+namespace cuda {
+
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
+
+extern CUstream AsCUDAStreamValue(Stream* stream);
+
+string ToString(cudnnStatus_t status) {
+ switch (status) {
+ case CUDNN_STATUS_SUCCESS:
+ return "CUDNN_STATUS_SUCCESS";
+ case CUDNN_STATUS_NOT_INITIALIZED:
+ return "CUDNN_STATUS_NOT_INITIALIZED";
+ case CUDNN_STATUS_ALLOC_FAILED:
+ return "CUDNN_STATUS_ALLOC_FAILED";
+ case CUDNN_STATUS_BAD_PARAM:
+ return "CUDNN_STATUS_BAD_PARAM";
+ case CUDNN_STATUS_INTERNAL_ERROR:
+ return "CUDNN_STATUS_INTERNAL_ERROR";
+ case CUDNN_STATUS_INVALID_VALUE:
+ return "CUDNN_STATUS_INVALID_VALUE";
+ case CUDNN_STATUS_ARCH_MISMATCH:
+ return "CUDNN_STATUS_ARCH_MISMATCH";
+ case CUDNN_STATUS_MAPPING_ERROR:
+ return "CUDNN_STATUS_MAPPING_ERROR";
+ case CUDNN_STATUS_EXECUTION_FAILED:
+ return "CUDNN_STATUS_EXECUTION_FAILED";
+ case CUDNN_STATUS_NOT_SUPPORTED:
+ return "CUDNN_STATUS_NOT_SUPPORTED";
+ case CUDNN_STATUS_LICENSE_ERROR:
+ return "CUDNN_STATUS_LICENSE_ERROR";
+ default:
+ return port::StrCat("<unknown cudnn status: ", static_cast<int>(status),
+ ">");
+ }
+}
+
+namespace dynload {
+
+static port::ThreadPool* InitCudnnThreadpool() {
+ port::ThreadPool* cudnn_threadpool_;
+ port::ThreadOptions options;
+ // TBD(keveman): Conservatively setting the stack size and guard size to 2MB,
+ // until we can get some guarantees from NVIDIA on the minimum stack space
+ // they will work with.
+ options.stack_size = 2 * 1024 * 1024;
+ options.guard_size = 2 * 1024 * 1024;
+ cudnn_threadpool_ = new port::ThreadPool(port::Env::Default(), options,
+ "cudnn_threadpool", 1);
+ CHECK(cudnn_threadpool_);
+ return cudnn_threadpool_;
+}
+
+static mutex cudnn_threadpool_mu(LINKER_INITIALIZED);
+static port::ThreadPool* GetCudaThreadpool() {
+ mutex_lock lock(cudnn_threadpool_mu);
+ static port::ThreadPool* cudnn_threadpool = InitCudnnThreadpool();
+ return cudnn_threadpool;
+}
+
+#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \
+ struct DynLoadShim__##__name { \
+ static const char* kName; \
+ typedef std::add_pointer<decltype(::__name)>::type FuncPointerT; \
+ static void* GetDsoHandle() { \
+ static auto result = internal::CachedDsoLoader::GetCudnnDsoHandle(); \
+ return result.ValueOrDie(); \
+ } \
+ static FuncPointerT DynLoad() { \
+ static void* f = dlsym(GetDsoHandle(), kName); \
+ if (f == nullptr) { \
+ LOG(FATAL) << "could not find " << kName \
+ << " in cudnn DSO; dlerror: " << dlerror(); \
+ } \
+ return reinterpret_cast<FuncPointerT>(f); \
+ } \
+ template <typename... Args> \
+ void CallWrapper(CUDAExecutor* parent, port::Notification* n, \
+ cudnnStatus_t* retval, const Args&... args) { \
+ cuda::ScopedActivateExecutorContext sac{parent}; \
+ *retval = DynLoad()(args...); \
+ n->Notify(); \
+ } \
+ template <typename... Args> \
+ cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \
+ port::Notification n; \
+ cudnnStatus_t retval; \
+ auto call_func_closure = \
+ std::bind(&DynLoadShim__##__name::CallWrapper<Args...>, this, \
+ parent, &n, &retval, args...); \
+ GetCudaThreadpool()->Schedule(call_func_closure); \
+ n.WaitForNotification(); \
+ return retval; \
+ } \
+ } __name; \
+ const char* DynLoadShim__##__name::kName = #__name;
+
+#define CUDNN_DNN_ROUTINE_EACH(__macro) \
+ __macro(cudnnSetTensor4dDescriptor) __macro( \
+ cudnnGetConvolutionNdForwardOutputDim) \
+ __macro(cudnnGetConvolutionForwardAlgorithm) __macro( \
+ cudnnCreateTensorDescriptor) __macro(cudnnDestroyTensorDescriptor) \
+ __macro(cudnnCreateFilterDescriptor) \
+ __macro(cudnnSetFilter4dDescriptor) \
+ __macro(cudnnSetPooling2dDescriptor) \
+ __macro(cudnnDestroyFilterDescriptor) \
+ __macro(cudnnCreateConvolutionDescriptor) \
+ __macro(cudnnCreatePoolingDescriptor) \
+ __macro(cudnnAddTensor) \
+ __macro(cudnnDestroyPoolingDescriptor)
+
+CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
+#undef CUDNN_DNN_ROUTINE_EACH
+
+// clang-format off
+#define CUDNN_DNN_ROUTINE_EACH(__macro) \
+ __macro(cudnnSetConvolution2dDescriptor) \
+ __macro(cudnnDestroyConvolutionDescriptor) \
+ __macro(cudnnCreate) \
+ __macro(cudnnDestroy) \
+ __macro(cudnnSetStream) \
+ __macro(cudnnActivationForward) \
+ __macro(cudnnConvolutionForward) \
+ __macro(cudnnConvolutionBackwardData) \
+ __macro(cudnnConvolutionBackwardFilter) \
+ __macro(cudnnGetConvolutionForwardWorkspaceSize) \
+ __macro(cudnnTransformTensor) \
+ __macro(cudnnPoolingForward) \
+ __macro(cudnnPoolingBackward)
+// clang-format on
+
+CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
+#undef CUDNN_DNN_ROUTINE_EACH
+
+} // namespace dynload
+
+namespace {
+
+cudnnHandle_t ToHandle(void* opaque_handle) {
+ return static_cast<cudnnHandle_t>(opaque_handle);
+}
+
+} // namespace
+
+CudnnSupport::CudnnSupport(CUDAExecutor* parent)
+ : parent_(parent), dnn_handle_(nullptr) {}
+
+CudnnSupport::~CudnnSupport() {
+ auto status = dynload::cudnnDestroy(parent_, ToHandle(dnn_handle_));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "could not destroy cudnn handle: " << ToString(status);
+ }
+}
+
+port::Status CudnnSupport::Init() {
+ auto status = dynload::cudnnCreate(
+ parent_, reinterpret_cast<cudnnHandle_t*>(&dnn_handle_));
+ if (status == CUDNN_STATUS_SUCCESS) {
+ return port::Status::OK();
+ }
+
+ LOG(ERROR) << "could not create cudnn handle: " << ToString(status);
+ if (status == CUDNN_STATUS_NOT_INITIALIZED) {
+ // This is the error code that the driver returns when we're not running a
+ // sufficient CUDA driver -- cudnn requires 6.5+ compatibility, which
+ // starts with the 340.XX driver series.
+ auto result = cuda::Diagnostician::FindKernelDriverVersion();
+ if (!result.ok()) {
+ LOG(ERROR) << "error retrieving driver version: "
+ << DriverVersionStatusToString(result);
+ } else {
+ const auto& version = result.ValueOrDie();
+ LOG(INFO) << "running driver version: " << DriverVersionToString(version);
+ if (std::get<0>(version) < 340) {
+ LOG(ERROR)
+ << "cudnn library is only supported on 340.XX+ driver versions";
+ }
+ }
+ }
+ return port::Status{port::error::INTERNAL,
+ port::StrCat("cudnn library could not create a handle: ",
+ ToString(status))};
+}
+
+// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
+class ScopedTensorDescriptor {
+ public:
+ ScopedTensorDescriptor(CUDAExecutor* parent,
+ const BatchDescriptor& batch_descriptor,
+ cudnnDataType_t elem_type)
+ : parent_(parent), handle_(nullptr) {
+ cudnnStatus_t status =
+ dynload::cudnnCreateTensorDescriptor(parent_, &handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not create cudnn tensor descriptor: "
+ << ToString(status);
+ }
+
+ cudnnTensorFormat_t format;
+ switch (batch_descriptor.layout()) {
+ case dnn::DataLayout::kBatchYXDepth:
+ format = CUDNN_TENSOR_NHWC;
+ break;
+ case dnn::DataLayout::kBatchDepthYX:
+ format = CUDNN_TENSOR_NCHW;
+ break;
+ default:
+ LOG(FATAL) << "Unsupported tensor format "
+ << DataLayoutString(batch_descriptor.layout());
+ break;
+ }
+
+ status = dynload::cudnnSetTensor4dDescriptor(
+ parent_, handle_, format, elem_type,
+ CheckedNarrowing<int64, int>(batch_descriptor.count()),
+ CheckedNarrowing<int64, int>(batch_descriptor.feature_map_count()),
+ CheckedNarrowing<int64, int>(batch_descriptor.height()),
+ CheckedNarrowing<int64, int>(batch_descriptor.width()));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not set cudnn tensor descriptor: "
+ << ToString(status);
+ }
+ }
+
+ ~ScopedTensorDescriptor() {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyTensorDescriptor(parent_, handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "could not destroy cudnn tensor descriptor: "
+ << ToString(status);
+ }
+ }
+
+ cudnnTensorDescriptor_t handle() const { return handle_; }
+
+ private:
+ CUDAExecutor* parent_; // Parent executor. Not owned.
+ cudnnTensorDescriptor_t handle_; // Owned.
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
+};
+
+// Turns a FilterDescriptor structure into a cudnn filter handle within a scope.
+class ScopedFilterDescriptor {
+ public:
+ ScopedFilterDescriptor(CUDAExecutor* parent,
+ const FilterDescriptor& filter_descriptor,
+ cudnnDataType_t elem_type)
+ : parent_(parent), handle_(nullptr) {
+ cudnnStatus_t status =
+ dynload::cudnnCreateFilterDescriptor(parent_, &handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not create cudnn filter descriptor: "
+ << ToString(status);
+ }
+
+ // TODO(b/23032134): Even if the filter layout is not supported,
+ // cudnnSetFilter4DDescriptor will return CUDNN_STATUS_SUCCESS because it
+ // does not take layout as an input. Maybe force cuDNN by giving wrong
+ // inputs intentionally?
+ switch (filter_descriptor.layout()) {
+ case dnn::FilterLayout::kOutputInputYX:
+ break;
+ default:
+ LOG(FATAL) << "Unsupported filter format "
+ << FilterLayoutString(filter_descriptor.layout());
+ break;
+ }
+
+ status = dynload::cudnnSetFilter4dDescriptor(
+ parent_, handle_, elem_type,
+ CheckedNarrowing<int64, int>(
+ filter_descriptor.output_feature_map_count()),
+ CheckedNarrowing<int64, int>(
+ filter_descriptor.input_feature_map_count()),
+ CheckedNarrowing<int64, int>(filter_descriptor.input_filter_height()),
+ CheckedNarrowing<int64, int>(filter_descriptor.input_filter_width()));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not set cudnn filter descriptor: "
+ << ToString(status);
+ }
+ }
+
+ ~ScopedFilterDescriptor() {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyFilterDescriptor(parent_, handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "could not destroy cudnn filter descriptor: "
+ << ToString(status);
+ }
+ }
+
+ cudnnFilterDescriptor_t handle() const { return handle_; }
+
+ private:
+ // Parent executor object. Not owned.
+ CUDAExecutor* parent_;
+
+ // cudnn filter descriptor this object creates. Owned.
+ cudnnFilterDescriptor_t handle_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
+};
+
+// Turns a ConvolutionDescriptor structure into a cudnn convolution handle
+// within a scope.
+class ScopedConvolutionDescriptor {
+ public:
+ ScopedConvolutionDescriptor(
+ CUDAExecutor* parent, const ConvolutionDescriptor& convolution_descriptor)
+ : parent_(parent), handle_(nullptr) {
+ cudnnStatus_t status =
+ dynload::cudnnCreateConvolutionDescriptor(parent_, &handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not create cudnn convolution descriptor: "
+ << ToString(status);
+ }
+
+ status = dynload::cudnnSetConvolution2dDescriptor(
+ parent_, handle_, CheckedNarrowing<int64, int>(
+ convolution_descriptor.zero_padding_height()),
+ CheckedNarrowing<int64, int>(
+ convolution_descriptor.zero_padding_width()),
+ CheckedNarrowing<int64, int>(
+ convolution_descriptor.vertical_filter_stride()),
+ CheckedNarrowing<int64, int>(
+ convolution_descriptor.horizontal_filter_stride()),
+ // TODO(leary) not sure what the following two params do.
+ 1 /* = upscale_input_x */, 1 /* = upscale_input_y */,
+ // NOTE(keveman): cuDNN supports convolution and cross correlation.
+ // However, almost all the use cases do cross correlation, so just hard
+ // coding it here.
+ CUDNN_CROSS_CORRELATION);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not set cudnn convolution descriptor: "
+ << ToString(status);
+ }
+ }
+
+ ~ScopedConvolutionDescriptor() {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyConvolutionDescriptor(parent_, handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "could not destroy cudnn convolution descriptor: "
+ << ToString(status);
+ }
+ }
+
+ cudnnConvolutionDescriptor_t handle() const { return handle_; }
+
+ private:
+ CUDAExecutor* parent_; // Parent executor. Not owned.
+ cudnnConvolutionDescriptor_t handle_; // Owned.
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
+};
+
+// Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
+// within a scope.
+class ScopedPoolingDescriptor {
+ public:
+ ScopedPoolingDescriptor(CUDAExecutor* parent,
+ const PoolingDescriptor& pooling_descriptor)
+ : parent_(parent), handle_(nullptr) {
+ cudnnStatus_t status =
+ dynload::cudnnCreatePoolingDescriptor(parent_, &handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not create cudnn pooling descriptor: "
+ << ToString(status);
+ }
+ status = dynload::cudnnSetPooling2dDescriptor(
+ parent_, handle_,
+ (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
+ ? CUDNN_POOLING_MAX
+ : CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING),
+ CheckedNarrowing<int64, int>(pooling_descriptor.window_height()),
+ CheckedNarrowing<int64, int>(pooling_descriptor.window_width()),
+ CheckedNarrowing<int64, int>(pooling_descriptor.vertical_padding()),
+ CheckedNarrowing<int64, int>(pooling_descriptor.horizontal_padding()),
+ CheckedNarrowing<int64, int>(pooling_descriptor.vertical_stride()),
+ CheckedNarrowing<int64, int>(pooling_descriptor.horizontal_stride()));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "could not set cudnn pooling descriptor: "
+ << ToString(status);
+ }
+ }
+ ~ScopedPoolingDescriptor() {
+ cudnnStatus_t status =
+ dynload::cudnnDestroyPoolingDescriptor(parent_, handle_);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "could not destroy cudnn pooling descriptor: "
+ << ToString(status);
+ }
+ }
+
+ cudnnPoolingDescriptor_t handle() const { return handle_; }
+
+ private:
+ CUDAExecutor* parent_; // Parent executor. Not owned.
+ cudnnPoolingDescriptor_t handle_; // Owned.
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
+};
+
+bool CudnnSupport::DoConvolve(
+ Stream* stream, const BatchDescriptor& batch_descriptor,
+ const DeviceMemory<float>& input_data,
+ const FilterDescriptor& filter_descriptor,
+ const DeviceMemory<float>& filter_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const BatchDescriptor& output_descriptor,
+ DeviceMemory<float>* output_data) {
+ ScopedTensorDescriptor input_4d{parent_, batch_descriptor, CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor output_4d{parent_, output_descriptor,
+ CUDNN_DATA_FLOAT};
+ ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT};
+ ScopedConvolutionDescriptor conv{parent_, convolution_descriptor};
+
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
+ }
+ // Alpha is the scaling factor for input.
+ float alpha = 1.0;
+ // Beta is the scaling factor for output.
+ float beta = 0.0;
+
+ // The NO_WORKSPACE versions are possibly slower for certain shapes, but
+ // not so for the shapes currently used by Brain. Also, it seems prudent to
+ // keep cuMemAlloc off the critical path.
+ cudnnConvolutionFwdAlgo_t algo;
+ status = dynload::cudnnGetConvolutionForwardAlgorithm(
+ parent_, ToHandle(dnn_handle_), input_4d.handle(), filter.handle(),
+ conv.handle(), output_4d.handle(), CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0,
+ &algo);
+
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
+ << "Unable to find a suitable algorithm for doing forward convolution";
+
+ status = dynload::cudnnConvolutionForward(
+ parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(),
+ input_data.opaque(), filter.handle(), filter_data.opaque(), conv.handle(),
+ algo, nullptr /* workspace ptr */, 0 /* workspace size */, &beta,
+ output_4d.handle(), output_data->opaque());
+
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ return false;
+ }
+
+ return true;
+}
+
+bool CudnnSupport::DoConvolve(
+ Stream* stream, const BatchDescriptor& batch_descriptor,
+ const DeviceMemory<double>& input_data,
+ const FilterDescriptor& filter_descriptor,
+ const DeviceMemory<double>& filter_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const BatchDescriptor& output_descriptor,
+ DeviceMemory<double>* output_data) {
+ LOG(ERROR) << "double-based DNN not yet implemented";
+ return false;
+}
+
+DeviceMemory<float> CudnnSupport::MaybeTransformLayout(
+ Stream* stream, BatchDescriptor* output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ std::unique_ptr<TemporaryDeviceMemory<float>>* transform_scratch) {
+ if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
+ return backward_output_data;
+ }
+ CHECK(output_descriptor->layout() == dnn::DataLayout::kBatchYXDepth);
+ *transform_scratch =
+ stream->AllocateTemporaryArray<float>(backward_output_data.ElementCount())
+ .ConsumeValueOrDie();
+ BatchDescriptor transformed_output_descriptor;
+ transformed_output_descriptor.CloneFrom(*output_descriptor);
+ transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
+ ScopedTensorDescriptor orig_out_back_4d{parent_, *output_descriptor,
+ CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor transformed_out_back_4d{
+ parent_, transformed_output_descriptor, CUDNN_DATA_FLOAT};
+
+ float alpha = 1.0f;
+ float beta = 0.0f;
+ auto status = dynload::cudnnTransformTensor(
+ parent_, ToHandle(dnn_handle_), &alpha, orig_out_back_4d.handle(),
+ backward_output_data.opaque(), &beta, transformed_out_back_4d.handle(),
+ (*transform_scratch)->mutable_device_memory()->opaque());
+
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "Failed to transform the data layout.";
+ }
+ output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX);
+ return (*transform_scratch)->device_memory();
+}
+
+bool CudnnSupport::DoConvolveBackwardData(
+ Stream* stream, const FilterDescriptor& filter_descriptor,
+ const DeviceMemory<float>& filter_data,
+ const BatchDescriptor& output_descriptor_in,
+ DeviceMemory<float> backward_output_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const BatchDescriptor& input_descriptor,
+ DeviceMemory<float>* backward_input_data) {
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
+ }
+
+ // Alpha is the scaling factor for input.
+ float alpha = 1.0;
+ // Beta is the scaling factor for output.
+ float beta = 0.0;
+
+ // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
+ BatchDescriptor output_descriptor;
+ output_descriptor.CloneFrom(output_descriptor_in);
+ std::unique_ptr<TemporaryDeviceMemory<float>> transform_scratch;
+ backward_output_data = MaybeTransformLayout(
+ stream, &output_descriptor, backward_output_data, &transform_scratch);
+
+ ScopedTensorDescriptor out_back_4d{parent_, output_descriptor,
+ CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor in_back_4d{parent_, input_descriptor,
+ CUDNN_DATA_FLOAT};
+ ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT};
+ ScopedConvolutionDescriptor conv{parent_, convolution_descriptor};
+
+ status = dynload::cudnnConvolutionBackwardData(
+ parent_, ToHandle(dnn_handle_), &alpha, filter.handle(),
+ filter_data.opaque(), out_back_4d.handle(), backward_output_data.opaque(),
+ conv.handle(), &beta, in_back_4d.handle(), backward_input_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+}
+
+bool CudnnSupport::DoConvolveBackwardFilter(
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_descriptor_in,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
+ DeviceMemory<float>* backward_filter_data) {
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
+ }
+
+ // Alpha is the scaling factor for input.
+ float alpha = 1.0;
+ // Beta is the scaling factor for output.
+ float beta = 0.0;
+
+ // TBD(keveman): remove once cuDNN supports kBatchYXDepth for backward pass.
+ BatchDescriptor output_descriptor;
+ output_descriptor.CloneFrom(output_descriptor_in);
+ std::unique_ptr<TemporaryDeviceMemory<float>> transform_scratch;
+ backward_output_data = MaybeTransformLayout(
+ stream, &output_descriptor, backward_output_data, &transform_scratch);
+
+ ScopedTensorDescriptor out_back_4d{parent_, output_descriptor,
+ CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor input_4d{parent_, input_descriptor, CUDNN_DATA_FLOAT};
+ ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT};
+ ScopedConvolutionDescriptor conv{parent_, convolution_descriptor};
+
+ status = dynload::cudnnConvolutionBackwardFilter(
+ parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(),
+ input_data.opaque(), out_back_4d.handle(), backward_output_data.opaque(),
+ conv.handle(), &beta, filter.handle(), backward_filter_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+}
+
+bool CudnnSupport::DoMatMul(Stream* stream,
+ const DeviceMemory<float>& input_data,
+ const DeviceMemory<float>& weights,
+ const dnn::BatchDescriptor& input_dimensions,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) {
+ if (input_dimensions.count() != output_dimensions.count()) {
+ LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
+ return false;
+ }
+
+ // We do not permute the input or output, instead we just
+ // reinterpret the layout. We are working with row-major matrices
+ // and the rows of the input and output correspond to batch, so
+ // batch has to be outermost in both the input and output.
+ //
+ // By adding transposes to the BLAS gemm call we could perhaps make
+ // the kYXDepthBatch layout work as well, but there has been no need
+ // for that so far.
+ if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
+ input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
+ LOG(ERROR) << "Unsupported MatMul input layout.";
+ return false;
+ }
+ if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
+ output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
+ LOG(ERROR) << "Unsupported MatMul output layout.";
+ return false;
+ }
+
+ if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
+ // This is a fast path that also supports the kBatchYXDepth layout.
+
+ // The matrices here are in row-major format while BLAS expects
+ // column-major, i.e. our matrices are transposed as far as BLAS
+ // is concerned. So we need to compute output^T =
+ // input^T*weights^T. There is no parameter for transposing the
+ // output in BLAS gemm, but instead we can transpose both sides of
+ // the equality to see that this is equivalent to
+ // output=weights*input. So we only need to swap the order of
+ // weights and input in the matrix product to correct for the
+ // row-major versus column-major difference.
+ const float alpha = 1.0f; // Take the matrix product without scaling it.
+ const float beta = 0.0f; // Ignore the original values in output_data.
+ const int64 m = output_dimensions.NodesAcrossFeatureMaps();
+ const int64 n = input_dimensions.count();
+ const int64 k = input_dimensions.NodesAcrossFeatureMaps();
+ stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
+ blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
+ m, input_data, k, beta, output_data, m);
+ } else {
+ // This is a slower and more complex path that supports output
+ // width() * height() > 1, though it only supports the
+ // kBatchYXDepth layout. Does support kBatchDepthYX if output
+ // feature_map_count() == 1, as then there is no difference
+ // between the two layouts.
+ //
+ // The operation here is the same as above, except that we have to
+ // do the matrix multiplication for each (y,x) output coordinate
+ // separately. We then interpret weights as containing K = width()
+ // * height() different matrices, which we all multiply onto the
+ // matrix from input_data, yielding K matrix products. We then
+ // combine these together into one matrix by concatenating all the
+ // first rows of these matrices, then all the seconds rows and so
+ // on. We can do this with a batched matrix multiplication, where
+ // the result is written to a different submatrix of the output
+ // for each matrix multiplication.
+ //
+ // The reason that we only support the kBatchYXDepth output layout
+ // is that we have to do something in the depth for each (y,x)
+ // coordinate. The kBatchYXDepth layout has the depth information
+ // for each point (y,x) in contiguous memory while the
+ // kBatchDepthYX layout does not.
+ //
+ // TODO(broune): Consider a special case for when output depth ==
+ // 1, as then possibly this could all be done as one matrix
+ // multiplication instead of a batched one, which should be
+ // faster. Another possibility would be to add a weights layout
+ // parameter and then support kBatchDepthYX for a different
+ // weights layout.
+ if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
+ !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
+ output_dimensions.feature_map_count() == 1)) {
+ LOG(ERROR) << "Unsupported MatMul output layout.";
+ return false;
+ }
+
+ const float alpha = 1.0f; // Take the matrix product without scaling it.
+ const float beta = 0.0f; // Ignore the original values in output_data.
+ const uint64 m = output_dimensions.feature_map_count();
+ const uint64 n = input_dimensions.count();
+ const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
+ const int lda = m;
+ const int ldb = k;
+ const int ldc = output_dimensions.NodesAcrossFeatureMaps();
+ const int batch_count = output_dimensions.NodesPerFeatureMap();
+
+ std::vector<DeviceMemory<float>> a(batch_count);
+ std::vector<DeviceMemory<float>> b(batch_count);
+ std::vector<DeviceMemory<float>> c(batch_count);
+ for (int i = 0; i < batch_count; ++i) {
+ const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
+ output_dimensions.feature_map_count();
+ a[i] = DeviceMemory<float>::MakeFromByteSize(
+ const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
+ weights_offset,
+ weights.ElementCount() - weights_offset);
+
+ b[i] = input_data;
+
+ const int output_offset = i * output_dimensions.feature_map_count();
+ c[i] = DeviceMemory<float>::MakeFromByteSize(
+ const_cast<float*>(
+ reinterpret_cast<const float*>(output_data->opaque())) +
+ output_offset,
+ output_data->ElementCount() - output_offset);
+ }
+ const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
+ std::vector<DeviceMemory<float>*> ptrs;
+ for (auto& mem : v) {
+ ptrs.push_back(&mem);
+ }
+ return ptrs;
+ };
+
+ stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
+ blas::Transpose::kNoTranspose, m, n, k, alpha,
+ toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
+ ldc, batch_count);
+ }
+
+ return stream->ok();
+}
+
+bool CudnnSupport::DoBiasAdd(Stream* stream,
+ const DeviceMemory<float>& input_data,
+ const DeviceMemory<float>& biases,
+ const dnn::BatchDescriptor& dimensions,
+ DeviceMemory<float>* output_data) {
+ ScopedTensorDescriptor input_descriptor{parent_, dimensions,
+ CUDNN_DATA_FLOAT};
+
+ BatchDescriptor bias_dimensions;
+ bias_dimensions.set_count(1)
+ .set_feature_map_count(dimensions.feature_map_count())
+ .set_height(1)
+ .set_width(1)
+ .set_layout(dnn::DataLayout::kBatchYXDepth);
+ ScopedTensorDescriptor bias_descriptor{parent_, bias_dimensions,
+ CUDNN_DATA_FLOAT};
+
+ // cudnnAddTensor is in-place, so we need to copy input_data to
+ // output_data before doing the addition, unless the input and
+ // output are at the same address.
+ if (input_data.opaque() != output_data->opaque()) {
+ stream->ThenMemcpy(output_data, input_data,
+ dimensions.ElementCount() * sizeof(float));
+ if (!stream->ok()) {
+ LOG(ERROR)
+ << "stream " << stream
+ << " could not enqueue a tensor copy as part of bias addition.";
+ return false;
+ }
+ }
+
+ mutex_lock lock{dnn_handle_mutex_};
+
+ const float alpha = 1.0f;
+ const float beta = 1.0f;
+ auto status = dynload::cudnnAddTensor(
+ parent_, ToHandle(dnn_handle_), CUDNN_ADD_SAME_C, &alpha,
+ bias_descriptor.handle(), biases.opaque(), &beta,
+ input_descriptor.handle(), output_data->opaque());
+
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
+ return false;
+ }
+
+ return true;
+}
+
+bool CudnnSupport::DoActivate(Stream* stream,
+ dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor& dimensions,
+ const DeviceMemory<float>& input_data,
+ DeviceMemory<float>* output_data) {
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
+ return false;
+ }
+ cudnnActivationMode_t mode;
+ switch (activation_mode) {
+ case dnn::ActivationMode::kRelu6:
+ // TODO(leary) should probably do a post-pass to clip at 6?
+ LOG(WARNING) << "user requested Relu6, but providing Relu instead";
+ mode = CUDNN_ACTIVATION_RELU;
+ break;
+ case dnn::ActivationMode::kReluX:
+ // TODO(broune) should probably do a post-pass to clip at X?
+ LOG(WARNING) << "user requested ReluX, but providing Relu instead";
+ mode = CUDNN_ACTIVATION_RELU;
+ break;
+ case dnn::ActivationMode::kRelu:
+ mode = CUDNN_ACTIVATION_RELU;
+ break;
+ case dnn::ActivationMode::kSigmoid:
+ mode = CUDNN_ACTIVATION_SIGMOID;
+ break;
+ case dnn::ActivationMode::kTanh:
+ mode = CUDNN_ACTIVATION_TANH;
+ break;
+ default:
+ LOG(ERROR) << "unrecognized activation mode: "
+ << static_cast<int>(activation_mode);
+ return false;
+ }
+
+ ScopedTensorDescriptor input_4d{parent_, dimensions, CUDNN_DATA_FLOAT};
+ // Alpha is the input scaling factor.
+ float alpha = 1.0;
+ // Beta is the output scaling factor.
+ float beta = 0.0;
+ status = dynload::cudnnActivationForward(
+ parent_, ToHandle(dnn_handle_), mode, &alpha, input_4d.handle(),
+ input_data.opaque(), &beta, input_4d.handle(), output_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "stream " << stream
+ << " could not enqueue activation: " << ToString(status);
+ return false;
+ }
+
+ return true;
+}
+
+bool CudnnSupport::DoPoolForward(
+ Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) {
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
+ return false;
+ }
+
+ // Alpha is the scaling factor for input.
+ float alpha = 1.0;
+ // Beta is the scaling factor for output.
+ float beta = 0.0;
+
+ ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
+ CUDNN_DATA_FLOAT};
+ ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
+ status = dynload::cudnnPoolingForward(
+ parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+ src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
+ output_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to enqueue forward pooling on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+}
+
+bool CudnnSupport::DoPoolBackward(
+ Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ const DeviceMemory<float>& output_data,
+ const DeviceMemory<float>& input_diff_data,
+ DeviceMemory<float>* output_diff_data) {
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
+ return false;
+ }
+
+ // Alpha is the scaling factor for input.
+ float alpha = 1.0;
+ // Beta is the scaling factor for output.
+ float beta = 0.0;
+
+ ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_FLOAT};
+ ScopedTensorDescriptor dest_desc{parent_, output_dimensions,
+ CUDNN_DATA_FLOAT};
+ ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
+ status = dynload::cudnnPoolingBackward(
+ parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+ dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
+ input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
+ src_desc.handle(), output_diff_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to enqueue backward pooling on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+}
+
+bool CudnnSupport::DoNormalize(
+ Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
+ const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
+ LOG(FATAL) << "not yet implemented"; // TODO(leary)
+}
+
+bool CudnnSupport::DoDepthConcatenate(
+ Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float>*> input_data,
+ DeviceMemory<float>* output_data) {
+ LOG(FATAL) << "not yet implemented"; // TODO(leary)
+}
+
+bool CudnnSupport::DoElementwiseOperate(
+ Stream* stream, dnn::ElementwiseOperation operation,
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float>*> input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) {
+ LOG(FATAL) << "not yet implemented"; // TODO(leary)
+}
+
+bool CudnnSupport::DoMemcpyD2HQuantized(
+ Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
+ port::MutableArraySlice<uint8> host_dst) {
+ LOG(ERROR) << "quantized memcpy not supported by cuDNN";
+ return false;
+}
+
+bool CudnnSupport::DoMemcpyD2HQuantized(
+ Stream* stream, const DeviceMemory<float>& device_unquantized_src,
+ port::MutableArraySlice<uint16> host_dst) {
+ LOG(ERROR) << "quantized memcpy not supported by cuDNN";
+ return false;
+}
+
+bool CudnnSupport::DoMemcpyD2HQuantized(
+ Stream* stream, const DeviceMemory<float>& device_unquantized_src,
+ port::MutableArraySlice<int32> host_dst) {
+ LOG(ERROR) << "quantized memcpy not supported by cuDNN";
+ return false;
+}
+
+bool CudnnSupport::DoMemcpyH2DQuantized(
+ Stream* stream, port::ArraySlice<uint8> host_src,
+ DeviceMemory<float>* gpu_unquantized_dst) {
+ LOG(ERROR) << "quantized memcpy not supported by cuDNN";
+ return false;
+}
+
+bool CudnnSupport::DeriveOutputBatchDescriptor(
+ const BatchDescriptor& batch_descriptor,
+ const FilterDescriptor& filter_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ dnn::BatchDescriptor* output_batch_descriptor) {
+ ScopedTensorDescriptor input_4d{parent_, batch_descriptor, CUDNN_DATA_FLOAT};
+ ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT};
+ ScopedConvolutionDescriptor conv{parent_, convolution_descriptor};
+
+ int dims[4];
+ auto status = dynload::cudnnGetConvolutionNdForwardOutputDim(
+ parent_, conv.handle(), input_4d.handle(), filter.handle(), 4, dims);
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "could not get output tensor for convolution: "
+ << ToString(status);
+ return false;
+ }
+
+ output_batch_descriptor->set_count(dims[0])
+ .set_feature_map_count(dims[1])
+ .set_height(dims[2])
+ .set_width(dims[3])
+ .set_layout(batch_descriptor.layout());
+ return true;
+}
+
+} // namespace cuda
+
+namespace gpu = ::perftools::gputools;
+
+void initialize_cudnn() {
+ gpu::port::Status status =
+ gpu::PluginRegistry::Instance()
+ ->RegisterFactory<gpu::PluginRegistry::DnnFactory>(
+ gpu::cuda::kCudaPlatformId, gpu::cuda::kCuDnnPlugin, "cuDNN",
+ [](gpu::internal::StreamExecutorInterface*
+ parent) -> gpu::dnn::DnnSupport* {
+ gpu::cuda::CUDAExecutor* cuda_executor =
+ dynamic_cast<gpu::cuda::CUDAExecutor*>(parent);
+ if (cuda_executor == nullptr) {
+ LOG(ERROR)
+ << "Attempting to initialize an instance of the cuBLAS "
+ << "support library with a non-CUDA StreamExecutor";
+ return nullptr;
+ }
+
+ gpu::cuda::CudnnSupport* dnn =
+ new gpu::cuda::CudnnSupport(cuda_executor);
+ if (!dnn->Init().ok()) {
+ // Note: Init() will log a more specific error.
+ delete dnn;
+ return nullptr;
+ }
+ return dnn;
+ });
+
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to register cuDNN factory: "
+ << status.error_message();
+ }
+
+ // Prime the cuDNN DSO. The loader will log more information.
+ auto statusor = gpu::internal::CachedDsoLoader::GetCudnnDsoHandle();
+ if (!statusor.ok()) {
+ LOG(INFO) << "Unable to load cuDNN DSO.";
+ }
+
+ gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
+ gpu::PluginKind::kDnn,
+ gpu::cuda::kCuDnnPlugin);
+}
+
+} // namespace gputools
+} // namespace perftools
+
+REGISTER_MODULE_INITIALIZER(register_cudnn,
+ { perftools::gputools::initialize_cudnn(); });
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
new file mode 100644
index 0000000000..08e952cee0
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -0,0 +1,206 @@
+// The CUDA-specific DNN library support, implementing the general DnnSupport
+// interface.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
+
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/temporary_device_memory.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+class CUDAExecutor;
+
+// Opaque and unique identifer for the cuDNN plugin.
+extern const PluginId kCuDnnPlugin;
+
+// cudnn-library based DNN support. For details on overridden interface
+// functions, see dnn.h.
+class CudnnSupport : public dnn::DnnSupport {
+ public:
+ explicit CudnnSupport(CUDAExecutor* parent);
+ ~CudnnSupport() override;
+
+ port::Status Init() override;
+
+ bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<float>& input_data,
+ const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<float>& filter_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
+ const DeviceMemory<double>& input_data,
+ const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<double>& filter_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<double>* output_data) override;
+
+ bool DoSeparableConvolve(
+ Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
+ const DeviceMemory<float>& input_data,
+ const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
+ const DeviceMemory<float>& first_weights,
+ const DeviceMemory<float>& second_weights,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<float>* output_data) override {
+ LOG(ERROR) << "separable convolution not supported by CUDNN";
+ return false;
+ }
+
+ bool DoConvolveBackwardData(
+ Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<float>& filter_data,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& input_descriptor,
+ DeviceMemory<float>* backward_input_data) override;
+
+ bool DoConvolveBackwardFilter(
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
+ DeviceMemory<float>* backward_filter_data) override;
+
+ bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
+ const DeviceMemory<float>& weights,
+ const dnn::BatchDescriptor& input_dimensions,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
+ const DeviceMemory<int8>& quantized_weights,
+ const DeviceMemory<float>& weight_scales,
+ const dnn::BatchDescriptor& input_dimensions,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) override {
+ LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
+ return false;
+ }
+
+ bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
+ const DeviceMemory<int16>& quantized_weights,
+ const DeviceMemory<float>& weight_scales,
+ const dnn::BatchDescriptor& input_dimensions,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) override {
+ LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
+ return false;
+ }
+
+ bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
+ const DeviceMemory<float>& biases,
+ const dnn::BatchDescriptor& dimensions,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor& dimensions,
+ const DeviceMemory<float>& input_data,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoPoolForward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoPoolBackward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ const DeviceMemory<float>& output_data,
+ const DeviceMemory<float>& input_diff_data,
+ DeviceMemory<float>* output_diff_data) override;
+
+ bool DoNormalize(Stream* stream,
+ const dnn::NormalizeDescriptor& normalize_descriptor,
+ const DeviceMemory<float>& input_data,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoDepthConcatenate(
+ Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float>*> input_data,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoElementwiseOperate(
+ Stream* stream, dnn::ElementwiseOperation operation,
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float>*> input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) override;
+
+ bool DoMemcpyD2HQuantized(Stream* stream,
+ const DeviceMemory<float>& device_unquantized_src,
+ port::MutableArraySlice<uint8> host_dst) override;
+
+ bool DoMemcpyD2HQuantized(Stream* stream,
+ const DeviceMemory<float>& device_unquantized_src,
+ port::MutableArraySlice<uint16> host_dst) override;
+
+ bool DoMemcpyD2HQuantized(Stream* stream,
+ const DeviceMemory<float>& device_unquantized_src,
+ port::MutableArraySlice<int32> host_dst) override;
+
+ bool DoMemcpyH2DQuantized(
+ Stream* stream, port::ArraySlice<uint8> host_src,
+ DeviceMemory<float>* device_unquantized_dst) override;
+
+ // Derives an output batch descriptor from an input batch and convolution
+ // descriptors.
+ bool DeriveOutputBatchDescriptor(
+ const dnn::BatchDescriptor& batch_descriptor,
+ const dnn::FilterDescriptor& filter_descriptor,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ dnn::BatchDescriptor* output_batch_descriptor);
+
+ private:
+ // Guards the enqueueing of DNN operations via the dnn_handle_ below.
+ mutex dnn_handle_mutex_;
+
+ CUDAExecutor* parent_; // Parent executor object. Not owned.
+
+ // cudnn library handle. cudnnHandle_t type is not present in this header to
+ // prevent third-party library header inclusions from leaking outside the
+ // single cuda_dnn translation unit.
+ void* dnn_handle_ GUARDED_BY(dnn_handle_mutex_);
+
+ // NOTE(keveman): Temporary data layout transformation until cuDNN supports
+ // kBatchYXDepth for backward pass. This function allocates temporary memory,
+ // lays out the source data into the temporary but in the kBatchDepthXY
+ // layout, and returns the temporary memory. The caller is responsible for
+ // deallocating the temporary. Since the allocation is done using Stream's
+ // AllocateTemporaryMemory, a later BlockHostUntilDone could be used for
+ // deallocation.
+ //
+ // transform_scratch is populated with a legitimate temporary allocation iff
+ // the original output data needs to be transformed.
+ DeviceMemory<float> MaybeTransformLayout(
+ Stream* stream, dnn::BatchDescriptor* output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ std::unique_ptr<TemporaryDeviceMemory<float>>* transform_scratch)
+ EXCLUSIVE_LOCKS_REQUIRED(dnn_handle_mutex_);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
new file mode 100644
index 0000000000..8c4316b4c1
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -0,0 +1,1608 @@
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+
+#include <dlfcn.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <set>
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/lib/casts.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/human_readable.h"
+#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/threadpool.h"
+#include "tensorflow/stream_executor/lib/stacktrace.h"
+#include "tensorflow/stream_executor/lib/static_threadlocal.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
+
+bool FLAGS_gpuexec_cuda_driver_inject_init_error = false;
+bool FLAGS_gpuexec_cuda_sync_around_driver_calls = false;
+bool FLAGS_gpuexec_cuda_device_0_only = false;
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+namespace dynload {
+
+#define PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(__name) \
+ struct DynLoadShim__##__name { \
+ static const char *kName; \
+ using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
+ static void *GetDsoHandle() { \
+ static auto status = internal::CachedDsoLoader::GetLibcudaDsoHandle(); \
+ return status.ValueOrDie(); \
+ } \
+ static FuncPointerT DynLoad() { \
+ static void *f = dlsym(GetDsoHandle(), kName); \
+ CHECK(f != nullptr) << "could not find " << kName \
+ << "in libcuda DSO; dlerror: " << dlerror(); \
+ return reinterpret_cast<FuncPointerT>(f); \
+ } \
+ template <typename... Args> \
+ CUresult operator()(Args... args) { \
+ return DynLoad()(args...); \
+ } \
+ } __name; \
+ const char *DynLoadShim__##__name::kName = #__name;
+
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxCreate_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxDestroy);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxEnablePeerAccess);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxGetCurrent);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxGetDevice);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxGetSharedMemConfig);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxPopCurrent_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxSetCurrent);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxSetSharedMemConfig);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuCtxSynchronize);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceComputeCapability);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceCanAccessPeer);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGet);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetAttribute);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetCount);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetName);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetPCIBusId);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceGetProperties);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDeviceTotalMem);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuDriverGetVersion);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventCreate);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventDestroy_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventElapsedTime);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventQuery);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuEventRecord);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuFuncGetAttribute);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuFuncSetCacheConfig);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuGetErrorName);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuGetErrorString);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuInit);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuLaunchKernel);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemAlloc_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoD_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoH_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyHtoD_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoDAsync_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyDtoHAsync_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemcpyHtoDAsync_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemGetAddressRange_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemFree_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemFreeHost);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemGetInfo_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostAlloc);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostRegister_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostUnregister);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32Async);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD8_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetFunction);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetGlobal_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleLoadDataEx);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleLoadFatBinary);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleUnload);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuOccupancyMaxActiveBlocksPerMultiprocessor);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuPointerGetAttribute);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamAddCallback);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamCreate);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamDestroy_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamQuery);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamSynchronize);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuStreamWaitEvent);
+
+} // namespace dynload
+
+namespace {
+
+// Manages the singleton set of contexts that we've created. This is used for
+// checking that no CUDA-runtime-created contexts have been generated
+// accidentally. CUDA-runtime-created contexts are avoided, if triple angle
+// brace launches are required, by using the scoped activations in
+// cuda_activation.h.
+class CreatedContexts {
+ public:
+ // Returns whether context is a member of the live set.
+ static bool Has(CUcontext context) {
+ shared_lock lock{mu_};
+ return Live()->find(context) != Live()->end();
+ }
+
+ // Adds context to the live set.
+ static void Add(CUcontext context) {
+ CHECK(context != nullptr);
+ mutex_lock lock{mu_};
+ Live()->emplace(context);
+ }
+
+ // Removes context from the live set.
+ static void Remove(CUcontext context) {
+ CHECK(context != nullptr);
+ mutex_lock lock{mu_};
+ Live()->erase(context);
+ }
+
+ private:
+ // Returns the live set singleton.
+ static std::set<CUcontext> *Live() {
+ static auto singleton = new std::set<CUcontext>;
+ return singleton;
+ }
+
+ // Lock that guards access-to/mutation-of the live set.
+ static mutex mu_;
+};
+
+/* static */ mutex CreatedContexts::mu_{LINKER_INITIALIZED};
+
+// Formats CUresult to output prettified values into a log stream.
+// Error summaries taken from:
+// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc6c391505e117393cc2558fff6bfc2e9
+//
+// TODO(leary) switch to cuGetErrorName when updated cuda.h is available.
+string ToString(CUresult result) {
+#define OSTREAM_CUDA_ERROR(__name) \
+ case CUDA_ERROR_##__name: \
+ return "CUDA_ERROR_" #__name;
+
+///////////////
+// NOTE: here we specify return code values outside of the enum explicitly
+// because our in-tree cuda.h is from the CUDA 5.5 SDK, but CUDA 6.0+ driver
+// libraries are deployed in the fleet these error codes are backwards
+// compatible, but if we see a "new" one, we want to be able to identify it in
+// the logs.
+//
+// Once we get a cuda.h that has cuGetErrorName (TODO is above) we can
+// eliminate this function and just rely on the driver to provide us these
+// strings.
+//
+// NOTE: "Must reboot all context" below is shorthand for, "must
+// destroy/recreate the offending context and any allocation which come from
+// it if you are to continue using CUDA."
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wswitch"
+ switch (result) {
+ OSTREAM_CUDA_ERROR(INVALID_VALUE)
+ OSTREAM_CUDA_ERROR(OUT_OF_MEMORY)
+ OSTREAM_CUDA_ERROR(NOT_INITIALIZED)
+ OSTREAM_CUDA_ERROR(DEINITIALIZED)
+ OSTREAM_CUDA_ERROR(NO_DEVICE)
+ OSTREAM_CUDA_ERROR(INVALID_DEVICE)
+ OSTREAM_CUDA_ERROR(INVALID_IMAGE)
+ OSTREAM_CUDA_ERROR(INVALID_CONTEXT)
+ OSTREAM_CUDA_ERROR(INVALID_HANDLE)
+ OSTREAM_CUDA_ERROR(NOT_FOUND)
+ OSTREAM_CUDA_ERROR(NOT_READY)
+ OSTREAM_CUDA_ERROR(NO_BINARY_FOR_GPU)
+
+ // Encountered an uncorrectable ECC error during execution.
+ OSTREAM_CUDA_ERROR(ECC_UNCORRECTABLE)
+
+ // Load/store on an invalid address. Must reboot all context.
+ case 700:
+ return "CUDA_ERROR_ILLEGAL_ADDRESS";
+ // Passed too many / wrong arguments, too many threads for register count.
+ case 701:
+ return "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES";
+ // Kernel took too long to execute.
+ case 702:
+ return "CUDA_ERROR_LAUNCH_TIMEOUT";
+ // Kernel launch uses an incompatible texturing mode.
+ case 703:
+ return "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING";
+ // Trying to re-enable peer access that already has it enabled.
+ case 704:
+ return "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED";
+ // Trying to disable peer access that has not yet been enabled.
+ case 705:
+ return "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED";
+ // Primary context for the specified device has already been initialized.
+ case 708:
+ return "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE";
+ // Context current to calling thread has been destroyed or is a primary
+ // context that has not yet been initialized.
+ case 709:
+ return "CUDA_ERROR_CONTEXT_IS_DESTROYED";
+ // Device-side assert triggered during kernel execution. Must reboot all
+ // context.
+ case 710:
+ return "CUDA_ERROR_ASSERT";
+ // Hardware resources to enable peer access have been exhausted.
+ case 711:
+ return "CUDA_ERROR_TOO_MANY_PEERS";
+ // Memory range has already been registered.
+ case 712:
+ return "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED";
+ // Pointer does not correspond to any currently registered memory region.
+ case 713:
+ return "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED";
+ // Due to stack corruption or exceeding stack size limit. Must reboot all
+ // context.
+ case 714:
+ return "CUDA_ERROR_HARDWARE_STACK_ERROR";
+ case 715:
+ return "CUDA_ERROR_ILLEGAL_INSTRUCTION";
+ // Load/store on an unaligned memory address. Must reboot all context.
+ case 716:
+ return "CUDA_ERROR_MISALIGNED_ADDRESS";
+ // Device instruction with specific address space given address not
+ // belonging to allowed address space. Must reboot all context.
+ case 717:
+ return "CUDA_ERROR_INVALID_ADDRESS_SPACE";
+ // Device program counter wrapped its address space. Must reboot all
+ // context.
+ case 718:
+ return "CUDA_ERROR_INVALID_PC";
+ // Exception on device while executing a kernel; e.g. deref invalid device
+ // pointer, accessing OOB shared memory. Must reboot all context.
+ case 719:
+ return "CUDA_ERROR_LAUNCH_FAILED";
+
+ OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE)
+ OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED)
+ OSTREAM_CUDA_ERROR(NOT_PERMITTED)
+ OSTREAM_CUDA_ERROR(NOT_SUPPORTED)
+ OSTREAM_CUDA_ERROR(UNKNOWN) // Unknown internal error to CUDA.
+ default:
+ return port::StrCat("CUresult(", static_cast<int>(result), ")");
+ }
+#pragma GCC diagnostic pop
+}
+
+// Returns the current context and checks that it is in the set of CUDA contexts
+// created by StreamExecutor (to ensure that the CUDA runtime didn't create a
+// context behind our backs).
+CUcontext CurrentContext() {
+ CUcontext current = nullptr;
+ CUresult result = dynload::cuCtxGetCurrent(&current);
+ if (result != CUDA_SUCCESS) {
+ LOG(FATAL) << "failed to query current context: " << ToString(result);
+ }
+ if (current != nullptr && !CreatedContexts::Has(current)) {
+ LOG(FATAL) << "current context was not created by the StreamExecutor "
+ "cuda_driver API: "
+ << current
+ << "; a CUDA runtime call "
+ "was likely performed without using a StreamExecutor context";
+ }
+ return current;
+}
+
+// "Pops" the current context, checks that it matches expected, and checks the
+// postcondition that the current context is nullptr.
+//
+// This is not done when we're nested within a MultiOpActivation, as we want to
+// persist the active context until the MultiOpActivation is popped.
+void PopContextAndCheckNowNull(CUcontext expected) {
+ CUcontext actual = CurrentContext();
+ CHECK_EQ(expected, actual) << "would pop unexpected context";
+ CUcontext popped;
+ CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxPopCurrent_v2(&popped));
+ CHECK_EQ(expected, popped);
+ CHECK(nullptr == CurrentContext());
+ VLOG(3) << "popped context " << expected
+ << " and current context is now null";
+}
+
+// CUDA driver routines may require a large amount of stack (particularly
+// cuModuleLoadDataEx, in our experience). To avoid stack overflow when using
+// stack-limited threads (such as those spawned by a default-argument
+// thread::ThreadPool on some platforms), we run certain routines in this pool
+// and wait for completion.
+static mutex driver_executor_threadpool_mu(LINKER_INITIALIZED);
+static port::ThreadPool *InitializeDriverExecutor() {
+ return new port::ThreadPool(port::Env::Default(), port::ThreadOptions(),
+ "cuda_driver", 1);
+}
+
+port::ThreadPool *GetDriverExecutor() {
+ mutex_lock lock(driver_executor_threadpool_mu);
+ static port::ThreadPool *thread_pool = InitializeDriverExecutor();
+ return thread_pool;
+}
+
+} // namespace
+
+
+// Thread-local storage that indicates whether a CUDA context activation is
+// being nested within an outer, MultiOpActivation. In that case, we should not
+// pop the context to nullptr when we are done with the current activation.
+SE_STATIC_THREAD_LOCAL_POD(bool, tls_in_multi_op_activation);
+
+string MemorySpaceString(MemorySpace memory_space) {
+ switch (memory_space) {
+ case MemorySpace::kHost:
+ return "host";
+ case MemorySpace::kDevice:
+ return "device";
+ default:
+ LOG(FATAL) << "impossible memory space";
+ }
+}
+
+// Implementation note: the CUDA context is held, per-thread, in TLS. We avoid
+// setting all the time because it's not clear what side effects might occur for
+// a "set" operation, whereas a "get" operation we can reasonably assume is a
+// TLS read.
+//
+// We cannot race here because CUcontext is associated with a particular thread
+// and stored in TLS; and these interfaces should not be used from signal
+// handlers.
+ScopedActivateContext::ScopedActivateContext(CUcontext context,
+ MultiOpActivation moa)
+ : context_(CHECK_NOTNULL(context)),
+ previously_in_multi_op_activation_(tls_in_multi_op_activation.get()) {
+ if (static_cast<bool>(moa)) {
+ tls_in_multi_op_activation.get() = true;
+ }
+
+ CUcontext current = prior_context_ = CurrentContext();
+ if (current != context) {
+ VLOG(3) << "ScopedActivateContext switching context from " << current
+ << " to " << context;
+ CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(context));
+ if (FLAGS_gpuexec_cuda_sync_around_driver_calls) {
+ auto res = dynload::cuCtxSynchronize();
+ if (res != CUDA_SUCCESS) {
+ LOG(FATAL) << "gpuexec_cuda_sync_around_driver_calls found "
+ << ToString(res)
+ << " immediately after establishing the device context "
+ << context << " :: " << port::CurrentStackTrace();
+ }
+ }
+ }
+}
+
+ScopedActivateContext::~ScopedActivateContext() {
+ if (tls_in_multi_op_activation.get()) {
+ CHECK_EQ(context_, CurrentContext());
+ if (FLAGS_gpuexec_cuda_sync_around_driver_calls) {
+ auto res = dynload::cuCtxSynchronize();
+ if (res != CUDA_SUCCESS) {
+ LOG(FATAL) << "gpuexec_cuda_sync_around_driver_calls found "
+ << ToString(res)
+ << " immediately after de-establishing the device context "
+ << context_ << " :: " << port::CurrentStackTrace();
+ }
+ }
+ CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxSetCurrent(prior_context_));
+ } else {
+ PopContextAndCheckNowNull(context_);
+ }
+ tls_in_multi_op_activation.get() = previously_in_multi_op_activation_;
+}
+
+namespace {
+
+// Returns a stringified device number associated with pointer, primarily for
+// logging purposes. Returns "?" if the device could not be successfully
+// queried.
+string CUDAPointerToDeviceString(CUdeviceptr pointer) {
+ auto value = CUDADriver::GetPointerDevice(pointer);
+ if (value.ok()) {
+ return port::StrCat(value.ValueOrDie());
+ }
+ LOG(ERROR) << "could not query device: " << value.status();
+ return "?";
+}
+
+// Returns a stringified memory space associated with pointer, primarily for
+// logging purposes. Returns "?" if the memory space could not be successfully
+// queried.
+string CUDAPointerToMemorySpaceString(CUdeviceptr pointer) {
+ auto value = CUDADriver::GetPointerMemorySpace(pointer);
+ if (value.ok()) {
+ return MemorySpaceString(value.ValueOrDie());
+ }
+ LOG(ERROR) << "could not query device: " << value.status();
+ return "?";
+}
+
+// Returns a stringified representation of whether or not peer access is
+// permitted between the "from" and "to" pointers' associated contexts,
+// primarily for logging purposes. Returns "error" if an error is encountered
+// in the process of querying.
+string CUDAPointersToCanAccessString(CUdeviceptr from, CUdeviceptr to) {
+ auto from_context = CUDADriver::GetPointerContext(from);
+ if (!from_context.ok()) {
+ LOG(ERROR) << "could not retrieve source pointer's context: "
+ << from_context.status();
+ return "error";
+ }
+ auto to_context = CUDADriver::GetPointerContext(to);
+ if (!to_context.ok()) {
+ LOG(ERROR) << "could not retrieve destination pointer's context: "
+ << to_context.status();
+ return "error";
+ }
+ return CUDADriver::CanEnablePeerAccess(from_context.ValueOrDie(),
+ to_context.ValueOrDie())
+ ? "true"
+ : "false";
+}
+
+
+// Actually performs the work of CUDA initialization. Wrapped up in one-time
+// execution guard.
+static port::Status InternalInit() {
+ CUresult res = CUDA_ERROR_NO_DEVICE;
+ if (FLAGS_gpuexec_cuda_driver_inject_init_error) {
+ LOG(ERROR) << "injecting CUDA init error; initialization will fail";
+ } else if (internal::CachedDsoLoader::GetLibcudaDsoHandle().ok()) {
+ // We only call cuInit if we can dynload libcuda.
+
+ res = dynload::cuInit(0 /* = flags */);
+ }
+
+ if (res == CUDA_SUCCESS) {
+ return port::Status::OK();
+ }
+
+ LOG(ERROR) << "failed call to cuInit: " << ToString(res);
+ Diagnostician::LogDiagnosticInformation();
+ return port::Status{port::error::ABORTED,
+ port::StrCat("failed call to cuInit: ", ToString(res))};
+}
+
+} // namespace
+
+/* static */ port::Status CUDADriver::Init() {
+ // Cached return value from calling InternalInit(), as cuInit need only be
+ // called once, but CUDADriver::Init may be called many times.
+ static port::Status init_retval;
+ static bool set = false;
+ static mutex init_mu(LINKER_INITIALIZED);
+
+ mutex_lock lock(init_mu);
+ if (!set) {
+ init_retval = InternalInit();
+ set = true;
+ }
+
+ return init_retval;
+}
+
+/* static */ port::Status CUDADriver::GetDevice(int device_ordinal,
+ CUdevice *device) {
+ CUresult res = dynload::cuDeviceGet(device, device_ordinal);
+ if (res == CUDA_SUCCESS) {
+ return port::Status::OK();
+ }
+
+ return port::Status{
+ port::error::INTERNAL,
+ port::StrCat("failed call to cuDeviceGet: ", ToString(res))};
+}
+
+/* static */ bool CUDADriver::GetDeviceName(CUdevice device,
+ string *device_name) {
+ static const size_t kCharLimit = 64;
+ port::InlinedVector<char, 4> chars(kCharLimit);
+ CUresult res =
+ dynload::cuDeviceGetName(chars.begin(), kCharLimit - 1, device);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to get device name for " << device << ": "
+ << ToString(res);
+ return false;
+ }
+ chars[kCharLimit - 1] = '\0';
+ *device_name = chars.begin();
+ return true;
+}
+
+bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
+ static_assert(DeviceOptions::kMask == 0xf,
+ "needs update for new device options");
+
+ if (device_options.flags() & DeviceOptions::kDoNotReclaimStackAllocation) {
+ *flags |= CU_CTX_LMEM_RESIZE_TO_MAX;
+ }
+
+ // If no flags are set the default is CU_CTX_SCHED_AUTO, which
+ // in Google environments is very likely to mean SPIN.
+ if (device_options.flags() & DeviceOptions::kScheduleSpin) {
+ *flags |= CU_CTX_SCHED_SPIN;
+ }
+ if (device_options.flags() & DeviceOptions::kScheduleYield) {
+ *flags |= CU_CTX_SCHED_YIELD;
+ }
+ if (device_options.flags() & DeviceOptions::kScheduleBlockingSync) {
+ *flags |= CU_CTX_SCHED_BLOCKING_SYNC;
+ }
+
+ return true;
+}
+
+/* static */ port::Status CUDADriver::CreateContext(
+ CUdevice device, DeviceOptions device_options, CUcontext *context) {
+ CUcontext former_context = CurrentContext();
+ if (former_context != nullptr) {
+ LOG(WARNING) << "creating context when one is currently active; existing: "
+ << former_context;
+ }
+
+ int flags = 0;
+ if (!DeviceOptionsToContextFlags(device_options, &flags)) {
+ LOG(WARNING) << "could not convert all device options into context flags";
+ }
+
+ CUresult res;
+ {
+ // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
+ // context creation: see http://b/13248943
+
+ res = dynload::cuCtxCreate_v2(context, flags, device);
+ }
+ if (res == CUDA_SUCCESS) {
+ CreatedContexts::Add(*context);
+ PopContextAndCheckNowNull(*context);
+ CHECK(*context != nullptr)
+ << "success in this call must entail non-null result";
+ VLOG(2) << "created context " << context << " for this thread";
+ return port::Status::OK();
+ }
+
+ string message = "failed call to cuCtxCreate: " + ToString(res);
+ if (res == CUDA_ERROR_OUT_OF_MEMORY) {
+ uint64 total_memory;
+ if (GetDeviceTotalMemory(device, &total_memory)) {
+ port::StrAppend(&message, "; total memory reported: ", total_memory);
+ } else {
+ port::StrAppend(&message, "; could not query total memory");
+ }
+ }
+
+ return port::Status{port::error::INTERNAL, message};
+}
+
+/* static */ void CUDADriver::DestroyContext(CUcontext context) {
+ if (context == nullptr) {
+ return;
+ }
+
+ CUresult res = dynload::cuCtxDestroy_v2(context);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to destroy CUDA context; leaking: " << ToString(res);
+ }
+
+ CreatedContexts::Remove(context);
+}
+
+/* static */ bool CUDADriver::FuncGetAttribute(CUfunction_attribute attribute,
+ CUfunction func,
+ int *attribute_value) {
+ CUresult res = dynload::cuFuncGetAttribute(attribute_value, attribute, func);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query kernel attribute. kernel: " << func
+ << ", attribute: " << attribute;
+ return false;
+ }
+ return true;
+}
+
+/* static */ bool CUDADriver::FuncSetCacheConfig(CUfunction function,
+ CUfunc_cache cache_config) {
+ CUresult res = dynload::cuFuncSetCacheConfig(function, cache_config);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to set CUDA kernel cache config. kernel: " << function
+ << ", config: " << cache_config << ", result: " << ToString(res);
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ port::StatusOr<CUsharedconfig>
+CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
+ CUsharedconfig shared_mem_config;
+ ScopedActivateContext activation{context};
+ CUresult result = dynload::cuCtxGetSharedMemConfig(&shared_mem_config);
+ if (result != CUDA_SUCCESS) {
+ CUdevice device;
+ dynload::cuCtxGetDevice(&device);
+ LOG(ERROR) << "failed to get CUDA device shared memory config. "
+ << "Context device ID: " << device
+ << ", result: " << ToString(result);
+ return port::Status{
+ port::error::INTERNAL,
+ port::StrCat("failed to get shared memory config: ", ToString(result))};
+ }
+ return shared_mem_config;
+}
+
+/* static */ port::Status CUDADriver::ContextSetSharedMemConfig(
+ CUcontext context, CUsharedconfig shared_mem_config) {
+ ScopedActivateContext activation{context};
+ CUresult result = dynload::cuCtxSetSharedMemConfig(shared_mem_config);
+ if (result != CUDA_SUCCESS) {
+ CUdevice device;
+ dynload::cuCtxGetDevice(&device);
+ LOG(ERROR) << "failed to set CUDA device shared memory config. "
+ << "Context device ID: " << device
+ << ", config: " << shared_mem_config
+ << ", result: " << ToString(result);
+ return port::Status{
+ port::error::INTERNAL,
+ port::StrCat("failed to set shared memory config: ", ToString(result))};
+ }
+ return port::Status::OK();
+}
+
+/* static */ bool CUDADriver::LaunchKernel(
+ CUcontext context, CUfunction function, unsigned int grid_dim_x,
+ unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x,
+ unsigned int block_dim_y, unsigned int block_dim_z,
+ unsigned int shared_mem_bytes, CUstream stream, void **kernel_params,
+ void **extra) {
+ ScopedActivateContext activation{context};
+ VLOG(2) << "launching kernel: " << function << "; gdx: " << grid_dim_x
+ << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
+ << " bdx: " << block_dim_x << " bdy: " << block_dim_y
+ << " bdz: " << block_dim_z;
+ CUresult res = dynload::cuLaunchKernel(
+ function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
+ block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to launch CUDA kernel: " << function
+ << "; result: " << ToString(res);
+ return false;
+ }
+ VLOG(2) << "successfully launched kernel";
+ return true;
+}
+
+/* static */ port::Status CUDADriver::LoadCubin(CUcontext context,
+ const char *cubin_bytes,
+ CUmodule *module) {
+ ScopedActivateContext activation{context};
+ CUresult result = dynload::cuModuleLoadFatBinary(module, cubin_bytes);
+ if (result != CUDA_SUCCESS) {
+ return port::Status{port::error::INTERNAL,
+ "failed to load in-memory CUBIN: " + ToString(result)};
+ }
+
+ return port::Status::OK();
+}
+
+/* static */ bool CUDADriver::LoadPtx(CUcontext context,
+ const char *ptx_contents,
+ CUmodule *module) {
+ port::Notification notification;
+ bool ret = true;
+ GetDriverExecutor()->Schedule([context, ptx_contents, module, &ret,
+ &notification]() {
+ ScopedActivateContext activation{context};
+ void *ptx_data = const_cast<char *>(ptx_contents);
+ static const unsigned int kLogBufferBytesLimit = 1024;
+ unsigned int error_log_buffer_bytes = kLogBufferBytesLimit;
+ unsigned int info_log_buffer_bytes = kLogBufferBytesLimit;
+ port::InlinedVector<char, 4> error_log_buffer(error_log_buffer_bytes);
+ port::InlinedVector<char, 4> info_log_buffer(info_log_buffer_bytes);
+ bool log_verbose = true;
+ CUjit_option options[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
+ CU_JIT_ERROR_LOG_BUFFER,
+ CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
+ CU_JIT_INFO_LOG_BUFFER, CU_JIT_LOG_VERBOSE};
+ // Note that the driver API wants the contents of this values to be stored
+ // in an array of void*s, so we coerce them accordingly.
+ void *option_values[] = {
+ port::bit_cast<void *>(uintptr_t(error_log_buffer_bytes)),
+ port::bit_cast<void *>(error_log_buffer.data()),
+ port::bit_cast<void *>(uintptr_t(info_log_buffer_bytes)),
+ port::bit_cast<void *>(info_log_buffer.data()),
+ port::bit_cast<void *>(uintptr_t(log_verbose))};
+ CHECK(ARRAYSIZE(options) == ARRAYSIZE(option_values));
+
+ CUresult res;
+ {
+ // TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
+ // module loading: see http://b/13248943
+
+ res = dynload::cuModuleLoadDataEx(module, ptx_data, ARRAYSIZE(options),
+ options, option_values);
+ }
+
+ // The PTX JIT mutates the values in the option values array to reflect the
+ // size of the logs it output; now that we've made the call, read the values
+ // back out.
+ error_log_buffer_bytes = reinterpret_cast<uintptr_t>(option_values[0]);
+ info_log_buffer_bytes = reinterpret_cast<uintptr_t>(option_values[2]);
+ CHECK_LE(error_log_buffer_bytes, kLogBufferBytesLimit);
+ CHECK_LE(info_log_buffer_bytes, kLogBufferBytesLimit);
+
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to load PTX text as a module: " << ToString(res);
+ // As a precaution for null termination of the API-provided value, ensure
+ // that at least the last byte is null.
+ error_log_buffer[error_log_buffer_bytes ?
+ error_log_buffer_bytes - 1 : 0] = '\0';
+ LOG(ERROR) << "error log buffer (" << error_log_buffer_bytes
+ << " bytes): " << error_log_buffer.data();
+ ret = false;
+ notification.Notify();
+ }
+
+ VLOG(3) << "PTX compilation info log (" << info_log_buffer_bytes
+ << " bytes): " << info_log_buffer.data();
+ VLOG(3) << "PTX compilation error log (" << error_log_buffer_bytes
+ << " bytes): " << error_log_buffer.data();
+ CHECK(module != nullptr);
+ notification.Notify();
+ });
+ notification.WaitForNotification();
+
+ return ret;
+}
+
+/* static */ bool CUDADriver::SynchronousMemsetUint8(CUcontext context,
+ CUdeviceptr location,
+ uint8 value, size_t size) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemsetD8_v2(location, value, size);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to memset memory: " << ToString(res);
+ return false;
+ }
+ return true;
+}
+
+/* static */ bool CUDADriver::SynchronousMemsetUint32(CUcontext context,
+ CUdeviceptr location,
+ uint32 value,
+ size_t uint32_count) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemsetD32_v2(location, value, uint32_count);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to memset memory: " << ToString(res);
+ return false;
+ }
+ return true;
+}
+
+/* static */ bool CUDADriver::AsynchronousMemsetUint32(CUcontext context,
+ CUdeviceptr location,
+ uint32 value,
+ size_t uint32_count,
+ CUstream stream) {
+ ScopedActivateContext activation{context};
+ CUresult res =
+ dynload::cuMemsetD32Async(location, value, uint32_count, stream);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
+ return false;
+ }
+ VLOG(2) << "successfully enqueued async memset operation";
+ return true;
+}
+
+/* static */ bool CUDADriver::AddStreamCallback(CUcontext context,
+ CUstream stream,
+ StreamCallback callback,
+ void *data) {
+ // Note: flags param is required to be zero according to CUDA 6.0.
+ CUresult res =
+ dynload::cuStreamAddCallback(stream, callback, data, 0 /* = flags */);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "unable to add host callback: " << ToString(res);
+ return false;
+ }
+ return true;
+}
+
+/* static */ bool CUDADriver::GetModuleFunction(CUcontext context,
+ CUmodule module,
+ const char *kernel_name,
+ CUfunction *function) {
+ ScopedActivateContext activated{context};
+ CHECK(module != nullptr && kernel_name != nullptr);
+ CUresult res = dynload::cuModuleGetFunction(function, module, kernel_name);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to get PTX kernel \"" << kernel_name
+ << "\" from module: " << ToString(res);
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ bool CUDADriver::GetModuleSymbol(CUcontext context,
+ CUmodule module,
+ const char *symbol_name,
+ CUdeviceptr *dptr,
+ size_t *bytes) {
+ ScopedActivateContext activated{context};
+ CHECK(module != nullptr && symbol_name != nullptr &&
+ (dptr != nullptr || bytes != nullptr));
+ CUresult res =
+ dynload::cuModuleGetGlobal_v2(dptr, bytes, module, symbol_name);
+ if (res != CUDA_SUCCESS) {
+ // symbol may not be found in the current module, but it may reside in
+ // another module.
+ VLOG(2) << "failed to get symbol \"" << symbol_name
+ << "\" from module: " << ToString(res);
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ void CUDADriver::UnloadModule(CUcontext context, CUmodule module) {
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuModuleUnload(module);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to unload module " << module
+ << "; leaking: " << ToString(res);
+ }
+}
+
+/* static */ port::StatusOr<CUdevice> CUDADriver::DeviceFromContext(
+ CUcontext context) {
+ ScopedActivateContext activated{context};
+ CUdevice device = -1;
+ CUresult result = dynload::cuCtxGetDevice(&device);
+ if (result == CUDA_SUCCESS) {
+ return device;
+ }
+
+ return port::Status{
+ port::error::INTERNAL,
+ port::StrCat("failed to get device for context: ", ToString(result))};
+}
+
+/* static */ bool CUDADriver::CreateStream(CUcontext context, CUstream *out) {
+ // TODO(leary) can we switch this to CU_STREAM_NON_BLOCKING or will that mess
+ // up synchronization with respect to memsets and any other things that have
+ // to occur on the default stream?
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuStreamCreate(out, 0);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "could not allocate CUDA stream for context " << context
+ << ": " << ToString(res);
+ return false;
+ }
+
+ VLOG(2) << "successfully created stream " << *out << " for context "
+ << context << " on thread";
+ return true;
+}
+
+/* static */ void CUDADriver::DestroyStream(CUcontext context,
+ CUstream *stream) {
+ if (*stream == nullptr) {
+ return;
+ }
+
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuStreamDestroy_v2(*stream);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to destroy CUDA stream for context " << context
+ << ": " << ToString(res);
+ } else {
+ VLOG(2) << "successfully destroyed stream " << *stream << " for context "
+ << context;
+ *stream = nullptr;
+ }
+}
+
+/* static */ void *CUDADriver::DeviceAllocate(CUcontext context, uint64 bytes) {
+ ScopedActivateContext activated{context};
+ CUdeviceptr result = 0;
+ CUresult res = dynload::cuMemAlloc_v2(&result, bytes);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to allocate "
+ << port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes
+ << " bytes) from device: " << ToString(res);
+ return nullptr;
+ }
+ void *ptr = reinterpret_cast<void *>(result);
+ VLOG(2) << "allocated " << ptr << " for context " << context << " of "
+ << bytes << " bytes";
+ return ptr;
+}
+
+/* static */ void CUDADriver::DeviceDeallocate(CUcontext context,
+ void *location) {
+ ScopedActivateContext activation{context};
+ CUdeviceptr pointer = port::bit_cast<CUdeviceptr>(location);
+ CUresult res = dynload::cuMemFree_v2(pointer);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to free device memory at " << location
+ << "; result: " << ToString(res);
+ } else {
+ VLOG(2) << "deallocated " << location << " for context " << context;
+ }
+}
+
+/* static */ void *CUDADriver::HostAllocate(CUcontext context, uint64 bytes) {
+ ScopedActivateContext activation{context};
+ void *host_mem = nullptr;
+ // "Portable" memory is visible to all CUDA contexts. Safe for our use model.
+ CUresult res =
+ dynload::cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to alloc " << bytes
+ << " bytes on host: " << ToString(res);
+ }
+ return host_mem;
+}
+
+/* static */ void CUDADriver::HostDeallocate(CUcontext context,
+ void *location) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemFreeHost(location);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "error deallocating host memory at " << location << ": "
+ << ToString(res);
+ }
+}
+
+/* static */ bool CUDADriver::HostRegister(CUcontext context, void *location,
+ uint64 bytes) {
+ ScopedActivateContext activation{context};
+ // "Portable" memory is visible to all CUDA contexts. Safe for our use model.
+ CUresult res =
+ dynload::cuMemHostRegister(location, bytes, CU_MEMHOSTREGISTER_PORTABLE);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "error registering host memory at " << location << ": "
+ << ToString(res);
+ return false;
+ }
+ return true;
+}
+
+/* static */ bool CUDADriver::HostUnregister(CUcontext context,
+ void *location) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemHostUnregister(location);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "error unregistering host memory at " << location << ": "
+ << ToString(res);
+ return false;
+ }
+ return true;
+}
+
+/* static */ port::Status CUDADriver::DestroyEvent(CUcontext context,
+ CUevent *event) {
+ if (*event == nullptr) {
+ return port::Status{port::error::INVALID_ARGUMENT,
+ "input event cannot be null"};
+ }
+
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuEventDestroy_v2(*event);
+ *event = nullptr;
+
+ switch (res) {
+ case CUDA_SUCCESS:
+ return port::Status::OK();
+ case CUDA_ERROR_DEINITIALIZED:
+ case CUDA_ERROR_NOT_INITIALIZED:
+ return port::Status{
+ port::error::FAILED_PRECONDITION,
+ port::Printf("error destroying CUDA event in context %p: %s", context,
+ ToString(res).c_str())};
+ default:
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf("error destroying CUDA event in context %p: %s", context,
+ ToString(res).c_str())};
+ }
+}
+
+/* static */ port::Status CUDADriver::RecordEvent(CUcontext context,
+ CUevent event,
+ CUstream stream) {
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuEventRecord(event, stream);
+ switch (res) {
+ case CUDA_SUCCESS:
+ return port::Status::OK();
+ case CUDA_ERROR_DEINITIALIZED:
+ case CUDA_ERROR_NOT_INITIALIZED:
+ return port::Status{
+ port::error::FAILED_PRECONDITION,
+ port::Printf("error recording CUDA event on stream %p: %s", stream,
+ ToString(res).c_str())};
+ default:
+ return port::Status{
+ port::error::INVALID_ARGUMENT,
+ port::Printf("error recording CUDA event on stream %p: %s", stream,
+ ToString(res).c_str())};
+ }
+}
+
+/* static */ port::StatusOr<CUresult> CUDADriver::QueryEvent(CUcontext context,
+ CUevent event) {
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuEventQuery(event);
+ if (res != CUDA_SUCCESS && res != CUDA_ERROR_NOT_READY) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf("failed to query event: %s", ToString(res).c_str())};
+ }
+
+ return res;
+}
+
+/* static */ bool CUDADriver::GetEventElapsedTime(CUcontext context,
+ float *elapsed_milliseconds,
+ CUevent start, CUevent stop) {
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuEventElapsedTime(elapsed_milliseconds, start, stop);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to get elapsed time between events: "
+ << ToString(res);
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ bool CUDADriver::WaitStreamOnEvent(CUcontext context,
+ CUstream stream,
+ CUevent event) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuStreamWaitEvent(stream, event, 0 /* = flags */);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "could not wait stream on event: " << ToString(res);
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ bool CUDADriver::SynchronizeContext(CUcontext context) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuCtxSynchronize();
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "could not synchronize on CUDA context: " << ToString(res)
+ << " :: " << port::CurrentStackTrace();
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ bool CUDADriver::SynchronizeStream(CUcontext context,
+ CUstream stream) {
+ ScopedActivateContext activated{context};
+ CHECK(stream != nullptr);
+ CUresult res = dynload::cuStreamSynchronize(stream);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "could not synchronize on CUDA stream: " << ToString(res)
+ << " :: " << port::CurrentStackTrace();
+ return false;
+ }
+ VLOG(2) << "successfully synchronized stream " << stream << " on context "
+ << context;
+ return true;
+}
+
+/* static */ bool CUDADriver::IsStreamIdle(CUcontext context, CUstream stream) {
+ ScopedActivateContext activated{context};
+ CHECK(stream != nullptr);
+ CUresult res = dynload::cuStreamQuery(stream);
+ if (res == CUDA_SUCCESS) {
+ return true;
+ }
+
+ if (res != CUDA_ERROR_NOT_READY) {
+ LOG(ERROR) << "stream in bad state on status query: " << ToString(res);
+ }
+ return false;
+}
+
+/* static */ bool CUDADriver::SynchronousMemcpyD2H(CUcontext context,
+ void *host_dst,
+ CUdeviceptr gpu_src,
+ uint64 size) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemcpyDtoH_v2(host_dst, gpu_src, size);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << port::Printf(
+ "failed to synchronous memcpy from device to host: %s; "
+ "host dst: %p; GPU src: %p; size: %llu=0x%llx",
+ ToString(res).c_str(), host_dst, port::bit_cast<void *>(gpu_src), size, size);
+ return false;
+ }
+ VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to "
+ << host_dst;
+ return true;
+}
+
+/* static */ bool CUDADriver::SynchronousMemcpyH2D(CUcontext context,
+ CUdeviceptr gpu_dst,
+ const void *host_src,
+ uint64 size) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemcpyHtoD_v2(gpu_dst, host_src, size);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << port::Printf(
+ "failed to synchronous memcpy from host to device: %s; GPU dst: %p;"
+ " host src: %p; size: %llu=0x%llx",
+ ToString(res).c_str(), port::bit_cast<void *>(gpu_dst), host_src, size, size);
+ return false;
+ }
+ VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes";
+ return true;
+}
+
+/* static */ bool CUDADriver::SynchronousMemcpyD2D(CUcontext context,
+ CUdeviceptr gpu_dst,
+ CUdeviceptr gpu_src,
+ uint64 size) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemcpyDtoD_v2(gpu_dst, gpu_src, size);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << port::Printf(
+ "failed to synchronous memcpy from host to device: %s; GPU dst: %p; "
+ "GPU src: %p; size: %llu=0x%llx",
+ ToString(res).c_str(), port::bit_cast<void *>(gpu_dst),
+ port::bit_cast<void *>(gpu_src), size, size);
+ return false;
+ }
+ VLOG(2) << "successfully sync memcpy'd d2d of " << size << " bytes";
+ return true;
+}
+
+/* static */ bool CUDADriver::AsynchronousMemcpyD2H(CUcontext context,
+ void *host_dst,
+ CUdeviceptr gpu_src,
+ uint64 size,
+ CUstream stream) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemcpyDtoHAsync_v2(host_dst, gpu_src, size, stream);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << port::Printf(
+ "failed to enqueue async memcpy from device to host: %s; host dst: %p; "
+ "GPU src: %p; size: %llu=0x%llx",
+ ToString(res).c_str(), host_dst, port::bit_cast<void *>(gpu_src), size, size);
+ return false;
+ }
+ VLOG(2) << "successfully enqueued async memcpy d2h of " << size
+ << " bytes from " << port::bit_cast<void *>(gpu_src) << " to " << host_dst
+ << " on stream " << stream;
+ return true;
+}
+
+/* static */ bool CUDADriver::AsynchronousMemcpyH2D(CUcontext context,
+ CUdeviceptr gpu_dst,
+ const void *host_src,
+ uint64 size,
+ CUstream stream) {
+ ScopedActivateContext activation{context};
+ CUresult res = dynload::cuMemcpyHtoDAsync_v2(gpu_dst, host_src, size, stream);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << port::Printf(
+ "failed to enqueue async memcpy from host to device: %s; GPU dst: %p; "
+ "host src: %p; size: %llu=0x%llx",
+ ToString(res).c_str(), port::bit_cast<void *>(gpu_dst), host_src, size, size);
+ return false;
+ }
+ VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes"
+ << " on stream " << stream;
+ return true;
+}
+
+/* static */ bool CUDADriver::AsynchronousMemcpyD2D(CUcontext context,
+ CUdeviceptr gpu_dst,
+ CUdeviceptr gpu_src,
+ uint64 size,
+ CUstream stream) {
+ ScopedActivateContext activation{context};
+ CUresult result =
+ dynload::cuMemcpyDtoDAsync_v2(gpu_dst, gpu_src, size, stream);
+ if (result != CUDA_SUCCESS) {
+ LOG(ERROR) << port::Printf(
+ "failed to enqueue async memcpy from device to device: %s"
+ "; GPU dst: %p on %s %s"
+ "; GPU src: %p on %s %s"
+ "; can access? %s; size: %llu=0x%llx",
+ ToString(result).c_str(), port::bit_cast<void *>(gpu_dst),
+ CUDAPointerToMemorySpaceString(gpu_dst).c_str(),
+ CUDAPointerToDeviceString(gpu_dst).c_str(), port::bit_cast<void *>(gpu_src),
+ CUDAPointerToMemorySpaceString(gpu_src).c_str(),
+ CUDAPointerToDeviceString(gpu_src).c_str(),
+ CUDAPointersToCanAccessString(gpu_src, gpu_dst).c_str(), size, size);
+
+ return false;
+ }
+ VLOG(2) << "successfully enqueued async memcpy d2d of " << size << " bytes";
+ return true;
+}
+
+/* static */ port::Status CUDADriver::CreateEvent(CUcontext context,
+ CUevent *result,
+ EventFlags flags) {
+ int cuflags;
+ switch (flags) {
+ case EventFlags::kDefault:
+ cuflags = CU_EVENT_DEFAULT;
+ break;
+ case EventFlags::kDisableTiming:
+ cuflags = CU_EVENT_DISABLE_TIMING;
+ break;
+ default:
+ LOG(FATAL) << "impossible event flags: " << int(flags);
+ }
+
+ ScopedActivateContext activated{context};
+ CUresult res = dynload::cuEventCreate(result, cuflags);
+
+ if (res == CUDA_SUCCESS) {
+ return port::Status::OK();
+ } else if (res == CUDA_ERROR_OUT_OF_MEMORY) {
+ return port::Status{port::error::RESOURCE_EXHAUSTED,
+ "could not create CUDA event: out of device memory"};
+ } else {
+ return port::Status{
+ port::error::FAILED_PRECONDITION,
+ port::StrCat("could not create CUDA event: ", ToString(res))};
+ }
+}
+
+/* static */ int CUDADriver::GetDeviceCount() {
+ int device_count = 0;
+ CUresult res = dynload::cuDeviceGetCount(&device_count);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "could not retrieve CUDA device count: " << ToString(res);
+ return 0;
+ }
+
+ if (FLAGS_gpuexec_cuda_device_0_only && device_count > 1) {
+ device_count = 1;
+ }
+ return device_count;
+}
+
+/* static */ port::StatusOr<CUcontext> CUDADriver::GetPointerContext(
+ CUdeviceptr pointer) {
+ CUcontext context = nullptr;
+ CUresult result = dynload::cuPointerGetAttribute(
+ &context, CU_POINTER_ATTRIBUTE_CONTEXT, pointer);
+ if (result == CUDA_SUCCESS) {
+ CHECK(context != nullptr) << "success should entail non-null context";
+ return context;
+ }
+
+ return port::Status{
+ port::error::INTERNAL,
+ port::StrCat("failed to query device pointer for context: ",
+ ToString(result))};
+}
+
+/* static */ port::StatusOr<MemorySpace> CUDADriver::GetPointerMemorySpace(
+ CUdeviceptr pointer) {
+ unsigned int value;
+ CUresult result = dynload::cuPointerGetAttribute(
+ &value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer);
+ if (result == CUDA_SUCCESS) {
+ switch (value) {
+ case CU_MEMORYTYPE_DEVICE:
+ return MemorySpace::kDevice;
+ case CU_MEMORYTYPE_HOST:
+ return MemorySpace::kHost;
+ default:
+ return port::Status{
+ port::error::INTERNAL,
+ port::StrCat("unknown memory space provided by CUDA API: ", value)};
+ }
+ }
+
+ return port::Status{
+ port::error::INTERNAL,
+ port::StrCat("failed to query device pointer for memory space: ",
+ ToString(result))};
+}
+
+/* static */ port::Status CUDADriver::GetPointerAddressRange(CUdeviceptr dptr,
+ CUdeviceptr *base,
+ size_t *size) {
+ CUresult result = dynload::cuMemGetAddressRange(base, size, dptr);
+ if (result == CUDA_SUCCESS) {
+ return port::Status::OK();
+ } else if (result == CUDA_ERROR_NOT_FOUND) {
+ // We differentiate between "this pointer is unknown" (return here) and
+ // "there was an internal error while performing this operation" (return
+ // below).
+ return port::Status{
+ port::error::NOT_FOUND,
+ port::Printf("not a device pointer %p; %s",
+ reinterpret_cast<void *>(dptr), ToString(result).c_str())};
+ }
+
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf("failed to get pointer into for device pointer %p; %s",
+ reinterpret_cast<void *>(dptr), ToString(result).c_str())};
+}
+
+/* static */ port::StatusOr<CUdevice> CUDADriver::GetPointerDevice(
+ CUdeviceptr pointer) {
+ auto result = GetPointerContext(pointer);
+ if (!result.ok()) {
+ return result.status();
+ }
+
+ return DeviceFromContext(result.ValueOrDie());
+}
+
+/* static */ port::Status CUDADriver::GetComputeCapability(int *cc_major,
+ int *cc_minor,
+ CUdevice device) {
+ *cc_major = 0;
+ *cc_minor = 0;
+ CUresult result =
+ dynload::cuDeviceComputeCapability(cc_major, cc_minor, device);
+ if (result == CUDA_SUCCESS) {
+ return port::Status::OK();
+ }
+
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf("failed to get compute capability for device: %s; %d",
+ ToString(result).c_str(), device)};
+}
+
+// Helper function that turns the integer output of cuDeviceGetAttribute to type
+// T and wraps it in a StatusOr.
+template <typename T>
+static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
+ CUdevice_attribute attribute) {
+ int value = -1;
+ CUresult result = dynload::cuDeviceGetAttribute(&value, attribute, device);
+ if (result != CUDA_SUCCESS) {
+ return port::Status{
+ port::error::NOT_FOUND,
+ port::StrCat("could not retrieve CUDA device attribute (", attribute,
+ "): ", ToString(result))};
+ }
+ T converted = value;
+ return converted;
+}
+
+/* static */ port::StatusOr<int> CUDADriver::GetMultiprocessorCount(
+ CUdevice device) {
+ return GetSimpleAttribute<int>(device,
+ CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT);
+}
+
+/* static */ port::StatusOr<int64> CUDADriver::GetMaxSharedMemoryPerCore(
+ CUdevice device) {
+ return GetSimpleAttribute<int64>(
+ device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR);
+}
+
+/* static */ port::StatusOr<int64> CUDADriver::GetMaxSharedMemoryPerBlock(
+ CUdevice device) {
+ return GetSimpleAttribute<int64>(
+ device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK);
+}
+
+/* static */ port::StatusOr<int64> CUDADriver::GetMaxThreadsPerMultiprocessor(
+ CUdevice device) {
+ return GetSimpleAttribute<int64>(
+ device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR);
+}
+
+/* static */ port::StatusOr<int64> CUDADriver::GetMaxThreadsPerBlock(
+ CUdevice device) {
+ return GetSimpleAttribute<int64>(device,
+ CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK);
+}
+
+/* static */ port::StatusOr<int64> CUDADriver::GetMaxRegistersPerBlock(
+ CUdevice device) {
+ return GetSimpleAttribute<int64>(device,
+ CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK);
+}
+
+/* static */ port::StatusOr<int64> CUDADriver::GetThreadsPerWarp(
+ CUdevice device) {
+ return GetSimpleAttribute<int64>(device, CU_DEVICE_ATTRIBUTE_WARP_SIZE);
+}
+
+/* static */ bool CUDADriver::GetGridLimits(int *x, int *y, int *z,
+ CUdevice device) {
+ int value;
+ CUresult res = dynload::cuDeviceGetAttribute(
+ &value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query max grid dim x: " << ToString(res);
+ return false;
+ }
+ *x = value;
+
+ res = dynload::cuDeviceGetAttribute(
+ &value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query max grid dim y: " << ToString(res);
+ return false;
+ }
+ *y = value;
+
+ res = dynload::cuDeviceGetAttribute(
+ &value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query max grid dim z: " << ToString(res);
+ return false;
+ }
+ *z = value;
+ return true;
+}
+
+/* static */ bool CUDADriver::GetDriverVersion(int *driver_version) {
+ CUresult res = dynload::cuDriverGetVersion(driver_version);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query driver version: " << ToString(res);
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ bool CUDADriver::GetDeviceProperties(CUdevprop *device_properties,
+ int device_ordinal) {
+ CUresult res =
+ dynload::cuDeviceGetProperties(device_properties, device_ordinal);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query device properties: " << ToString(res);
+ return false;
+ }
+
+ return true;
+}
+
+/* static */ bool CUDADriver::IsEccEnabled(CUdevice device, bool *result) {
+ int value = -1;
+ CUresult res = dynload::cuDeviceGetAttribute(
+ &value, CU_DEVICE_ATTRIBUTE_ECC_ENABLED, device);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query ECC status: " << ToString(res);
+ return false;
+ }
+
+ *result = value;
+ return true;
+}
+
+/* static */ bool CUDADriver::GetDeviceMemoryInfo(CUcontext context,
+ int64 *free_out,
+ int64 *total_out) {
+ ScopedActivateContext activation{context};
+ size_t free = 0;
+ size_t total = 0;
+ CUresult res = dynload::cuMemGetInfo_v2(&free, &total);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query device memory info: " << ToString(res);
+ return false;
+ }
+
+ *free_out = free;
+ *total_out = total;
+ return true;
+}
+
+/* static */ bool CUDADriver::GetDeviceTotalMemory(CUdevice device,
+ uint64 *result) {
+ size_t value = -1;
+ CUresult res = dynload::cuDeviceTotalMem_v2(&value, device);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query total available memory: " << ToString(res);
+ return false;
+ }
+
+ *result = value;
+ return true;
+}
+
+/* static */ string CUDADriver::GetPCIBusID(CUdevice device) {
+ string pci_bus_id;
+ static const int kBufferSize = 64;
+ port::InlinedVector<char, 4> chars(kBufferSize);
+ chars[kBufferSize - 1] = '\0';
+ CUresult res =
+ dynload::cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res);
+ return pci_bus_id;
+ }
+ pci_bus_id = chars.begin();
+ return pci_bus_id;
+}
+
+/* static */ bool CUDADriver::CanEnablePeerAccess(CUcontext from,
+ CUcontext to) {
+ if (from == to) {
+ return true; // A context can always access its own memory.
+ }
+
+ int can_access_peer = -1;
+ auto from_device = DeviceFromContext(from);
+ if (!from_device.ok()) {
+ LOG(ERROR) << "failed to resolve 'from' peer access context to a device: "
+ << from_device.status();
+ return false;
+ }
+ auto to_device = DeviceFromContext(to);
+ if (!to_device.ok()) {
+ LOG(ERROR) << "failed to resolve 'to' peer access context to a device: "
+ << to_device.status();
+ return false;
+ }
+ CUresult res = dynload::cuDeviceCanAccessPeer(
+ &can_access_peer, from_device.ValueOrDie(), to_device.ValueOrDie());
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to detect peer access capability: " << ToString(res);
+ return false;
+ }
+
+ return can_access_peer;
+}
+
+/* static */ port::Status CUDADriver::EnablePeerAccess(CUcontext from,
+ CUcontext to) {
+ if (from == to) {
+ return port::Status::OK(); // A context can always access its own memory.
+ }
+
+ ScopedActivateContext activated{from};
+ CUresult result = dynload::cuCtxEnablePeerAccess(to, 0 /* = flags */);
+ if (result != CUDA_SUCCESS &&
+ result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf("failed to enable peer access from %p to %p: %s", from, to,
+ ToString(result).c_str())};
+ }
+
+ return port::Status::OK();
+}
+
+/* static */ port::StatusOr<int> CUDADriver::GetMaxOccupiedBlocksPerCore(
+ CUcontext context, CUfunction kernel, int threads_per_block,
+ size_t dynamic_shared_memory_bytes) {
+ ScopedActivateContext activation{context};
+
+ int max_blocks;
+ CUresult result = dynload::cuOccupancyMaxActiveBlocksPerMultiprocessor(
+ &max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes);
+ if (result != CUDA_SUCCESS) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf("failed to calculate occupancy of kernel %p: %s", kernel,
+ ToString(result).c_str())};
+ }
+
+ return max_blocks;
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h
new file mode 100644
index 0000000000..007db222d9
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_driver.h
@@ -0,0 +1,460 @@
+// CUDA userspace driver library wrapper functionality.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
+
+#include <stddef.h>
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/cuda/multi_op_activation.h"
+#include "tensorflow/stream_executor/device_options.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// Identifies the memory space where an allocation resides. See
+// CUDADriver::GetPointerMemorySpace().
+enum class MemorySpace { kHost, kDevice };
+
+// Returns a casual string, such as "host" for the provided memory space.
+string MemorySpaceString(MemorySpace memory_space);
+
+// CUDADriver contains wrappers for calls to the userspace library driver. It's
+// useful to isolate these calls and put basic wrappers around them to separate
+// userspace library driver behaviors from the rest of the program.
+//
+// At the moment it's simply used as a namespace.
+//
+// The calls log any specific errors internally and return whether the operation
+// was successful to the caller.
+//
+// The order of parameters is generally kept symmetric with the underlying CUDA
+// driver API.
+//
+// Links on functions are to specific documentation under
+// http://docs.nvidia.com/cuda/cuda-driver-api/
+//
+// Thread safety: these functions should not be used from signal handlers.
+class CUDADriver {
+ public:
+ // Wraps a call to cuInit with logging to help indicate what has gone wrong in
+ // the case of failure. Safe to call multiple times; will be fast on all calls
+ // after the first.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html#group__CUDA__INITIALIZE_1g0a2f1517e1bd8502c7194c3a8c134bc3
+ static port::Status Init();
+
+ // Returns the device associated with the given context.
+ // device is an outparam owned by the caller, must not be null.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g4e84b109eba36cdaaade167f34ae881e
+ static port::StatusOr<CUdevice> DeviceFromContext(CUcontext context);
+
+ // Creates a new CUDA stream associated with the given context via
+ // cuStreamCreate.
+ // stream is an outparam owned by the caller, must not be null.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1ga581f0c5833e21ded8b5a56594e243f4
+ static bool CreateStream(CUcontext context, CUstream *stream);
+
+ // Destroys a CUDA stream associated with the given context.
+ // stream is owned by the caller, must not be null, and *stream is set to null
+ // if the stream is successfuly destroyed.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g244c8833de4596bcd31a06cdf21ee758
+ static void DestroyStream(CUcontext context, CUstream *stream);
+
+ // CUDA events can explicitly disable event TSC retrieval for some presumed
+ // performance improvement if timing is unnecessary.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db
+ enum class EventFlags { kDefault, kDisableTiming };
+
+ // Creates a new event associated with the given context.
+ // result is an outparam owned by the caller and must not be null.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db
+ static port::Status CreateEvent(CUcontext context, CUevent *result,
+ EventFlags flags);
+
+ // Destroys *event and turns it into a nullptr. event may not be null, but
+ // *event may be, via cuEventDestroy
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g593ec73a8ec5a5fc031311d3e4dca1ef
+ static port::Status DestroyEvent(CUcontext context, CUevent *event);
+
+ // Allocates a GPU memory space of size bytes associated with the given
+ // context via cuMemAlloc.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb82d2a09844a58dd9e744dc31e8aa467
+ static void *DeviceAllocate(CUcontext context, uint64 bytes);
+
+ // Deallocates a GPU memory space of size bytes associated with the given
+ // context via cuMemFree.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a
+ static void DeviceDeallocate(CUcontext context, void *location);
+
+ // Allocates page-locked and CUDA-registered memory on the host via
+ // cuMemAllocHost.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gdd8311286d2c2691605362c689bc64e0
+ static void *HostAllocate(CUcontext context, uint64 bytes);
+
+ // Deallocates a location created by HostAllocate, via cuMemFreeHost.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g62e0fdbe181dab6b1c90fa1a51c7b92c
+ static void HostDeallocate(CUcontext context, void *location);
+
+ // Registers a memory region at location of size bytes via cuMemHostRegister.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223
+ static bool HostRegister(CUcontext context, void *location, uint64 bytes);
+
+ // Unregisters a memory region that was previously registered at location via
+ // cuMemHostUnregister.
+ //
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g63f450c8125359be87b7623b1c0b2a14
+ //
+ // TODO(leary) verify an error will be returned if the location wasn't
+ // previously registered.
+ static bool HostUnregister(CUcontext context, void *location);
+
+ // Given a device ordinal, returns a device handle into the device outparam,
+ // which must not be null.
+ //
+ // N.B. these device handles do not have a corresponding destroy function in
+ // the CUDA driver API.
+ static port::Status GetDevice(int device_ordinal, CUdevice *device);
+
+ // Given a device handle, returns the name reported by the driver for the
+ // device.
+ static bool GetDeviceName(CUdevice device, string *name_out);
+
+ // Given a device to create a context for, returns a context handle into the
+ // context outparam, which must not be null.
+ //
+ // N.B. CUDA contexts are weird. They are implicitly associated with the
+ // calling thread. Current documentation on contexts and their influence on
+ // userspace processes is given here:
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf
+ static port::Status CreateContext(CUdevice device,
+ DeviceOptions device_options,
+ CUcontext *context);
+
+ // Destroys the provided context via cuCtxDestroy.
+ // Don't do this while clients could still be using the context, per the docs
+ // bad things will happen.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e
+ static void DestroyContext(CUcontext context);
+
+ // Queries the runtime for the specified attribute of the specified function.
+ // cuFuncGetAttribute (the underlying CUDA driver API routine) only operates
+ // in terms of integer-sized values, so there's no potential for overrun (as
+ // of CUDA 5.5).
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b
+ static bool FuncGetAttribute(CUfunction_attribute attribute,
+ CUfunction function, int *attribute_value);
+
+ // Sets the preferred cache configuration for the specified function.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g40f8c11e81def95dc0072a375f965681
+ static bool FuncSetCacheConfig(CUfunction function,
+ CUfunc_cache cache_config);
+
+ // Gets the preferred shared memory bank configuration for the specified
+ // CONTEXT (not function!), either default or four- or eight-byte bank size.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g17153a1b8b8c756f7ab8505686a4ad74
+ static port::StatusOr<CUsharedconfig> ContextGetSharedMemConfig(
+ CUcontext context);
+
+ // Sets the preferred shared memory bank configuration for the specified
+ // CONTEXT (not function!), either default or four- or eight-byte bank size.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g2574235fa643f8f251bf7bc28fac3692
+ static port::Status ContextSetSharedMemConfig(
+ CUcontext context, CUsharedconfig shared_mem_config);
+
+ // Launches a CUDA kernel via cuLaunchKernel.
+ // TODO(leary) describe the structure of kernel_params and extra in a readable
+ // way.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
+ static bool LaunchKernel(CUcontext context, CUfunction function,
+ unsigned int grid_dim_x, unsigned int grid_dim_y,
+ unsigned int grid_dim_z, unsigned int block_dim_x,
+ unsigned int block_dim_y, unsigned int block_dim_z,
+ unsigned int shared_mem_bytes, CUstream stream,
+ void **kernel_params, void **extra);
+
+ // Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting
+ // handle in "module". Any error logs that are produced are logged internally.
+ static bool LoadPtx(CUcontext context, const char *ptx_contents,
+ CUmodule *module);
+
+ // Loads cubin_bytes with the CUDA driver's blob loading interface and stores
+ // the resulting handle in "module".
+ static port::Status LoadCubin(CUcontext context, const char *cubin_bytes,
+ CUmodule *module);
+
+ // Retrieves a named kernel from a loaded module, and places the resulting
+ // handle into function (outparam) on success. Neither kernel_name nor
+ // function may be null. No ownership is taken of kernel_name.
+ static bool GetModuleFunction(CUcontext context, CUmodule module,
+ const char *kernel_name, CUfunction *function);
+
+ // Retrieves a named global/constant symbol from a loaded module, and returns
+ // a device pointer and size of the symbol on success. symbol_name may not be
+ // null. At least one of dptr or bytes should not be null. No ownership is
+ // taken of symbol_name.
+ static bool GetModuleSymbol(CUcontext context, CUmodule module,
+ const char *symbol_name, CUdeviceptr *dptr,
+ size_t *bytes);
+
+ // Unloads module from the current context via cuModuleUnload.
+ // TODO(leary) the documentation doesn't say what kind of disasters happen
+ // if you try to unload a module while its CUfunctions are in use.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html#group__CUDA__MODULE_1g8ea3d716524369de3763104ced4ea57b
+ static void UnloadModule(CUcontext context, CUmodule module);
+
+ // Performs a synchronous memset of the device memory segment via cuMemsetD8.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6e582bf866e9e2fb014297bfaf354d7b
+ static bool SynchronousMemsetUint8(CUcontext context, CUdeviceptr location,
+ uint8 value, size_t size);
+
+ // Performs a synchronous memset of the device memory segment via cuMemsetD32.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g983e8d8759acd1b64326317481fbf132
+ static bool SynchronousMemsetUint32(CUcontext context, CUdeviceptr location,
+ uint32 value, size_t uint32_count);
+
+ // Performs an asynchronous memset of the device memory segment via
+ // cuMemsetD32Async.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g58229da5d30f1c0cdf667b320ec2c0f5
+ static bool AsynchronousMemsetUint32(CUcontext context, CUdeviceptr location,
+ uint32 value, size_t uint32_count,
+ CUstream stream);
+
+ // -- Synchronous memcopies.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g4d32266788c440b0220b1a9ba5795169
+
+ static bool SynchronousMemcpyD2H(CUcontext context, void *host_dst,
+ CUdeviceptr gpu_src, uint64 size);
+ static bool SynchronousMemcpyH2D(CUcontext context, CUdeviceptr gpu_dst,
+ const void *host_src, uint64 size);
+ static bool SynchronousMemcpyD2D(CUcontext context, CUdeviceptr gpu_dst,
+ CUdeviceptr gpu_src, uint64 size);
+
+ // -- Asynchronous memcopies.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g56f30236c7c5247f8e061b59d3268362
+
+ static bool AsynchronousMemcpyD2H(CUcontext context, void *host_dst,
+ CUdeviceptr gpu_src, uint64 size,
+ CUstream stream);
+ static bool AsynchronousMemcpyH2D(CUcontext context, CUdeviceptr gpu_dst,
+ const void *host_src, uint64 size,
+ CUstream stream);
+ static bool AsynchronousMemcpyD2D(CUcontext context, CUdeviceptr gpu_dst,
+ CUdeviceptr gpu_src, uint64 size,
+ CUstream stream);
+
+ // The CUDA stream callback type signature.
+ // The data passed to AddStreamCallback is subsequently passed to this
+ // callback when it fires.
+ //
+ // Some notable things:
+ // * Callbacks must not make any CUDA API calls.
+ // * Callbacks from independent streams execute in an undefined order and may
+ // be serialized.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g613d97a277d7640f4cb1c03bd51c2483
+ typedef void (*StreamCallback)(CUstream stream, CUresult status, void *data);
+
+ // Enqueues a callback operation into stream.
+ // See StreamCallback above and the NVIDIA documentation for additional
+ // details.
+ static bool AddStreamCallback(CUcontext context, CUstream stream,
+ StreamCallback callback, void *data);
+
+ // Causes stream to wait for event to trigger before proceeding via
+ // cuStreamWaitEvent.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#axzz334nAXAhM
+ static bool WaitStreamOnEvent(CUcontext context, CUstream stream,
+ CUevent event);
+
+ // Blocks the calling thread until the operations enqueued onto stream have
+ // been completed, via cuStreamSynchronize.
+ //
+ // TODO(leary) if a pathological thread enqueues operations onto the stream
+ // while another thread blocks like this, can you wind up waiting an unbounded
+ // amount of time?
+ //
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g15e49dd91ec15991eb7c0a741beb7dad
+ static bool SynchronizeStream(CUcontext context, CUstream stream);
+
+ // Blocks the calling thread until the operations associated with the context
+ // have been completed, via cuCtxSynchronize.
+ //
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g7a54725f28d34b8c6299f0c6ca579616
+ static bool SynchronizeContext(CUcontext context);
+
+ // Returns true if all stream tasks have completed at time of the call. Note
+ // the potential for races around this call (if another thread adds work to
+ // the stream immediately after this returns).
+ static bool IsStreamIdle(CUcontext context, CUstream stream);
+
+ // Returns whether code in the from context can access memory in the to
+ // context via cuDeviceCanAccessPeer.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g496bdaae1f632ebfb695b99d2c40f19e
+ static bool CanEnablePeerAccess(CUcontext from, CUcontext to);
+
+ // Enables peer access per CanEnablePeerAccess, via cuCtxEnablePeerAccess.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g0889ec6728e61c05ed359551d67b3f5a
+ static port::Status EnablePeerAccess(CUcontext from, CUcontext to);
+
+ // Returns the elapsed milliseconds between start and stop via
+ // cuEventElapsedTime.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1gdfb1178807353bbcaa9e245da497cf97
+ static bool GetEventElapsedTime(CUcontext context,
+ float *elapsed_milliseconds, CUevent start,
+ CUevent stop);
+
+ // Records that an event occurred when execution reaches the current point in
+ // thestream via cuEventRecord.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g95424d3be52c4eb95d83861b70fb89d1
+ static port::Status RecordEvent(CUcontext context, CUevent event,
+ CUstream stream);
+
+ // Polls (without blocking) to determine the status of an event - pending or
+ // complete (or an error status).
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g6f0704d755066b0ee705749ae911deef
+ static port::StatusOr<CUresult> QueryEvent(CUcontext context, CUevent event);
+
+ // -- Pointer-specific calls.
+
+ // Returns the context in which pointer was allocated or registered.
+ static port::StatusOr<CUcontext> GetPointerContext(CUdeviceptr pointer);
+
+ // Returns the device associated with the context from GetPointerContext().
+ static port::StatusOr<CUdevice> GetPointerDevice(CUdeviceptr pointer);
+
+ // Returns the memory space addressed by pointer.
+ static port::StatusOr<MemorySpace> GetPointerMemorySpace(CUdeviceptr pointer);
+
+ // Returns the base address and size of the device pointer dptr.
+ static port::Status GetPointerAddressRange(CUdeviceptr dptr,
+ CUdeviceptr *base, size_t *size);
+
+ // -- Device-specific calls.
+
+ // Returns the compute capability for the device; i.e (3, 5).
+ // This is currently done via the deprecated device API.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1ge2091bbac7e1fb18c2821612115607ea
+ static port::Status GetComputeCapability(int *cc_major, int *cc_minor,
+ CUdevice device);
+
+ // Returns the number of multiprocessors on the device (note that the device
+ // may be multi-GPU-per-board).
+ static port::StatusOr<int> GetMultiprocessorCount(CUdevice device);
+
+ // Returns the limit on number of threads that can be resident in a single
+ // multiprocessor.
+ static port::StatusOr<int64> GetMaxThreadsPerMultiprocessor(CUdevice device);
+
+ // Returns the limit on number of threads which may be resident for a single
+ // block (cooperative thread array).
+ static port::StatusOr<int64> GetMaxThreadsPerBlock(CUdevice device);
+
+ // Returns the amount of shared memory available on a single GPU core (i.e.
+ // SM on NVIDIA devices).
+ static port::StatusOr<int64> GetMaxSharedMemoryPerCore(CUdevice device);
+
+ // Returns the amount of shared memory available for a single block
+ // (cooperative thread array).
+ static port::StatusOr<int64> GetMaxSharedMemoryPerBlock(CUdevice device);
+
+ // Returns the maximum supported number of registers per block.
+ static port::StatusOr<int64> GetMaxRegistersPerBlock(CUdevice device);
+
+ // Returns the number of threads per warp.
+ static port::StatusOr<int64> GetThreadsPerWarp(CUdevice device);
+
+ // Queries the grid limits for device with cuDeviceGetAttribute calls.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266
+ static bool GetGridLimits(int *x, int *y, int *z, CUdevice device);
+
+ // Returns a grab-bag of device properties in a caller-owned device_properties
+ // structure for device_ordinal via cuDeviceGetProperties.
+ // This call is deprecated in the NVIDIA driver API.
+ //
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1g65a5b4e25186bd257df80b98c98cffe6
+ static bool GetDeviceProperties(CUdevprop *device_properties,
+ int device_ordinal);
+
+ // Returns whether ECC is enabled for the given CUdevice via
+ // cuDeviceGetattribute with CU_DEVICE_ATTRIBUTE_ECC_ENABLED.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266
+ static bool IsEccEnabled(CUdevice device, bool *result);
+
+ // Returns the total amount of memory available for allocation by the CUDA
+ // context, in bytes, via cuDeviceTotalMem.
+ static bool GetDeviceTotalMemory(CUdevice device, uint64 *result);
+
+ // Returns the free amount of memory and total amount of memory, as reported
+ // by cuMemGetInfo.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g808f555540d0143a331cc42aa98835c0
+ static bool GetDeviceMemoryInfo(CUcontext context, int64 *free, int64 *total);
+
+ // Returns a PCI bus id string for the device.
+ // [domain]:[bus]:[device].[function]
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g85295e7d9745ab8f0aa80dd1e172acfc
+ static string GetPCIBusID(CUdevice device);
+
+ // -- Context- and device-independent calls.
+
+ // Returns the number of visible CUDA device via cuDeviceGetCount.
+ // This should correspond to the set of device ordinals available.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g52b5ce05cb8c5fb6831b2c0ff2887c74
+ static int GetDeviceCount();
+
+ // Returns the driver version number via cuDriverGetVersion.
+ // This is, surprisingly, NOT the actual driver version (e.g. 331.79) but,
+ // instead, the CUDA toolkit release number that this driver is compatible
+ // with; e.g. 6000 (for a CUDA 6.0 compatible driver) or 6050 (for a CUDA 6.5
+ // compatible driver).
+ //
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VERSION.html#group__CUDA__VERSION_1g8b7a10395392e049006e61bcdc8ebe71
+ static bool GetDriverVersion(int *driver_version);
+
+ // -- Other calls
+
+ // Returns the maximum number of blocks (per multiprocessor) occupied by the
+ // specified kernel/CUfunction when launched with the specified parameters.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__OCCUPANCY.html#group__CUDA__OCCUPANCY_1gcc6e1094d05cba2cee17fe33ddd04a98
+ static port::StatusOr<int> GetMaxOccupiedBlocksPerCore(
+ CUcontext context, CUfunction kernel, int threads_per_block,
+ size_t dynamic_shared_memory_bytes);
+
+ // Seam for injecting an error at CUDA initialization time for testing
+ // purposes.
+ static bool driver_inject_init_error_;
+};
+
+// Ensures a context is activated within a scope.
+class ScopedActivateContext {
+ public:
+ // Activates the context via cuCtxSetCurrent, if it is not the currently
+ // active context (a la cuCtxGetCurrent). Note the alternative push/pop
+ // mechanism is said by NVIDIA to be relatively slow and deprecated.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1gbe562ee6258b4fcc272ca6478ca2a2f7
+ explicit ScopedActivateContext(
+ CUcontext context, MultiOpActivation moa = MultiOpActivation::kNo);
+
+ // Checks that the context has remained activated for the duration of the
+ // scope.
+ ~ScopedActivateContext();
+
+ private:
+ CUcontext context_; // context being activated.
+
+ CUcontext prior_context_; // context that was active when we were activated.
+
+ // Stores whether this was instantiated during a MultiOpActivation, in which
+ // case we will not pop the context when we're destroyed (we will leave it to
+ // the parent MultiOpActivation that we were nested within).
+ bool previously_in_multi_op_activation_;
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_event.cc b/tensorflow/stream_executor/cuda/cuda_event.cc
new file mode 100644
index 0000000000..a87c868c6b
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_event.cc
@@ -0,0 +1,56 @@
+#include "tensorflow/stream_executor/cuda/cuda_event.h"
+
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+CUDAEvent::CUDAEvent(CUDAExecutor* parent)
+ : parent_(parent), cuda_event_(nullptr) {}
+
+CUDAEvent::~CUDAEvent() {}
+
+port::Status CUDAEvent::Init() {
+ return CUDADriver::CreateEvent(parent_->cuda_context(), &cuda_event_,
+ CUDADriver::EventFlags::kDisableTiming);
+}
+
+port::Status CUDAEvent::Destroy() {
+ return CUDADriver::DestroyEvent(parent_->cuda_context(), &cuda_event_);
+}
+
+port::Status CUDAEvent::Record(CUDAStream* stream) {
+ return CUDADriver::RecordEvent(parent_->cuda_context(), cuda_event_,
+ stream->cuda_stream());
+}
+
+Event::Status CUDAEvent::PollForStatus() {
+ port::StatusOr<CUresult> status =
+ CUDADriver::QueryEvent(parent_->cuda_context(), cuda_event_);
+ if (!status.ok()) {
+ LOG(ERROR) << "Error polling for event status: "
+ << status.status().error_message();
+ return Event::Status::kError;
+ }
+
+ switch (status.ValueOrDie()) {
+ case CUDA_SUCCESS:
+ return Event::Status::kComplete;
+ case CUDA_ERROR_NOT_READY:
+ return Event::Status::kPending;
+ default:
+ LOG(INFO) << "Error condition returned for event status: "
+ << status.ValueOrDie();
+ return Event::Status::kError;
+ }
+}
+
+const CUevent& CUDAEvent::cuda_event() {
+ return cuda_event_;
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_event.h b/tensorflow/stream_executor/cuda/cuda_event.h
new file mode 100644
index 0000000000..c5b65662db
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_event.h
@@ -0,0 +1,49 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
+
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/lib/status.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// CUDAEvent wraps a CUevent in the platform-independent EventInterface
+// interface.
+class CUDAEvent : public internal::EventInterface {
+ public:
+ explicit CUDAEvent(CUDAExecutor* parent);
+
+ ~CUDAEvent() override;
+
+ // Populates the CUDA-platform-specific elements of this object.
+ port::Status Init();
+
+ // Deallocates any platform-specific elements of this object. This is broken
+ // out (not part of the destructor) to allow for error reporting.
+ port::Status Destroy();
+
+ // Inserts the event at the current position into the specified stream.
+ port::Status Record(CUDAStream* stream);
+
+ // Polls the CUDA platform for the event's current status.
+ Event::Status PollForStatus();
+
+ // The underyling CUDA event element.
+ const CUevent& cuda_event();
+
+ private:
+ // The Executor used to which this object and CUevent are bound.
+ CUDAExecutor* parent_;
+
+ // The underlying CUDA event element.
+ CUevent cuda_event_;
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc
new file mode 100644
index 0000000000..59c3159895
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_fft.cc
@@ -0,0 +1,327 @@
+#include "tensorflow/stream_executor/cuda/cuda_fft.h"
+
+#include <dlfcn.h>
+
+#include <complex>
+
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuFftPlugin);
+
+namespace dynload {
+
+// This macro wraps a global identifier, given by __name, in a callable
+// structure that loads the DLL symbol out of the DSO handle in a thread-safe
+// manner on first use. This dynamic loading technique is used to avoid DSO
+// dependencies on vendor libraries which may or may not be available in the
+// deployed binary environment.
+#define PERFTOOLS_GPUTOOLS_CUFFT_WRAP(__name) \
+ struct DynLoadShim__##__name { \
+ static const char *kName; \
+ using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
+ static void *GetDsoHandle() { \
+ static auto status = internal::CachedDsoLoader::GetCufftDsoHandle(); \
+ return status.ValueOrDie(); \
+ } \
+ static FuncPointerT DynLoad() { \
+ static void *f = dlsym(GetDsoHandle(), kName); \
+ CHECK(f != nullptr) << "could not find " << kName \
+ << " in cuFFT DSO; dlerror: " << dlerror(); \
+ return reinterpret_cast<FuncPointerT>(f); \
+ } \
+ template <typename... Args> \
+ cufftResult operator()(CUDAExecutor * parent, Args... args) { \
+ cuda::ScopedActivateExecutorContext sac{parent}; \
+ return DynLoad()(args...); \
+ } \
+ } __name; \
+ const char *DynLoadShim__##__name::kName = #__name;
+
+#define CUFFT_ROUTINE_EACH(__macro) \
+ __macro(cufftDestroy) __macro(cufftSetStream) __macro(cufftPlan1d) \
+ __macro(cufftPlan2d) __macro(cufftPlan3d) __macro(cufftPlanMany) \
+ __macro(cufftExecD2Z) __macro(cufftExecZ2D) __macro(cufftExecC2C) \
+ __macro(cufftExecC2R) __macro(cufftExecZ2Z) \
+ __macro(cufftExecR2C)
+
+CUFFT_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUFFT_WRAP)
+
+} // namespace dynload
+
+namespace {
+
+// A helper function transforming gpu_fft arguments into cuFFT arguments.
+cufftType CUDAFftType(fft::Type type) {
+ switch (type) {
+ case fft::Type::kC2CForward:
+ case fft::Type::kC2CInverse:
+ return CUFFT_C2C;
+ case fft::Type::kC2R:
+ return CUFFT_C2R;
+ case fft::Type::kR2C:
+ return CUFFT_R2C;
+ case fft::Type::kZ2ZForward:
+ case fft::Type::kZ2ZInverse:
+ return CUFFT_Z2Z;
+ case fft::Type::kZ2D:
+ return CUFFT_Z2D;
+ case fft::Type::kD2Z:
+ return CUFFT_D2Z;
+ default:
+ LOG(FATAL) << "Invalid value of fft::Type.";
+ }
+}
+
+// Associates the given stream with the given cuFFT plan.
+bool SetStream(CUDAExecutor *parent, cufftHandle plan, Stream *stream) {
+ auto ret = dynload::cufftSetStream(parent, plan, AsCUDAStreamValue(stream));
+ if (ret != CUFFT_SUCCESS) {
+ LOG(ERROR) << "failed to run cuFFT routine cufftSetStream: " << ret;
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, fft::Type type)
+ : parent_(parent), fft_type_(type) {
+ auto ret = dynload::cufftPlan1d(parent, &plan_, num_x, CUDAFftType(type),
+ 1 /* = batch */);
+ if (ret != CUFFT_SUCCESS) {
+ LOG(ERROR) << "failed to create cuFFT 1d plan:" << ret;
+ }
+}
+
+CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y,
+ fft::Type type)
+ : parent_(parent), fft_type_(type) {
+ auto ret =
+ dynload::cufftPlan2d(parent, &plan_, num_x, num_y, CUDAFftType(type));
+ if (ret != CUFFT_SUCCESS) {
+ LOG(ERROR) << "failed to create cuFFT 2d plan:" << ret;
+ }
+}
+
+CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y,
+ uint64 num_z, fft::Type type)
+ : parent_(parent), fft_type_(type) {
+ auto ret = dynload::cufftPlan3d(parent, &plan_, num_x, num_y, num_z,
+ CUDAFftType(type));
+ if (ret != CUFFT_SUCCESS) {
+ LOG(ERROR) << "failed to create cuFFT 3d plan:" << ret;
+ }
+}
+
+CUDAFftPlan::CUDAFftPlan(CUDAExecutor *parent, int rank, uint64 *elem_count,
+ uint64 *input_embed, uint64 input_stride,
+ uint64 input_distance, uint64 *output_embed,
+ uint64 output_stride, uint64 output_distance,
+ fft::Type type, int batch_count)
+ : parent_(parent), fft_type_(type) {
+ int elem_count_[3], input_embed_[3], output_embed_[3];
+ for (int i = 0; i < rank; ++i) {
+ elem_count_[i] = elem_count[i];
+ if (input_embed) {
+ input_embed_[i] = input_embed[i];
+ }
+ if (output_embed) {
+ output_embed_[i] = output_embed[i];
+ }
+ }
+ auto ret = dynload::cufftPlanMany(
+ parent, &plan_, rank, elem_count_, input_embed ? input_embed_ : nullptr,
+ input_stride, input_distance, output_embed ? output_embed_ : nullptr,
+ output_stride, output_distance, CUDAFftType(type), batch_count);
+ if (ret != CUFFT_SUCCESS) {
+ LOG(ERROR) << "failed to create cuFFT batched plan:" << ret;
+ }
+}
+
+CUDAFftPlan::~CUDAFftPlan() { dynload::cufftDestroy(parent_, plan_); }
+
+int CUDAFftPlan::GetFftDirection() const {
+ switch (fft_type_) {
+ case fft::Type::kC2CForward:
+ case fft::Type::kZ2ZForward:
+ case fft::Type::kR2C:
+ case fft::Type::kD2Z:
+ return CUFFT_FORWARD;
+ case fft::Type::kC2CInverse:
+ case fft::Type::kZ2ZInverse:
+ case fft::Type::kC2R:
+ case fft::Type::kZ2D:
+ return CUFFT_INVERSE;
+ default:
+ LOG(FATAL) << "Invalid value of fft::Type.";
+ }
+}
+
+std::unique_ptr<fft::Plan> CUDAFft::Create1dPlan(Stream *stream, uint64 num_x,
+ fft::Type type,
+ bool in_place_fft) {
+ std::unique_ptr<fft::Plan> plan{new CUDAFftPlan(parent_, num_x, type)};
+ return plan;
+}
+
+std::unique_ptr<fft::Plan> CUDAFft::Create2dPlan(Stream *stream, uint64 num_x,
+ uint64 num_y, fft::Type type,
+ bool in_place_fft) {
+ std::unique_ptr<fft::Plan> plan{new CUDAFftPlan(parent_, num_x, num_y, type)};
+ return plan;
+}
+
+std::unique_ptr<fft::Plan> CUDAFft::Create3dPlan(Stream *stream, uint64 num_x,
+ uint64 num_y, uint64 num_z,
+ fft::Type type,
+ bool in_place_fft) {
+ std::unique_ptr<fft::Plan> plan{
+ new CUDAFftPlan(parent_, num_x, num_y, num_z, type)};
+ return plan;
+}
+
+std::unique_ptr<fft::Plan> CUDAFft::CreateBatchedPlan(
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed,
+ uint64 output_stride, uint64 output_distance, fft::Type type,
+ bool in_place_fft, int batch_count) {
+ std::unique_ptr<fft::Plan> plan{new CUDAFftPlan(
+ parent_, rank, elem_count, input_embed, input_stride, input_distance,
+ output_embed, output_stride, output_distance, type, batch_count)};
+ return plan;
+}
+
+template <typename FuncT, typename InputT, typename OutputT>
+bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec,
+ const DeviceMemory<InputT> &input,
+ DeviceMemory<OutputT> *output) {
+ CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
+ if (cuda_fft_plan == nullptr) {
+ LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
+ return false;
+ }
+
+ if (!SetStream(parent_, cuda_fft_plan->GetPlan(), stream)) {
+ return false;
+ }
+
+ auto ret = cufftExec(parent_, cuda_fft_plan->GetPlan(),
+ CUDAComplex(const_cast<InputT *>(CUDAMemory(input))),
+ CUDAComplex(CUDAMemoryMutable(output)));
+
+ if (ret != CUFFT_SUCCESS) {
+ LOG(ERROR) << "failed to run cuFFT routine: " << ret;
+ return false;
+ }
+
+ return true;
+}
+
+template <typename FuncT, typename InputT, typename OutputT>
+bool CUDAFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
+ FuncT cufftExec,
+ const DeviceMemory<InputT> &input,
+ DeviceMemory<OutputT> *output) {
+ CUDAFftPlan *cuda_fft_plan = dynamic_cast<CUDAFftPlan *>(plan);
+ if (cuda_fft_plan == nullptr) {
+ LOG(ERROR) << "the passed-in plan is not a CUDAFftPlan object.";
+ return false;
+ }
+
+ if (!SetStream(parent_, cuda_fft_plan->GetPlan(), stream)) {
+ return false;
+ }
+
+ auto ret = cufftExec(parent_, cuda_fft_plan->GetPlan(),
+ CUDAComplex(const_cast<InputT *>(CUDAMemory(input))),
+ CUDAComplex(CUDAMemoryMutable(output)),
+ cuda_fft_plan->GetFftDirection());
+
+ if (ret != CUFFT_SUCCESS) {
+ LOG(ERROR) << "failed to run cuFFT routine: " << ret;
+ return false;
+ }
+
+ return true;
+}
+
+#define PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
+ __fft_type3) \
+ bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<__type>> &input, \
+ DeviceMemory<std::complex<__type>> *output) { \
+ return DoFftWithDirectionInternal( \
+ stream, plan, dynload::cufftExec##__fft_type1, input, output); \
+ } \
+ bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<__type> &input, \
+ DeviceMemory<std::complex<__type>> *output) { \
+ return DoFftInternal(stream, plan, dynload::cufftExec##__fft_type2, input, \
+ output); \
+ } \
+ bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<__type>> &input, \
+ DeviceMemory<__type> *output) { \
+ return DoFftInternal(stream, plan, dynload::cufftExec##__fft_type3, input, \
+ output); \
+ }
+
+PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(float, C2C, R2C, C2R)
+PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
+
+#undef PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+namespace gpu = ::perftools::gputools;
+
+REGISTER_MODULE_INITIALIZER(register_cufft, {
+ gpu::port::Status status =
+ gpu::PluginRegistry::Instance()
+ ->RegisterFactory<gpu::PluginRegistry::FftFactory>(
+ gpu::cuda::kCudaPlatformId, gpu::cuda::kCuFftPlugin, "cuFFT",
+ [](gpu::internal::StreamExecutorInterface
+ *parent) -> gpu::fft::FftSupport * {
+ gpu::cuda::CUDAExecutor *cuda_executor =
+ dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
+ if (cuda_executor == nullptr) {
+ LOG(ERROR)
+ << "Attempting to initialize an instance of the cuFFT "
+ << "support library with a non-CUDA StreamExecutor";
+ return nullptr;
+ }
+
+ return new gpu::cuda::CUDAFft(cuda_executor);
+ });
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to register cuFFT factory: "
+ << status.error_message();
+ }
+
+ // Prime the cuFFT DSO. The loader will log more information.
+ auto statusor = gpu::internal::CachedDsoLoader::GetCufftDsoHandle();
+ if (!statusor.ok()) {
+ LOG(INFO) << "Unable to load cuFFT DSO.";
+ }
+
+ gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
+ gpu::PluginKind::kFft,
+ gpu::cuda::kCuFftPlugin);
+});
diff --git a/tensorflow/stream_executor/cuda/cuda_fft.h b/tensorflow/stream_executor/cuda/cuda_fft.h
new file mode 100644
index 0000000000..2577c2952e
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_fft.h
@@ -0,0 +1,95 @@
+// CUDA-specific support for FFT functionality -- this wraps the cuFFT library
+// capabilities, and is only included into CUDA implementation code -- it will
+// not introduce cuda headers into other code.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_
+
+#include "tensorflow/stream_executor/fft.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "third_party/gpus/cuda/include/cufft.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+
+namespace cuda {
+
+class CUDAExecutor;
+
+// Opaque and unique indentifier for the cuFFT plugin.
+extern const PluginId kCuFftPlugin;
+
+class CUDAFftPlan : public fft::Plan {
+ public:
+ // Constructor creating 1d FFT plan.
+ CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, fft::Type type);
+ // Constructor creating 2d FFT plan.
+ CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y, fft::Type type);
+ // Constructor creating 3d FFT plan.
+ CUDAFftPlan(CUDAExecutor *parent, uint64 num_x, uint64 num_y, uint64 num_z,
+ fft::Type type);
+ // Constructor creating batched FFT plan.
+ CUDAFftPlan(CUDAExecutor *parent, int rank, uint64 *elem_count,
+ uint64 *input_embed, uint64 input_stride, uint64 input_distance,
+ uint64 *output_embed, uint64 output_stride,
+ uint64 output_distance, fft::Type type, int batch_count);
+ ~CUDAFftPlan() override;
+
+ // Get FFT direction in cuFFT based on FFT type.
+ int GetFftDirection() const;
+ cufftHandle GetPlan() const { return plan_; }
+
+ private:
+ CUDAExecutor *parent_;
+ cufftHandle plan_;
+ fft::Type fft_type_;
+};
+
+// FFT support for CUDA platform via cuFFT library.
+//
+// This satisfies the platform-agnostic FftSupport interface.
+//
+// Note that the cuFFT handle that this encapsulates is implicitly tied to the
+// context (and, as a result, the device) that the parent CUDAExecutor is tied
+// to. This simply happens as an artifact of creating the cuFFT handle when a
+// CUDA context is active.
+//
+// Thread-safe. The CUDA context associated with all operations is the CUDA
+// context of parent_, so all context is explicit.
+class CUDAFft : public fft::FftSupport {
+ public:
+ explicit CUDAFft(CUDAExecutor *parent) : parent_(parent) {}
+ ~CUDAFft() override {}
+
+ TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
+
+ private:
+ CUDAExecutor *parent_;
+
+ // Two helper functions that execute dynload::cufftExec?2?.
+
+ // This is for complex to complex FFT, when the direction is required.
+ template <typename FuncT, typename InputT, typename OutputT>
+ bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
+ FuncT cufft_exec,
+ const DeviceMemory<InputT> &input,
+ DeviceMemory<OutputT> *output);
+
+ // This is for complex to real or real to complex FFT, when the direction
+ // is implied.
+ template <typename FuncT, typename InputT, typename OutputT>
+ bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufft_exec,
+ const DeviceMemory<InputT> &input,
+ DeviceMemory<OutputT> *output);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CUDAFft);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
new file mode 100644
index 0000000000..77f16e2a6e
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -0,0 +1,1082 @@
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+
+#include <unistd.h>
+
+#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_event.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+#include "tensorflow/stream_executor/cuda/cuda_timer.h"
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/kernel_cache_config.h"
+#include "tensorflow/stream_executor/lib/casts.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/mathutil.h"
+#include "tensorflow/stream_executor/lib/path.h"
+#include "tensorflow/stream_executor/lib/process_state.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/lib/str_util.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+#include "tensorflow/stream_executor/timer.h"
+#include "tensorflow/stream_executor/lib/numbers.h"
+
+#ifdef PLATFORMS_GPUS_CUDA_DYNAMIC_LIBCUDA_DYNAMIC_LIBCUDA_H_
+#error \
+ "No driver calls in this file, wrap driver functionality in cuda_driver.cc."
+#endif
+
+#ifdef __CUDA_RUNTIME_H__
+#error \
+ "CUDA runtime being included into CUDA GPU executor; should be driver only."
+#endif
+
+extern bool FLAGS_check_gpu_leaks;
+tensorflow::int32 FLAGS_register_occupancy_warning_threshold;
+bool FLAGS_prefer_cubin_to_ptx = true;
+
+namespace perftools {
+namespace gputools {
+namespace rng {
+class RngSupport;
+} // namespace rng
+} // namespace gputools
+} // namespace perftools
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// Hook that can be used to CUBIN-ate PTX before it is loaded into the driver.
+// It has been observed that loading both PTX and cubins into the driver library
+// can cause it to crash, but loading only CUBINs avoids those crashes;
+// therefore, it's useful to have this hook to hack in uniform CUBIN-ation of
+// PTX code.
+//
+// As this is an implementation-detail workaround, the usage is to declare this
+// variable with extern linkage and populate it from another translation unit.
+std::function<string(const string &)> g_cubinate;
+
+static CUDAEvent *AsCUDAEvent(Event *event) {
+ DCHECK(event != nullptr);
+ return static_cast<CUDAEvent *>(event->implementation());
+}
+
+// Given a platform-independent stream datatype, returns the internal CUDA
+// platform implementation pointer.
+static CUDAStream *AsCUDAStream(Stream *stream) {
+ DCHECK(stream != nullptr);
+ return static_cast<CUDAStream *>(stream->implementation());
+}
+
+// Given a platform-independent stream datatype, returns the platform
+// implementation's internal value, suitable for passing directly to libcuda
+// APIs.
+CUstream AsCUDAStreamValue(Stream *stream) {
+ DCHECK(stream != nullptr);
+ return AsCUDAStream(stream)->cuda_stream();
+}
+
+// Given a platform-independent timer datatype, returns the internal CUDA
+// platform implementation pointer.
+static CUDATimer *AsCUDATimer(Timer *timer) {
+ DCHECK(timer != nullptr);
+ return static_cast<CUDATimer *>(timer->implementation());
+}
+
+// Given const GPU memory, returns a libcuda device pointer datatype, suitable
+// for passing directly to libcuda APIs.
+//
+// N.B. we must lose constness in order to pass a suitable type to the existing
+// libcuda APIs, so the caller should take care to only pass the result of const
+// GPU memory conversions to libcuda functions which will honor constness.
+static CUdeviceptr AsCudaDevicePtr(const DeviceMemoryBase &gpu_mem) {
+ return reinterpret_cast<CUdeviceptr>(gpu_mem.opaque());
+}
+
+// See description on const version above.
+static CUdeviceptr AsCudaDevicePtr(DeviceMemoryBase *gpu_mem) {
+ return AsCudaDevicePtr(*gpu_mem);
+}
+
+static CUcontext GetCudaContext(Stream *stream) {
+ return static_cast<CUDAExecutor *>(stream->parent()->implementation())
+ ->cuda_context();
+}
+
+CUcontext ExtractCudaContext(CUDAExecutor *cuda_exec) {
+ CHECK(cuda_exec != nullptr);
+ return cuda_exec->cuda_context();
+}
+
+CUDAExecutor *ExtractCudaExecutor(StreamExecutor *stream_exec) {
+ return static_cast<CUDAExecutor *>(stream_exec->implementation());
+}
+
+CUDAExecutor::~CUDAExecutor() {
+ for (auto &it : disk_modules_) {
+ CUDADriver::UnloadModule(context_, it.second);
+ }
+ for (auto &it : in_memory_modules_) {
+ CUDADriver::UnloadModule(context_, it.second);
+ }
+ if (context_ != nullptr) {
+ CUDADriver::DestroyContext(context_);
+ }
+}
+
+port::Status CUDAExecutor::Init(int device_ordinal,
+ DeviceOptions device_options) {
+ device_ordinal_ = device_ordinal;
+
+ auto status = CUDADriver::Init();
+ if (!status.ok()) {
+ return status;
+ }
+
+ status = CUDADriver::GetDevice(device_ordinal_, &device_);
+ if (!status.ok()) {
+ return status;
+ }
+
+ status = CUDADriver::CreateContext(device_, device_options, &context_);
+ if (!status.ok()) {
+ return status;
+ }
+
+ return CUDADriver::GetComputeCapability(&cc_major_, &cc_minor_, device_);
+}
+
+bool CUDAExecutor::FindOnDiskForComputeCapability(
+ port::StringPiece filename, port::StringPiece canonical_suffix,
+ string *found_filename) const {
+ if (cc_major_ == 0 && cc_minor_ == 0) {
+ return false;
+ }
+
+ // TODO(22689637): Eliminate unnecessary ToString()s when all dependencies
+ // have been migrated.
+ string cc_specific = port::StrCat(filename.ToString(), ".cc", cc_major_,
+ cc_minor_, canonical_suffix.ToString());
+ if (port::FileExists(cc_specific)) {
+ VLOG(2) << "found compute-capability-specific file, using that: "
+ << cc_specific;
+ *found_filename = cc_specific;
+ return true;
+ }
+
+ VLOG(2) << "could not find compute-capability specific file at: "
+ << cc_specific;
+ if (port::FileExists(filename.ToString())) {
+ *found_filename = filename.ToString();
+ return true;
+ }
+
+ return false;
+}
+
+// Returns the path to the running executable.
+// N.B. Derived from //knowledge/smalltalk/background_kb.cc
+// Arg: strip_exe: if true, remove the name of the executable itself from the
+// returned string. Example: calling this from /usr/bin/foo
+// would return /usr/bin.
+static string GetBinaryDir(bool strip_exe) {
+ char exe_path[PATH_MAX] = {0};
+ CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1));
+ // Make sure it's null-terminated:
+ exe_path[sizeof(exe_path) - 1] = 0;
+
+ if (strip_exe) {
+ // The exe is the last component of the path, so remove one component.
+ string ret = exe_path;
+ std::vector<string> components = port::Split(exe_path, '/');
+ components.pop_back();
+ return port::Join(components, "/");
+ }
+ return exe_path;
+}
+
+// Returns the location of the runfiles directory.
+// This is the directory which "bazel run" sets as the current working directory
+// before the program starts.
+// N.B. This doesn't have to be running under "bazel run" in order to get the
+// appropriate runfiles directory.
+static string GetRunfilesDir() {
+ return port::StrCat(GetBinaryDir(false), ".runfiles");
+}
+
+bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
+ KernelBase *kernel) {
+ CUDAKernel *cuda_kernel = AsCUDAKernel(kernel);
+ CUmodule module = nullptr;
+ const string *kernelname;
+
+ const OnDiskKernelLoaderSpec *on_disk_spec = nullptr;
+ bool has_ptx = spec.has_cuda_ptx_on_disk();
+ bool has_cubin = spec.has_cuda_cubin_on_disk();
+ if (has_cubin && (!has_ptx || FLAGS_prefer_cubin_to_ptx)) {
+ on_disk_spec = &spec.cuda_cubin_on_disk();
+ } else if (has_ptx) {
+ on_disk_spec = &spec.cuda_ptx_on_disk();
+ }
+
+ if (on_disk_spec != nullptr) {
+ } else if (spec.has_cuda_ptx_in_memory()) {
+ kernelname = &spec.cuda_ptx_in_memory().kernelname();
+
+ if (cc_major_ == 0 && cc_minor_ == 0) {
+ return false;
+ }
+
+ // Note that the orignal ptx may be compressed, and the ptx we get below is
+ // the decompressed result. To cache the module we should use the original
+ // ptx (compressed one) as the key. This is because for the same compressed
+ // ptx, we may get different decompressed ptx wrt the pointer value.
+ const char *ptx = spec.cuda_ptx_in_memory().text(cc_major_, cc_minor_);
+ const char *orig_ptx =
+ spec.cuda_ptx_in_memory().original_text(cc_major_, cc_minor_);
+ if (ptx == nullptr || orig_ptx == nullptr) {
+ ptx = spec.cuda_ptx_in_memory().default_text();
+ orig_ptx = spec.cuda_ptx_in_memory().original_default_text();
+ }
+ if (ptx == nullptr || orig_ptx == nullptr) {
+ LOG(FATAL) << "could not load ptx for kernel " << kernelname;
+ return false;
+ }
+
+ mutex_lock lock{in_memory_modules_mu_};
+ module = in_memory_modules_[orig_ptx];
+
+ if (module == nullptr) {
+ if (g_cubinate == nullptr) {
+ if (!CUDADriver::LoadPtx(context_, ptx, &module)) {
+ return false;
+ }
+ } else {
+ string cubin = g_cubinate(ptx);
+ auto load_status =
+ CUDADriver::LoadCubin(context_, cubin.c_str(), &module);
+ if (!load_status.ok()) {
+ LOG(ERROR) << "failed to load cubin via hook: " << load_status;
+ return false;
+ }
+ }
+ in_memory_modules_[orig_ptx] = module;
+ }
+ } else if (spec.has_cuda_cubin_in_memory()) {
+ kernelname = &spec.cuda_cubin_in_memory().kernelname();
+ const char *cubin = spec.cuda_cubin_in_memory().bytes();
+ mutex_lock lock{in_memory_modules_mu_};
+ module = in_memory_modules_[cubin];
+
+ if (module == nullptr) {
+ auto load_status = CUDADriver::LoadCubin(context_, cubin, &module);
+ if (!load_status.ok()) {
+ LOG(ERROR) << "failed to load CUBIN: " << load_status;
+ return false;
+ }
+
+ in_memory_modules_[cubin] = module;
+ }
+ } else {
+ LOG(WARNING) << "no method of loading CUDA kernel provided";
+ return false;
+ }
+
+ VLOG(2) << "getting function " << kernelname << " from module " << module;
+ if (!CUDADriver::GetModuleFunction(context_, module, kernelname->c_str(),
+ cuda_kernel->cuda_function_ptr())) {
+ return false;
+ }
+
+ // We have to trust the kernel loader spec arity because there doesn't appear
+ // to be a way to reflect on the number of expected arguments w/the CUDA API.
+ cuda_kernel->set_arity(spec.arity());
+
+ KernelMetadata kernel_metadata;
+ if (!GetKernelMetadata(cuda_kernel, &kernel_metadata)) {
+ LOG(WARNING) << "Unable to get metadata for kernel " << kernelname;
+ }
+ kernel->set_metadata(kernel_metadata);
+ kernel->set_name(*kernelname);
+ return true;
+}
+
+bool CUDAExecutor::GetKernelMetadata(CUDAKernel *cuda_kernel,
+ KernelMetadata *kernel_metadata) {
+ int value;
+ if (!CUDADriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_NUM_REGS,
+ *cuda_kernel->cuda_function_ptr(),
+ &value)) {
+ return false;
+ }
+ kernel_metadata->set_registers_per_thread(value);
+
+ if (!CUDADriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
+ *cuda_kernel->cuda_function_ptr(),
+ &value)) {
+ return false;
+ }
+ kernel_metadata->set_shared_memory_bytes(value);
+
+ return true;
+}
+
+bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &kernel,
+ const std::vector<KernelArg> &args) {
+ CHECK_EQ(kernel.Arity(), args.size());
+ CUstream custream = AsCUDAStreamValue(stream);
+ const CUDAKernel *cuda_kernel = AsCUDAKernel(&kernel);
+ CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue();
+
+ std::vector<void *> addrs;
+ addrs.reserve(args.size());
+ int shmem_bytes = 0;
+ for (size_t i = 0; i < args.size(); i++) {
+ switch (args[i].type) {
+ case KernelArg::kNormal:
+ addrs.push_back(const_cast<void *>(
+ static_cast<const void *>(args[i].data.begin())));
+ break;
+ case KernelArg::kSharedMemory:
+ shmem_bytes += args[i].bytes;
+ break;
+ default:
+ LOG(ERROR) << "Invalid kernel arg type passed (" << args[i].type
+ << ") for arg " << i;
+ return false;
+ }
+ }
+
+ // Only perform/print the occupancy check 1x.
+ launched_kernels_mu_.lock();
+ if (launched_kernels_.find(cufunc) == launched_kernels_.end()) {
+ OccupancyCheck(kernel, thread_dims, block_dims);
+ // TODO(rspringer): Remove elements from launched_kernels_...if we ever
+ // expose a kernel/module deallocation method.
+ launched_kernels_.insert(cufunc);
+ }
+ launched_kernels_mu_.unlock();
+
+ if (cuda_kernel->GetPreferredCacheConfig() !=
+ KernelCacheConfig::kNoPreference) {
+ CUDADriver::FuncSetCacheConfig(cufunc, cuda_kernel->GetCUDACacheConfig());
+ }
+
+ if (!CUDADriver::LaunchKernel(
+ GetCudaContext(stream), cufunc, block_dims.x, block_dims.y,
+ block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z,
+ shmem_bytes, custream, addrs.data(), nullptr /* = extra */)) {
+ LOG(ERROR) << "failed to launch CUDA kernel with args: " << args.size()
+ << "; thread dim: " << thread_dims.ToString()
+ << "; block dim: " << block_dims.ToString();
+ return false;
+ }
+
+ return true;
+}
+
+// This is a non-essential operation; if there's a failure, proceed without
+// logging an error. It's nearly certain that in case of failures, we'd never
+// get here in the first place; these are very low-impact routines.
+void CUDAExecutor::OccupancyCheck(const KernelBase &kernel,
+ const ThreadDim &thread_dims,
+ const BlockDim &block_dims) {
+ VLOG(2) << "Computing kernel occupancy for kernel "
+ << kernel.demangled_name();
+ VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y
+ << ", " << thread_dims.z << ")";
+
+ int regs_per_thread;
+ if (!kernel.metadata().registers_per_thread(&regs_per_thread)) {
+ return;
+ }
+
+ int smem_per_block;
+ if (!kernel.metadata().shared_memory_bytes(&smem_per_block)) {
+ return;
+ }
+
+ const DeviceDescription &device_description =
+ kernel.parent()->GetDeviceDescription();
+
+ uint64 blocks_per_sm = CalculateOccupancy(
+ device_description, regs_per_thread, smem_per_block, thread_dims);
+ VLOG(2) << "Resident blocks per SM is " << blocks_per_sm;
+
+ // To increase occupancy, there must be a sufficient number of blocks
+ // available to spread across the sm's at this new improved occupancy level.
+ int multiprocessor_count = device_description.core_count();
+ int block_count = block_dims.x * block_dims.y * block_dims.z;
+ int available_blocks_per_sm =
+ port::MathUtil::CeilOfRatio(block_count, multiprocessor_count);
+ if (available_blocks_per_sm <= static_cast<int64>(blocks_per_sm)) {
+ VLOG(2) << "Occupancy is limited by number of blocks available per sm.";
+ return;
+ }
+
+ uint64 improved_regs_per_thread = CalculateRegisterLimitForTargetOccupancy(
+ device_description, smem_per_block, thread_dims, blocks_per_sm + 1);
+ if (improved_regs_per_thread != 0) {
+ VLOG(2) << "Reducing register usage from " << regs_per_thread
+ << " to " << improved_regs_per_thread
+ << " could increase resident blocks per SM by one.";
+
+ uint64 reg_reduction = regs_per_thread - improved_regs_per_thread;
+ if (reg_reduction <=
+ static_cast<uint64>(FLAGS_register_occupancy_warning_threshold)) {
+ LOG(INFO) << "Notice: occupancy would increase if register usage was"
+ << " reduced from " << regs_per_thread
+ << " to " << improved_regs_per_thread
+ << " registers per thread for kernel: "
+ << kernel.demangled_name();
+ }
+ } else {
+ VLOG(2) << "Resident blocks per SM cannot be increased by reducing "
+ "register usage.";
+ }
+}
+
+void *CUDAExecutor::Allocate(uint64 size) {
+ return CUDADriver::DeviceAllocate(context_, size);
+}
+
+void *CUDAExecutor::AllocateSubBuffer(DeviceMemoryBase *mem,
+ uint64 offset_bytes, uint64 size_bytes) {
+ // offset and size are in bytes, so char* works as the pointer type.
+ return reinterpret_cast<char *>(mem->opaque()) + offset_bytes;
+}
+
+void CUDAExecutor::Deallocate(DeviceMemoryBase *mem) {
+ // CUDA "sub-buffers" are just pointer + offset, so no dealloc is necessary.
+ if (!mem->is_sub_buffer()) {
+ CUDADriver::DeviceDeallocate(context_, mem->opaque());
+ }
+}
+
+bool CUDAExecutor::HostMemoryRegister(void *location, uint64 size) {
+ if (location == nullptr || size == 0) {
+ LOG(WARNING) << "attempting to register null or zero-sized memory: "
+ << location << "; size " << size;
+ }
+ VLOG(2) << "registering " << location << " size " << size;
+ return CUDADriver::HostRegister(context_, location, size);
+}
+
+bool CUDAExecutor::HostMemoryUnregister(void *location) {
+ VLOG(2) << "unregistering " << location;
+ return CUDADriver::HostUnregister(context_, location);
+}
+
+bool CUDAExecutor::SynchronizeAllActivity() {
+ return CUDADriver::SynchronizeContext(context_);
+}
+
+bool CUDAExecutor::SynchronousMemZero(DeviceMemoryBase *location, uint64 size) {
+ if (reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
+ size % 4 == 0) {
+ return CUDADriver::SynchronousMemsetUint32(
+ context_, AsCudaDevicePtr(location), 0x0, size / 4);
+ }
+ return CUDADriver::SynchronousMemsetUint8(context_, AsCudaDevicePtr(location),
+ 0x0, size);
+}
+
+bool CUDAExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) {
+ if (reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
+ size % 4 == 0) {
+ // cudaMemset reinterprets "value" as a uint8.
+ uint8 byte_value = static_cast<uint8>(value);
+ uint32 pattern = (byte_value << 24) | (byte_value << 16) |
+ (byte_value << 8) | byte_value;
+ return CUDADriver::SynchronousMemsetUint32(
+ context_, AsCudaDevicePtr(location), pattern, size / 4);
+ }
+ return CUDADriver::SynchronousMemsetUint8(context_, AsCudaDevicePtr(location),
+ value, size);
+}
+
+bool CUDAExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
+ const void *host_src, uint64 size) {
+ return CUDADriver::SynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst),
+ host_src, size);
+}
+
+bool CUDAExecutor::SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) {
+ return CUDADriver::SynchronousMemcpyD2H(context_, host_dst,
+ AsCudaDevicePtr(gpu_src), size);
+}
+
+bool CUDAExecutor::SynchronousMemcpyDeviceToDevice(
+ DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64 size) {
+ return CUDADriver::SynchronousMemcpyD2D(context_, AsCudaDevicePtr(gpu_dst),
+ AsCudaDevicePtr(gpu_src), size);
+}
+
+bool CUDAExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) {
+ return Memset32(stream, location, 0x0, size);
+}
+
+bool CUDAExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
+ uint32 pattern, uint64 size) {
+ VLOG(2) << "enqueueing memset32 operation onto stream " << stream
+ << " at location " << location << " with size " << size
+ << " and pattern " << std::hex << pattern;
+ CHECK(reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
+ size % 4 == 0);
+ return CUDADriver::AsynchronousMemsetUint32(
+ context_, AsCudaDevicePtr(location), pattern, size / 4,
+ AsCUDAStreamValue(stream));
+}
+
+bool CUDAExecutor::Memcpy(Stream *stream, void *host_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size) {
+ return CUDADriver::AsynchronousMemcpyD2H(context_, host_dst,
+ AsCudaDevicePtr(gpu_src), size,
+ AsCUDAStreamValue(stream));
+}
+
+bool CUDAExecutor::Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
+ const void *host_src, uint64 size) {
+ return CUDADriver::AsynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst),
+ host_src, size,
+ AsCUDAStreamValue(stream));
+}
+
+bool CUDAExecutor::MemcpyDeviceToDevice(Stream *stream,
+ DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) {
+ return CUDADriver::AsynchronousMemcpyD2D(context_, AsCudaDevicePtr(gpu_dst),
+ AsCudaDevicePtr(gpu_src), size,
+ AsCUDAStreamValue(stream));
+}
+
+bool CUDAExecutor::HostCallback(Stream *stream,
+ std::function<void()> callback) {
+ auto callback_ptr = new std::function<void()>(callback);
+ return CUDADriver::AddStreamCallback(context_, AsCUDAStreamValue(stream),
+ InternalHostCallback, callback_ptr);
+}
+
+/* static */ void CUDAExecutor::InternalHostCallback(CUstream stream,
+ CUresult status,
+ void *data) {
+ std::function<void()> *callback =
+ reinterpret_cast<std::function<void()> *>(data);
+ (*callback)();
+ delete callback;
+}
+
+port::Status CUDAExecutor::AllocateEvent(Event *event) {
+ return AsCUDAEvent(event)->Init();
+}
+
+port::Status CUDAExecutor::DeallocateEvent(Event *event) {
+ return AsCUDAEvent(event)->Destroy();
+}
+
+port::Status CUDAExecutor::RecordEvent(Stream *stream, Event *event) {
+ return AsCUDAEvent(event)->Record(AsCUDAStream(stream));
+}
+
+port::Status CUDAExecutor::WaitForEvent(Stream *stream, Event *event) {
+ if (CUDADriver::WaitStreamOnEvent(context_,
+ AsCUDAStream(stream)->cuda_stream(),
+ AsCUDAEvent(event)->cuda_event())) {
+ return port::Status::OK();
+ } else {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf("error recording waiting for CUDA event on stream %p",
+ stream)};
+ }
+}
+
+Event::Status CUDAExecutor::PollForEventStatus(Event *event) {
+ return AsCUDAEvent(event)->PollForStatus();
+}
+
+bool CUDAExecutor::AllocateStream(Stream *stream) {
+ return AsCUDAStream(stream)->Init();
+}
+
+void CUDAExecutor::DeallocateStream(Stream *stream) {
+ CUDAStream *cuda_stream = AsCUDAStream(stream);
+ if (!cuda_stream->IsIdle()) {
+ LOG(ERROR) << "Deallocating stream with pending work";
+ }
+ cuda_stream->Destroy();
+}
+
+bool CUDAExecutor::AllocateTimer(Timer *timer) {
+ return AsCUDATimer(timer)->Init();
+}
+
+void CUDAExecutor::DeallocateTimer(Timer *timer) {
+ AsCUDATimer(timer)->Destroy();
+}
+
+bool CUDAExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
+ CUevent other_completed_event;
+ bool ok =
+ AsCUDAStream(other)->GetOrCreateCompletedEvent(&other_completed_event);
+ if (!ok) {
+ LOG(ERROR) << "failed to get completion event from other; "
+ "therefore, failed to create inter-stream dependency";
+ return false;
+ }
+
+ ok = CUDADriver::RecordEvent(context_, other_completed_event,
+ AsCUDAStreamValue(other))
+ .ok();
+ if (!ok) {
+ LOG(ERROR) << "failed to record completion event; "
+ "therefore, failed to create inter-stream dependency";
+ return false;
+ }
+
+ return CUDADriver::WaitStreamOnEvent(context_, AsCUDAStreamValue(dependent),
+ other_completed_event);
+}
+
+bool CUDAExecutor::StartTimer(Stream *stream, Timer *timer) {
+ return AsCUDATimer(timer)->Start(AsCUDAStream(stream));
+}
+
+bool CUDAExecutor::StopTimer(Stream *stream, Timer *timer) {
+ return AsCUDATimer(timer)->Stop(AsCUDAStream(stream));
+}
+
+bool CUDAExecutor::BlockHostUntilDone(Stream *stream) {
+ return CUDADriver::SynchronizeStream(context_, AsCUDAStreamValue(stream));
+}
+
+blas::BlasSupport *CUDAExecutor::CreateBlas() {
+ PluginRegistry *registry = PluginRegistry::Instance();
+ port::StatusOr<PluginRegistry::BlasFactory> status =
+ registry->GetFactory<PluginRegistry::BlasFactory>(kCudaPlatformId,
+ plugin_config_.blas());
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to retrieve BLAS factory: "
+ << status.status().error_message();
+ return nullptr;
+ }
+
+ return status.ValueOrDie()(this);
+}
+
+dnn::DnnSupport *CUDAExecutor::CreateDnn() {
+ PluginRegistry *registry = PluginRegistry::Instance();
+ port::StatusOr<PluginRegistry::DnnFactory> status =
+ registry->GetFactory<PluginRegistry::DnnFactory>(kCudaPlatformId,
+ plugin_config_.dnn());
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to retrieve DNN factory: "
+ << status.status().error_message();
+ return nullptr;
+ }
+
+ return status.ValueOrDie()(this);
+}
+
+fft::FftSupport *CUDAExecutor::CreateFft() {
+ PluginRegistry *registry = PluginRegistry::Instance();
+ port::StatusOr<PluginRegistry::FftFactory> status =
+ registry->GetFactory<PluginRegistry::FftFactory>(kCudaPlatformId,
+ plugin_config_.fft());
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to retrieve FFT factory: "
+ << status.status().error_message();
+ return nullptr;
+ }
+
+ return status.ValueOrDie()(this);
+}
+
+rng::RngSupport *CUDAExecutor::CreateRng() {
+ PluginRegistry *registry = PluginRegistry::Instance();
+ port::StatusOr<PluginRegistry::RngFactory> status =
+ registry->GetFactory<PluginRegistry::RngFactory>(kCudaPlatformId,
+ plugin_config_.rng());
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to retrieve RNG factory: "
+ << status.status().error_message();
+ return nullptr;
+ }
+
+ return status.ValueOrDie()(this);
+}
+
+// TODO(rspringer): Remove in b/18544742.
+bool CUDAExecutor::SupportsDnn() const {
+ return true;
+}
+
+bool CUDAExecutor::CanEnablePeerAccessTo(StreamExecutorInterface *other) {
+ CUDAExecutor *cuda_other = static_cast<CUDAExecutor *>(other);
+ return CUDADriver::CanEnablePeerAccess(context_, cuda_other->context_);
+}
+
+port::Status CUDAExecutor::EnablePeerAccessTo(StreamExecutorInterface *other) {
+ CUDAExecutor *cuda_other = static_cast<CUDAExecutor *>(other);
+ return CUDADriver::EnablePeerAccess(context_, cuda_other->context_);
+}
+
+SharedMemoryConfig CUDAExecutor::GetDeviceSharedMemoryConfig() {
+ port::StatusOr<CUsharedconfig> cuda_config =
+ CUDADriver::ContextGetSharedMemConfig(context_);
+ if (!cuda_config.ok()) {
+ // Don't log; the failed call will log necessary output.
+ return SharedMemoryConfig::kDefault;
+ }
+
+ switch (cuda_config.ValueOrDie()) {
+ case CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE:
+ return SharedMemoryConfig::kDefault;
+ case CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE:
+ return SharedMemoryConfig::kFourByte;
+ case CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE:
+ return SharedMemoryConfig::kEightByte;
+ default:
+ LOG(FATAL) << "Invalid shared memory configuration returned: "
+ << cuda_config.ValueOrDie();
+ }
+}
+
+port::Status CUDAExecutor::SetDeviceSharedMemoryConfig(
+ SharedMemoryConfig config) {
+ CUsharedconfig cuda_config;
+ switch (config) {
+ case SharedMemoryConfig::kDefault:
+ cuda_config = CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE;
+ break;
+ case SharedMemoryConfig::kFourByte:
+ cuda_config = CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE;
+ break;
+ case SharedMemoryConfig::kEightByte:
+ cuda_config = CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE;
+ break;
+ default:
+ LOG(FATAL) << "Invalid shared memory configuration specified: "
+ << static_cast<int>(config);
+ }
+ return CUDADriver::ContextSetSharedMemConfig(context_, cuda_config);
+}
+
+bool CUDAExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
+ return CUDADriver::GetDeviceMemoryInfo(context_, free, total);
+}
+
+bool CUDAExecutor::GetSymbol(const string& symbol_name, void **mem,
+ size_t *bytes) {
+ { // give limited scope to mutex_lock
+ mutex_lock lock{disk_modules_mu_};
+ for (auto &it : disk_modules_) {
+ if (CUDADriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
+ reinterpret_cast<CUdeviceptr *>(mem),
+ bytes)) {
+ return true;
+ }
+ }
+ }
+
+ { // give limited scope to mutex_lock
+ mutex_lock lock{in_memory_modules_mu_};
+ for (auto &it : in_memory_modules_) {
+ if (CUDADriver::GetModuleSymbol(context_, it.second, symbol_name.c_str(),
+ reinterpret_cast<CUdeviceptr *>(mem),
+ bytes)) {
+ return true;
+ }
+ }
+ }
+
+ LOG(INFO) << "Falied to find symbol in any modules: " << symbol_name;
+ return false;
+}
+
+bool CUDAExecutor::FillBlockDimLimit(BlockDim *block_dim_limit) const {
+ // The BlockDim name is a mismatch against these GRID_DIM_* queries because
+ // we use BlockDims to express the dimensions of blocks within a grid
+ // (as opposed to ThreadDim which expresses the dimensions of threads
+ // within a block).
+ int x, y, z;
+ if (!CUDADriver::GetGridLimits(&x, &y, &z, device_)) {
+ return false;
+ }
+
+ block_dim_limit->x = x;
+ block_dim_limit->y = y;
+ block_dim_limit->z = z;
+ return true;
+}
+
+KernelArg CUDAExecutor::DeviceMemoryToKernelArg(
+ const DeviceMemoryBase &gpu_mem) const {
+ const void* arg = gpu_mem.opaque();
+ const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg);
+
+ KernelArg kernel_arg;
+ kernel_arg.type = KernelArg::kNormal;
+ kernel_arg.data = port::InlinedVector<uint8, 4>(arg_ptr, arg_ptr + sizeof(arg));
+ kernel_arg.bytes = sizeof(arg);
+ return kernel_arg;
+}
+
+bool CUDAExecutor::SupportsBlas() const { return true; }
+
+bool CUDAExecutor::SupportsFft() const { return true; }
+
+bool CUDAExecutor::SupportsRng() const { return true; }
+
+void *CUDAExecutor::CudaContextHack() { return context_; }
+
+CUcontext CUDAExecutor::cuda_context() { return context_; }
+
+// Attemps to read the NUMA node corresponding to the GPU device's PCI bus out
+// of SysFS. Returns -1 if it cannot.
+//
+// For anything more complicated/prod-focused than this, you'll likely want to
+// turn to gsys' topology modeling.
+static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) {
+ VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal;
+ static const int kUnknownNumaNode = -1;
+
+ if (pci_bus_id.empty()) {
+ LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal;
+ return kUnknownNumaNode;
+ }
+
+ string filename =
+ port::Printf("/sys/bus/pci/devices/%s/numa_node", pci_bus_id.c_str());
+
+ // We have to use fopen/fread here so that the device properties can be
+ // populated before InitGoogle procedure has been completed (at which point we
+ // could use the file::* utilities).
+ FILE *file = fopen(filename.c_str(), "r");
+ if (file == nullptr) {
+ LOG(ERROR) << "could not open file to read NUMA node: " << filename;
+ return kUnknownNumaNode;
+ }
+
+ string content;
+ char buf[32];
+ size_t did_read = fread(buf, sizeof(buf[0]), sizeof(buf) - 1, file);
+ buf[did_read] = '\0';
+ content = buf;
+
+ int32 value;
+ if (port::safe_strto32(content, &value)) {
+ if (value < 0) { // See http://b/18228951 for details on this path.
+ LOG(INFO) << "successful NUMA node read from SysFS had negative value ("
+ << value << "), but there must be at least one NUMA node"
+ ", so returning NUMA node zero";
+ return 0;
+ }
+ return value;
+ }
+
+ LOG(WARNING)
+ << "could not convert SysFS file contents to integral NUMA node value: "
+ << content;
+
+ return kUnknownNumaNode;
+}
+
+// Set of compute capability specific device parameters that cannot be
+// queried from the driver API. These values instead are baked into a
+// lookup table indexed by compute capability version.
+struct UnqueryableDeviceParams {
+ int cc_major;
+ int cc_minor;
+ uint64 blocks_per_core_limit;
+ uint64 registers_per_core_limit;
+ uint64 registers_per_thread_limit;
+ uint64 warp_alloc_granularity;
+ uint64 register_alloc_granularity;
+ uint64 shared_memory_alloc_granularity;
+};
+
+static const UnqueryableDeviceParams kAllUnqueryableDeviceParams[] = {
+ {
+ 3, 5, // compute capability (3.5)
+ 16, // blocks_per_core_limit
+ 64 * 1024, // registers_per_core_limit
+ 255, // registers_per_thread_limit
+ 4, // warp_alloc_granularity
+ 256, // register_alloc_granularity
+ 256 // shared_memory_alloc_granularity
+ }
+};
+
+DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
+ internal::DeviceDescriptionBuilder builder;
+
+ {
+ int driver_version = 0;
+ (void)CUDADriver::GetDriverVersion(&driver_version);
+ string augmented_driver_version = port::Printf(
+ "%d (%s)", driver_version,
+ DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
+ builder.set_driver_version(augmented_driver_version);
+ }
+
+ {
+ string pci_bus_id = CUDADriver::GetPCIBusID(device_);
+
+ // Lower the hex characters to match sysfs.
+ pci_bus_id = port::Lowercase(pci_bus_id);
+ builder.set_pci_bus_id(pci_bus_id);
+
+ // Read the NUMA node corresponding to the PCI bus ID out of sysfs.
+ int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal_);
+ builder.set_numa_node(numa_node);
+ }
+
+ CUdevprop prop;
+ if (CUDADriver::GetDeviceProperties(&prop, device_ordinal_)) {
+ builder.set_threads_per_block_limit(prop.maxThreadsPerBlock);
+
+ ThreadDim thread_dim_limit;
+ thread_dim_limit.x = prop.maxThreadsDim[0];
+ thread_dim_limit.y = prop.maxThreadsDim[1];
+ thread_dim_limit.z = prop.maxThreadsDim[2];
+ builder.set_thread_dim_limit(thread_dim_limit);
+
+ float clock_rate_ghz = static_cast<float>(prop.clockRate) / 1e6;
+ builder.set_clock_rate_ghz(clock_rate_ghz);
+ }
+
+ {
+ bool ecc_enabled = false;
+ (void)CUDADriver::IsEccEnabled(device_, &ecc_enabled);
+ builder.set_ecc_enabled(ecc_enabled);
+ }
+
+ {
+ uint64 device_memory_size = -1;
+ (void)CUDADriver::GetDeviceTotalMemory(device_, &device_memory_size);
+ builder.set_device_memory_size(device_memory_size);
+ }
+
+ {
+ BlockDim block_dim_limit;
+ FillBlockDimLimit(&block_dim_limit);
+ builder.set_block_dim_limit(block_dim_limit);
+ }
+
+ {
+ string device_name;
+ (void)CUDADriver::GetDeviceName(device_, &device_name);
+ builder.set_name(device_name);
+ }
+
+ for (size_t i = 0; i < ARRAYSIZE(kAllUnqueryableDeviceParams); i++) {
+ const auto &params = kAllUnqueryableDeviceParams[i];
+ if (params.cc_major == cc_major_ && params.cc_minor == cc_minor_) {
+ builder.set_blocks_per_core_limit(params.blocks_per_core_limit);
+ builder.set_registers_per_core_limit(params.registers_per_core_limit);
+ builder.set_registers_per_thread_limit(params.registers_per_thread_limit);
+ builder.set_warp_alloc_granularity(params.warp_alloc_granularity);
+ builder.set_register_alloc_granularity(params.register_alloc_granularity);
+ builder.set_shared_memory_alloc_granularity(
+ params.shared_memory_alloc_granularity);
+ }
+ }
+
+ builder.set_platform_version(
+ port::StrCat("Compute Capability ", cc_major_, ".", cc_minor_));
+
+ // TODO(leary) should be a way to query this from the driver, but this is
+ // unlikely to change for us any time soon.
+ builder.set_device_address_bits(64);
+
+ builder.set_device_vendor("NVIDIA Corporation");
+ builder.set_cuda_compute_capability(cc_major_, cc_minor_);
+ builder.set_shared_memory_per_core(
+ CUDADriver::GetMaxSharedMemoryPerCore(device_).ValueOrDie());
+ builder.set_shared_memory_per_block(
+ CUDADriver::GetMaxSharedMemoryPerBlock(device_).ValueOrDie());
+ builder.set_core_count(
+ CUDADriver::GetMultiprocessorCount(device_).ValueOrDie());
+ builder.set_threads_per_core_limit(
+ CUDADriver::GetMaxThreadsPerMultiprocessor(device_).ValueOrDie());
+ builder.set_registers_per_block_limit(
+ CUDADriver::GetMaxRegistersPerBlock(device_).ValueOrDie());
+ builder.set_threads_per_warp(
+ CUDADriver::GetThreadsPerWarp(device_).ValueOrDie());
+
+ auto built = builder.Build();
+ return built.release();
+}
+
+} // namespace cuda
+
+namespace gpu = ::perftools::gputools;
+
+void initialize_cuda_gpu_executor() {
+ port::StatusOr<void *> status =
+ gpu::internal::CachedDsoLoader::GetLibcudaDsoHandle();
+ if (!status.ok()) {
+ gpu::cuda::Diagnostician::LogDriverVersionInformation();
+ LOG(INFO) << "LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
+ LOG(INFO) << "failed to find libcuda.so on this system: "
+ << status.status();
+ }
+
+ // TODO(b/22689637): Temporary until users are migrated off of PlatformKind.
+ gpu::PluginRegistry::Instance()->MapPlatformKindToId(
+ gpu::PlatformKind::kCuda, gpu::cuda::kCudaPlatformId);
+
+ *gpu::internal::MakeCUDAExecutorImplementation() = [](
+ const gpu::PluginConfig &config) {
+ return new gpu::cuda::CUDAExecutor{config};
+ };
+
+ *gpu::internal::MakeCUDAKernelImplementation() = []() {
+ return new gpu::cuda::CUDAKernel;
+ };
+
+ *gpu::internal::MakeCUDAEventImplementation() = [](
+ gpu::StreamExecutor *parent) {
+ gpu::cuda::CUDAExecutor *cuda_executor =
+ static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
+ return new gpu::cuda::CUDAEvent{cuda_executor};
+ };
+
+ *gpu::internal::MakeCUDAStreamImplementation() = [](
+ gpu::StreamExecutor *parent) {
+ gpu::cuda::CUDAExecutor *cuda_executor =
+ static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
+ return new gpu::cuda::CUDAStream{cuda_executor};
+ };
+ *gpu::internal::MakeCUDATimerImplementation() = [](
+ gpu::StreamExecutor *parent) {
+ gpu::cuda::CUDAExecutor *cuda_executor =
+ static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
+ return new gpu::cuda::CUDATimer{cuda_executor};
+ };
+}
+
+} // namespace gputools
+} // namespace perftools
+
+REGISTER_MODULE_INITIALIZER(
+ cuda_gpu_executor, {perftools::gputools::initialize_cuda_gpu_executor();});
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
new file mode 100644
index 0000000000..fda89b9738
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -0,0 +1,270 @@
+// The CUDA implementation of the StreamExecutorInterface functionality.
+// CUDA inclusions are ideally confined to this implementation file.
+//
+// The notions from the StreamExecutor basically correspond to the CUDA streams
+// programming model provided by the libcuda.so driver APIs, so we don't have
+// to do much more than wrap the calls to the libraries appropriately.
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_GPU_EXECUTOR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_GPU_EXECUTOR_H_
+
+#include <map>
+#include <set>
+
+#include "tensorflow/stream_executor/cuda/cuda_kernel.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace perftools {
+namespace gputools {
+namespace blas {
+class BlasSupport;
+}
+namespace internal {
+class RngSupport;
+} // namespace internal
+} // namespace gputools
+} // namespace perftools
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// CUDA-platform implementation of the platform-agnostic
+// StreamExecutorInferface.
+class CUDAExecutor : public internal::StreamExecutorInterface {
+ public:
+ // sub_platform indicates the subplatform used in this executor; it must
+ // be a CUDA type.
+ explicit CUDAExecutor(const PluginConfig &plugin_config)
+ : device_(0),
+ context_(nullptr),
+ device_ordinal_(0),
+ cc_major_(0),
+ cc_minor_(0),
+ plugin_config_(plugin_config) {}
+
+ // See the corresponding StreamExecutor methods for method comments on the
+ // following overrides.
+
+ ~CUDAExecutor() override;
+
+ port::Status Init(int device_ordinal, DeviceOptions device_options) override;
+
+ bool GetKernel(const MultiKernelLoaderSpec &spec,
+ KernelBase *kernel) override;
+
+ bool Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &k,
+ const std::vector<KernelArg> &args) override;
+
+ void *Allocate(uint64 size) override;
+
+ void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes,
+ uint64 size_bytes) override;
+
+ void Deallocate(DeviceMemoryBase *mem) override;
+
+ // CUDA allocation/registration functions are necessary because the driver
+ // internally sets up buffers for DMA operations (and page locks them).
+ // There's no external interface for us to otherwise control these DMA
+ // settings.
+ void *HostMemoryAllocate(uint64 size) override {
+ return CUDADriver::HostAllocate(context_, size);
+ }
+
+ void HostMemoryDeallocate(void *location) override {
+ return CUDADriver::HostDeallocate(context_, location);
+ }
+
+ bool HostMemoryRegister(void *location, uint64 size) override;
+
+ bool HostMemoryUnregister(void *location) override;
+
+ bool SynchronizeAllActivity() override;
+
+ bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override;
+
+ bool SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) override;
+
+ bool SynchronousMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size) override;
+
+ bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size) override;
+
+ bool SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) override;
+
+ bool MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) override;
+ bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
+ uint64 size) override;
+
+ bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size) override;
+
+ bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size) override;
+
+ bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) override;
+
+ bool HostCallback(Stream *stream, std::function<void()> callback) override;
+
+ bool AllocateStream(Stream *stream) override;
+
+ void DeallocateStream(Stream *stream) override;
+
+ bool CreateStreamDependency(Stream *dependent, Stream *other) override;
+
+ bool AllocateTimer(Timer *timer) override;
+
+ void DeallocateTimer(Timer *timer) override;
+
+ bool StartTimer(Stream *stream, Timer *timer) override;
+
+ bool StopTimer(Stream *stream, Timer *timer) override;
+
+ port::Status AllocateEvent(Event *event) override;
+
+ port::Status DeallocateEvent(Event *event) override;
+
+ port::Status RecordEvent(Stream *stream, Event *event) override;
+
+ port::Status WaitForEvent(Stream *stream, Event *event) override;
+
+ Event::Status PollForEventStatus(Event *event) override;
+
+ bool BlockHostUntilDone(Stream *stream) override;
+
+ int PlatformDeviceCount() override { return CUDADriver::GetDeviceCount(); }
+
+ port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override;
+
+ bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override;
+
+ SharedMemoryConfig GetDeviceSharedMemoryConfig() override;
+
+ port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override;
+
+ bool DeviceMemoryUsage(int64 *free, int64 *total) const override;
+
+ // Search for the symbol and returns a device pointer and size.
+ // Returns false if symbol does not exist.
+ bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) override;
+
+ DeviceDescription *PopulateDeviceDescription() const override;
+
+ // Populates the block_dim_limit by querying the device driver API. If an
+ // error occurs at any point while asking the driver for block dim limits, it
+ // will be only partially populated as a result, and an error will be logged.
+ bool FillBlockDimLimit(BlockDim *block_dim_limit) const;
+
+ KernelArg DeviceMemoryToKernelArg(
+ const DeviceMemoryBase &gpu_mem) const override;
+
+ bool SupportsBlas() const override;
+
+ blas::BlasSupport *CreateBlas() override;
+
+ bool SupportsFft() const override;
+
+ fft::FftSupport *CreateFft() override;
+
+ bool SupportsRng() const override;
+
+ rng::RngSupport *CreateRng() override;
+
+ bool SupportsDnn() const override;
+
+ dnn::DnnSupport *CreateDnn() override;
+
+ void *CudaContextHack() override;
+
+ CUcontext cuda_context();
+
+ private:
+ // Attempts to find a more specific version of the file indicated by
+ // filename by looking for compute-capability-specific suffixed versions; i.e.
+ // looking for "foo.ptx" will check to see if "foo.ptx.cc30.ptx" is present if
+ // we're on a compute capability 3.0 machine.
+ bool FindOnDiskForComputeCapability(port::StringPiece filename,
+ port::StringPiece canonical_suffix,
+ string *found_filename) const;
+
+ // Host callback landing routine invoked by CUDA.
+ // data: User-provided callback provided to HostCallback() above, captured
+ // as a std::function<void()>. Allocated/initialized inside
+ // HostCallback() and owned and deleted by this call.
+ static void InternalHostCallback(CUstream stream, CUresult status,
+ void *data);
+
+ // Collects metadata for the specified kernel.
+ bool GetKernelMetadata(CUDAKernel *cuda_kernel,
+ KernelMetadata *kernel_metadata);
+
+ // Determines if the given kernel's occupancy could be improved by only
+ // slightly reducing its register usage. If so, a message is emitted to the
+ // INFO log. The warning threshold is controlled by the flag
+ // register_occupancy_warning_threshold.
+ void OccupancyCheck(const KernelBase &kernel, const ThreadDim &thread_dims,
+ const BlockDim &block_dims);
+
+ // Guards the on-disk-module mapping.
+ mutex disk_modules_mu_;
+
+ // Mapping from filename to CUmodule, if it was already retrieved.
+ // Multiple CUfunctions are usually obtained from a single CUmodule so we
+ // attempt to hit in this mapping first, before retrieving it.
+ std::map<string, CUmodule> disk_modules_ GUARDED_BY(disk_modules_mu_);
+
+ // Guards the in-memory-module mapping.
+ mutex in_memory_modules_mu_;
+
+ std::map<const char *, CUmodule> in_memory_modules_
+ GUARDED_BY(in_memory_modules_mu_);
+
+ // Guards the launched kernel set.
+ mutex launched_kernels_mu_;
+
+ // Keeps track of the set of launched kernels. Currently used to suppress the
+ // occupancy check on subsequent launches.
+ std::set<CUfunction> launched_kernels_ GUARDED_BY(launched_kernels_mu_);
+
+ // Handle for the CUDA device being operated on. Immutable
+ // post-initialization.
+ CUdevice device_;
+
+ // Handle for session with the library/driver. Immutable post-initialization.
+ CUcontext context_;
+
+ // The device ordinal value that this executor was initialized with; recorded
+ // for use in getting device metadata. Immutable post-initialization.
+ int device_ordinal_;
+
+ // The major verion of the compute capability for device_.
+ int cc_major_;
+
+ // The minor verion of the compute capability for device_.
+ int cc_minor_;
+
+ // The plugin configuration associated with this instance.
+ PluginConfig plugin_config_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CUDAExecutor);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_GPU_EXECUTOR_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_helpers.h b/tensorflow/stream_executor/cuda/cuda_helpers.h
new file mode 100644
index 0000000000..2c5311cb3b
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_helpers.h
@@ -0,0 +1,95 @@
+// Common helper functions used for dealing with CUDA API datatypes.
+//
+// These are typically placed here for use by multiple source components (for
+// example, BLAS and executor components).
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_
+
+#include <stddef.h>
+#include <complex>
+
+#include "third_party/gpus/cuda/include/cuComplex.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+template <typename ElemT>
+class DeviceMemory;
+
+namespace cuda {
+
+// Converts a const DeviceMemory reference to its underlying typed pointer in
+// CUDA
+// device memory.
+template <typename T>
+const T *CUDAMemory(const DeviceMemory<T> &mem) {
+ return static_cast<const T *>(mem.opaque());
+}
+
+// Converts a (non-const) DeviceMemory pointer reference to its underlying typed
+// pointer in CUDA device device memory.
+template <typename T>
+T *CUDAMemoryMutable(DeviceMemory<T> *mem) {
+ return static_cast<T *>(mem->opaque());
+}
+
+CUstream AsCUDAStreamValue(Stream *stream);
+
+static_assert(sizeof(std::complex<float>) == sizeof(cuComplex),
+ "std::complex<float> and cuComplex should have the same size");
+static_assert(offsetof(cuComplex, x) == 0,
+ "The real part of cuComplex should appear first.");
+static_assert(sizeof(std::complex<double>) == sizeof(cuDoubleComplex),
+ "std::complex<double> and cuDoubleComplex should have the same "
+ "size");
+static_assert(offsetof(cuDoubleComplex, x) == 0,
+ "The real part of cuDoubleComplex should appear first.");
+
+// Type traits to get CUDA complex types from std::complex<>.
+
+template <typename T>
+struct CUDAComplexT {
+ typedef T type;
+};
+
+template <>
+struct CUDAComplexT<std::complex<float>> {
+ typedef cuComplex type;
+};
+
+template <>
+struct CUDAComplexT<std::complex<double>> {
+ typedef cuDoubleComplex type;
+};
+
+// Converts pointers of std::complex<> to pointers of
+// cuComplex/cuDoubleComplex. No type conversion for non-complex types.
+
+template <typename T>
+inline const typename CUDAComplexT<T>::type *CUDAComplex(const T *p) {
+ return reinterpret_cast<const typename CUDAComplexT<T>::type *>(p);
+}
+
+template <typename T>
+inline typename CUDAComplexT<T>::type *CUDAComplex(T *p) {
+ return reinterpret_cast<typename CUDAComplexT<T>::type *>(p);
+}
+
+// Converts values of std::complex<float/double> to values of
+// cuComplex/cuDoubleComplex.
+inline cuComplex CUDAComplexValue(std::complex<float> val) {
+ return {val.real(), val.imag()};
+}
+
+inline cuDoubleComplex CUDAComplexValue(std::complex<double> val) {
+ return {val.real(), val.imag()};
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_kernel.h b/tensorflow/stream_executor/cuda/cuda_kernel.h
new file mode 100644
index 0000000000..e8ad3955e9
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_kernel.h
@@ -0,0 +1,115 @@
+// The CUDA implementation of the StreamExecutorInterface functionality.
+// CUDA inclusions are ideally confined to this implementation file.
+//
+// The notions from the StreamExecutor basically correspond to the CUDA streams
+// programming model provided by the libcuda.so driver APIs, so we don't have
+// to do much more than wrap the calls to the libraries appropriately.
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_
+
+#include "tensorflow/stream_executor/kernel_cache_config.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/lib/casts.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "third_party/gpus/cuda/include/cuda.h"
+
+#ifdef PLATFORMS_GPUS_CUDA_DYNAMIC_LIBCUDA_DYNAMIC_LIBCUDA_H_
+#error \
+ "No driver calls in this file, wrap driver functionality in cuda_driver.cc."
+#endif
+
+#ifdef __CUDA_RUNTIME_H__
+#error \
+ "CUDA runtime being included into CUDA GPU executor; should be driver only."
+#endif
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// Wraps a CUfunction to implement the platform-independent KernelInterface.
+class CUDAKernel : public internal::KernelInterface {
+ public:
+ CUDAKernel() : cuda_function_(nullptr), arity_(0),
+ preferred_cache_config_(KernelCacheConfig::kNoPreference) {}
+
+ // Note that the function is unloaded when the module is unloaded, and the
+ // module that the function is contained in is owned by the CUDAExecutor.
+ ~CUDAKernel() override {}
+
+ // As arity cannot be reflected upon using the CUDA API, the arity is
+ // explicitly set during the CUDAExecutor::GetKernel initialization process.
+ void set_arity(unsigned arity) { arity_ = arity; }
+ unsigned Arity() const override { return arity_; }
+
+ // Returns the CUfunction value for passing to the CUDA API.
+ CUfunction AsCUDAFunctionValue() const {
+ DCHECK(cuda_function_ != nullptr);
+ return const_cast<CUfunction>(cuda_function_);
+ }
+
+ // Returns the slot that the CUfunction is stored within for this object,
+ // for the CUDA API which wants to load into a CUfunction*.
+ CUfunction *cuda_function_ptr() { return &cuda_function_; }
+
+ // CUDA supports setting the preferred cache configuration of a CUfunction
+ // (more-or-less equivalent to a CUDAKernel). We support this via the below
+ // functions; users can set a preference, and that is applied when the kernel
+ // is [lazy-]loaded (in CUDAExecutor::Launch). The alternative would be to
+ // load the kernel & set the preference when the user calls the setter below;
+ // either approach is valid.
+ // Sets the current kernel cache configuration preference.
+ void SetPreferredCacheConfig(KernelCacheConfig config) override {
+ preferred_cache_config_ = config;
+ }
+
+ // Returns the current kernel cache configuration preference.
+ KernelCacheConfig GetPreferredCacheConfig() const override {
+ return preferred_cache_config_;
+ }
+
+ // Returns the current kernel cache configuration preference as a
+ // CUfunc_cache.
+ CUfunc_cache GetCUDACacheConfig() const {
+ switch (preferred_cache_config_) {
+ case KernelCacheConfig::kNoPreference:
+ return CU_FUNC_CACHE_PREFER_NONE;
+ case KernelCacheConfig::kPreferShared:
+ return CU_FUNC_CACHE_PREFER_SHARED;
+ case KernelCacheConfig::kPreferL1:
+ return CU_FUNC_CACHE_PREFER_L1;
+ case KernelCacheConfig::kPreferEqual:
+ return CU_FUNC_CACHE_PREFER_EQUAL;
+ default:
+ LOG(FATAL) << "Unknown KernelCacheConfig"
+ << static_cast<int32>(preferred_cache_config_);
+ }
+ }
+
+ private:
+ CUfunction cuda_function_; // Wrapped CUDA kernel handle.
+ unsigned arity_; // Number of formal parameters the kernel takes.
+
+ // Preferred (but not required) cache configuration for this kernel.
+ KernelCacheConfig preferred_cache_config_;
+};
+
+// Given a platform-independent kernel datatype, returns the (const) internal
+// CUDA platform implementation pointer.
+inline const CUDAKernel *AsCUDAKernel(const KernelBase *kernel) {
+ return static_cast<const CUDAKernel *>(kernel->implementation());
+}
+
+// Given a platform-independent kernel datatype, returns the (non-const)
+// internal CUDA platform implementation pointer.
+inline CUDAKernel *AsCUDAKernel(KernelBase *kernel) {
+ return static_cast<CUDAKernel *>(kernel->implementation());
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc
new file mode 100644
index 0000000000..ef88b89eda
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_platform.cc
@@ -0,0 +1,172 @@
+#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+PLATFORM_DEFINE_ID(kCudaPlatformId);
+
+CudaPlatform::CudaPlatform()
+ : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}
+
+CudaPlatform::~CudaPlatform() {}
+
+// Due to legacy issues in user code, we can't currently call InpectNumaNodes
+// at module initialization time, because non-GPU programs still include this
+// plugin via various methods, so instead, it has to be init-on-reference.
+void CudaPlatform::InspectNumaNodes() {
+ // To get NUMA node information, we need to create all executors, so we can
+ // examine their device descriptions to see their bus assignments.
+ static bool initialized = false;
+ static mutex numa_mutex(LINKER_INITIALIZED);
+ mutex_lock lock(numa_mutex);
+ if (initialized) {
+ return;
+ }
+
+ StreamExecutorConfig config;
+ for (int i = 0; i < VisibleDeviceCount(); i++) {
+ config.ordinal = i;
+ StreamExecutor* exec = GetExecutor(config).ValueOrDie();
+ if (i == 0) {
+ // NUMA nodes may not start at 0, so set the minimum node based on the
+ // first executor we see.
+ min_numa_node_ = exec->GetDeviceDescription().numa_node();
+ limit_numa_node_ = min_numa_node_ + 1;
+ } else {
+ min_numa_node_ =
+ std::min(min_numa_node_, exec->GetDeviceDescription().numa_node());
+ limit_numa_node_ = std::max(limit_numa_node_,
+ exec->GetDeviceDescription().numa_node() + 1);
+ }
+ }
+ initialized = true;
+}
+
+int CudaPlatform::BusCount() {
+ InspectNumaNodes();
+ return limit_numa_node_ - min_numa_node_;
+}
+
+int CudaPlatform::DeviceToBus(int device_ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = device_ordinal;
+ StreamExecutor* exec = GetExecutor(config).ValueOrDie();
+ return exec->GetDeviceDescription().numa_node() - min_numa_node_;
+}
+
+port::StatusOr<StreamExecutor*> CudaPlatform::FirstExecutorForBus(
+ int bus_ordinal) {
+ InspectNumaNodes();
+ CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range";
+ for (int i = 0; i < VisibleDeviceCount(); i++) {
+ if (DeviceToBus(i) == bus_ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = i;
+ return GetExecutor(config).ValueOrDie();
+ }
+ }
+
+ return port::Status{
+ port::error::NOT_FOUND,
+ port::Printf("Executor for bus %d not found.", bus_ordinal)};
+}
+
+Platform::Id CudaPlatform::id() const { return kCudaPlatformId; }
+
+int CudaPlatform::VisibleDeviceCount() const {
+ // Throw away the result - it logs internally, and this [containing] function
+ // isn't in the path of user control. It's safe to call this > 1x.
+ if (!cuda::CUDADriver::Init().ok()) {
+ return -1;
+ }
+
+ return CUDADriver::GetDeviceCount();
+}
+
+const string& CudaPlatform::Name() const { return name_; }
+
+port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDevice(int ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ config.plugin_config = PluginConfig();
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
+ int device_ordinal, const PluginConfig& plugin_config) {
+ StreamExecutorConfig config;
+ config.ordinal = device_ordinal;
+ config.plugin_config = plugin_config;
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
+ const StreamExecutorConfig& config) {
+ mutex_lock lock(mu_);
+
+ port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
+ if (status.ok()) {
+ return status.ValueOrDie();
+ }
+
+ port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
+ GetUncachedExecutor(config);
+ if (!executor.ok()) {
+ return executor.status();
+ }
+
+ StreamExecutor* naked_executor = executor.ValueOrDie().get();
+ executor_cache_.Insert(config, executor.ConsumeValueOrDie());
+ return naked_executor;
+}
+
+port::StatusOr<std::unique_ptr<StreamExecutor>>
+CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
+ auto executor = port::MakeUnique<StreamExecutor>(PlatformKind::kCuda,
+ config.plugin_config);
+ auto init_status = executor->Init(config.ordinal, config.device_options);
+ if (!init_status.ok()) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf(
+ "failed initializing StreamExecutor for CUDA device ordinal %d: %s",
+ config.ordinal, init_status.ToString().c_str())};
+ }
+
+ return std::move(executor);
+}
+
+void CudaPlatform::RegisterTraceListener(
+ std::unique_ptr<TraceListener> listener) {
+ LOG(FATAL) << "not yet implemented: register CUDA trace listener";
+}
+
+void CudaPlatform::UnregisterTraceListener(TraceListener* listener) {
+ LOG(FATAL) << "not yet implemented: unregister CUDA trace listener";
+}
+
+} // namespace cuda
+
+static void InitializeCudaPlatform() {
+ // Disabling leak checking, MultiPlatformManager does not destroy its
+ // registered platforms.
+
+ std::unique_ptr<cuda::CudaPlatform> platform(new cuda::CudaPlatform);
+ SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
+}
+
+} // namespace gputools
+} // namespace perftools
+
+REGISTER_MODULE_INITIALIZER(cuda_platform,
+ perftools::gputools::InitializeCudaPlatform());
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.h b/tensorflow/stream_executor/cuda/cuda_platform.h
new file mode 100644
index 0000000000..966d7343f7
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_platform.h
@@ -0,0 +1,98 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_
+
+#include <memory>
+#include "tensorflow/stream_executor/platform/port.h"
+#include <vector>
+
+#include "tensorflow/stream_executor/executor_cache.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+#include "tensorflow/stream_executor/trace_listener.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// Opaque and unique identifier for the CUDA platform plugin.
+// This is needed so that plugins can refer to/identify this platform without
+// instantiating a CudaPlatform object.
+extern const Platform::Id kCudaPlatformId;
+
+// Cuda-specific platform plugin, registered as a singleton value via module
+// initializer.
+class CudaPlatform : public Platform {
+ public:
+ CudaPlatform();
+ ~CudaPlatform() override;
+
+ // CudaPlatform-specific functionality
+ // Returns the number of distinct buses / NUMA nodes on the machine.
+ int BusCount();
+
+ // Returns the bus/NUMA node for the specified device ordinal.
+ int DeviceToBus(int device_ordinal);
+
+ // Returns the lowest-ordinal-number StreamExecutor on the specified bus.
+ port::StatusOr<StreamExecutor*> FirstExecutorForBus(int bus_ordinal);
+
+ // Platform interface implementation:
+ // Returns the same value as kCudaPlatform above.
+ Platform::Id id() const override;
+
+ // Returns -1 as a sentinel on internal failure (and logs the error).
+ int VisibleDeviceCount() const override;
+
+ const string& Name() const override;
+
+ port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
+
+ port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
+ int ordinal, const PluginConfig& config) override;
+
+ port::StatusOr<StreamExecutor*> GetExecutor(
+ const StreamExecutorConfig& config) override;
+
+ port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+ const StreamExecutorConfig& config) override;
+
+ void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override;
+
+ void UnregisterTraceListener(TraceListener* listener) override;
+
+ private:
+ // Determines the number of NUMA nodes and the assignment of executor to each.
+ void InspectNumaNodes();
+
+ // This platform's name.
+ string name_;
+
+ // mutex that guards internal state.
+ mutable mutex mu_;
+
+ // Cache of created executors.
+ ExecutorCache executor_cache_;
+
+ // The smallest NUMA node value for any device managed by this machine
+ // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
+ // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./
+ int min_numa_node_;
+
+ // Larger than the NUMA node value for any device managed by this machine
+ // manager.
+ int limit_numa_node_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CudaPlatform);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc
new file mode 100644
index 0000000000..ad48c8b59a
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_rng.cc
@@ -0,0 +1,317 @@
+#include "tensorflow/stream_executor/cuda/cuda_rng.h"
+
+#include <dlfcn.h>
+
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "third_party/gpus/cuda/include/curand.h"
+
+// Formats curandStatus_t to output prettified values into a log stream.
+std::ostream &operator<<(std::ostream &in, const curandStatus_t &status) {
+#define OSTREAM_CURAND_STATUS(__name) \
+ case CURAND_STATUS_##__name: \
+ in << "CURAND_STATUS_" #__name; \
+ return in;
+
+ switch (status) {
+ OSTREAM_CURAND_STATUS(SUCCESS)
+ OSTREAM_CURAND_STATUS(VERSION_MISMATCH)
+ OSTREAM_CURAND_STATUS(NOT_INITIALIZED)
+ OSTREAM_CURAND_STATUS(ALLOCATION_FAILED)
+ OSTREAM_CURAND_STATUS(TYPE_ERROR)
+ OSTREAM_CURAND_STATUS(OUT_OF_RANGE)
+ OSTREAM_CURAND_STATUS(LENGTH_NOT_MULTIPLE)
+ OSTREAM_CURAND_STATUS(LAUNCH_FAILURE)
+ OSTREAM_CURAND_STATUS(PREEXISTING_FAILURE)
+ OSTREAM_CURAND_STATUS(INITIALIZATION_FAILED)
+ OSTREAM_CURAND_STATUS(ARCH_MISMATCH)
+ OSTREAM_CURAND_STATUS(INTERNAL_ERROR)
+ default:
+ in << "curandStatus_t(" << static_cast<int>(status) << ")";
+ return in;
+ }
+}
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin);
+
+namespace dynload {
+
+#define PERFTOOLS_GPUTOOLS_CURAND_WRAP(__name) \
+ struct DynLoadShim__##__name { \
+ static const char *kName; \
+ using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
+ static void *GetDsoHandle() { \
+ static auto status = internal::CachedDsoLoader::GetCurandDsoHandle(); \
+ return status.ValueOrDie(); \
+ } \
+ static FuncPointerT DynLoad() { \
+ static void *f = dlsym(GetDsoHandle(), kName); \
+ CHECK(f != nullptr) << "could not find " << kName \
+ << " in curand DSO; dlerror: " << dlerror(); \
+ return reinterpret_cast<FuncPointerT>(f); \
+ } \
+ template <typename... Args> \
+ curandStatus_t operator()(CUDAExecutor * parent, Args... args) { \
+ cuda::ScopedActivateExecutorContext sac{parent}; \
+ return DynLoad()(args...); \
+ } \
+ } __name; \
+ const char *DynLoadShim__##__name::kName = #__name;
+
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandCreateGenerator);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandDestroyGenerator);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetStream);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniform);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniformDouble);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetGeneratorOffset);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormal);
+PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormalDouble);
+
+} // namespace dynload
+
+template <typename T>
+string TypeString();
+
+template <>
+string TypeString<float>() {
+ return "float";
+}
+
+template <>
+string TypeString<double>() {
+ return "double";
+}
+
+template <>
+string TypeString<std::complex<float>>() {
+ return "std::complex<float>";
+}
+
+template <>
+string TypeString<std::complex<double>>() {
+ return "std::complex<double>";
+}
+
+CUDARng::CUDARng(CUDAExecutor *parent) : parent_(parent), rng_(nullptr) {}
+
+CUDARng::~CUDARng() {
+ if (rng_ != nullptr) {
+ dynload::curandDestroyGenerator(parent_, rng_);
+ }
+}
+
+bool CUDARng::Init() {
+ mutex_lock lock{mu_};
+ CHECK(rng_ == nullptr);
+
+ curandStatus_t ret =
+ dynload::curandCreateGenerator(parent_, &rng_, CURAND_RNG_PSEUDO_DEFAULT);
+ if (ret != CURAND_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to create random number generator: " << ret;
+ return false;
+ }
+
+ CHECK(rng_ != nullptr);
+ return true;
+}
+
+bool CUDARng::SetStream(Stream *stream) {
+ curandStatus_t ret =
+ dynload::curandSetStream(parent_, rng_, AsCUDAStreamValue(stream));
+ if (ret != CURAND_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for random generation: " << ret;
+ return false;
+ }
+
+ return true;
+}
+
+// Returns true if std::complex stores its contents as two consecutive
+// elements. Tests int, float and double, as the last two are independent
+// specializations.
+constexpr bool ComplexIsConsecutiveFloats() {
+ return sizeof(std::complex<int>) == 8 && sizeof(std::complex<float>) == 8 &&
+ sizeof(std::complex<double>) == 16;
+}
+
+template <typename T>
+bool CUDARng::DoPopulateRandUniformInternal(Stream *stream,
+ DeviceMemory<T> *v) {
+ mutex_lock lock{mu_};
+ static_assert(ComplexIsConsecutiveFloats(),
+ "std::complex values are not stored as consecutive values");
+
+ if (!SetStream(stream)) {
+ return false;
+ }
+
+ // std::complex<T> is currently implemented as two consecutive T variables.
+ uint64 element_count = v->ElementCount();
+ if (std::is_same<T, std::complex<float>>::value ||
+ std::is_same<T, std::complex<double>>::value) {
+ element_count *= 2;
+ }
+
+ curandStatus_t ret;
+ if (std::is_same<T, float>::value ||
+ std::is_same<T, std::complex<float>>::value) {
+ ret = dynload::curandGenerateUniform(
+ parent_, rng_, reinterpret_cast<float *>(CUDAMemoryMutable(v)),
+ element_count);
+ } else {
+ ret = dynload::curandGenerateUniformDouble(
+ parent_, rng_, reinterpret_cast<double *>(CUDAMemoryMutable(v)),
+ element_count);
+ }
+ if (ret != CURAND_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to do uniform generation of " << v->ElementCount()
+ << " " << TypeString<T>() << "s at " << v->opaque() << ": "
+ << ret;
+ return false;
+ }
+
+ return true;
+}
+
+bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory<float> *v) {
+ return DoPopulateRandUniformInternal(stream, v);
+}
+
+bool CUDARng::DoPopulateRandUniform(Stream *stream, DeviceMemory<double> *v) {
+ return DoPopulateRandUniformInternal(stream, v);
+}
+
+bool CUDARng::DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<float>> *v) {
+ return DoPopulateRandUniformInternal(stream, v);
+}
+
+bool CUDARng::DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<double>> *v) {
+ return DoPopulateRandUniformInternal(stream, v);
+}
+
+template <typename ElemT, typename FuncT>
+bool CUDARng::DoPopulateRandGaussianInternal(Stream *stream, ElemT mean,
+ ElemT stddev,
+ DeviceMemory<ElemT> *v,
+ FuncT func) {
+ mutex_lock lock{mu_};
+
+ if (!SetStream(stream)) {
+ return false;
+ }
+
+ uint64 element_count = v->ElementCount();
+ curandStatus_t ret =
+ func(parent_, rng_, CUDAMemoryMutable(v), element_count, mean, stddev);
+
+ if (ret != CURAND_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to do gaussian generation of " << v->ElementCount()
+ << " floats at " << v->opaque() << ": " << ret;
+ return false;
+ }
+
+ return true;
+}
+
+bool CUDARng::DoPopulateRandGaussian(Stream *stream, float mean, float stddev,
+ DeviceMemory<float> *v) {
+ return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
+ dynload::curandGenerateNormal);
+}
+
+bool CUDARng::DoPopulateRandGaussian(Stream *stream, double mean, double stddev,
+ DeviceMemory<double> *v) {
+ return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
+ dynload::curandGenerateNormalDouble);
+}
+
+bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) {
+ mutex_lock lock{mu_};
+ CHECK(rng_ != nullptr);
+
+ if (!CheckSeed(seed, seed_bytes)) {
+ return false;
+ }
+
+ if (!SetStream(stream)) {
+ return false;
+ }
+
+ // Requires 8 bytes of seed data; checked in RngSupport::CheckSeed (above)
+ // (which itself requires 16 for API consistency with host RNG fallbacks).
+ curandStatus_t ret = dynload::curandSetPseudoRandomGeneratorSeed(
+ parent_, rng_, *(reinterpret_cast<const uint64 *>(seed)));
+ if (ret != CURAND_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set rng seed: " << ret;
+ return false;
+ }
+
+ ret = dynload::curandSetGeneratorOffset(parent_, rng_, 0);
+ if (ret != CURAND_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to reset rng position: " << ret;
+ return false;
+ }
+ return true;
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+namespace gpu = ::perftools::gputools;
+
+REGISTER_MODULE_INITIALIZER(register_curand, {
+ gpu::port::Status status =
+ gpu::PluginRegistry::Instance()
+ ->RegisterFactory<gpu::PluginRegistry::RngFactory>(
+ gpu::cuda::kCudaPlatformId, gpu::cuda::kCuRandPlugin, "cuRAND",
+ [](gpu::internal::StreamExecutorInterface
+ *parent) -> gpu::rng::RngSupport * {
+ gpu::cuda::CUDAExecutor *cuda_executor =
+ dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
+ if (cuda_executor == nullptr) {
+ LOG(ERROR)
+ << "Attempting to initialize an instance of the cuRAND "
+ << "support library with a non-CUDA StreamExecutor";
+ return nullptr;
+ }
+
+ gpu::cuda::CUDARng *rng = new gpu::cuda::CUDARng(cuda_executor);
+ if (!rng->Init()) {
+ // Note: Init() will log a more specific error.
+ delete rng;
+ return nullptr;
+ }
+ return rng;
+ });
+
+ if (!status.ok()) {
+ LOG(ERROR) << "Unable to register cuRAND factory: "
+ << status.error_message();
+ }
+
+ // Prime the cuRAND DSO. The loader will log more information.
+ auto statusor = gpu::internal::CachedDsoLoader::GetCurandDsoHandle();
+ if (!statusor.ok()) {
+ LOG(INFO) << "Unable to load cuRAND DSO.";
+ }
+
+ gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
+ gpu::PluginKind::kRng,
+ gpu::cuda::kCuRandPlugin);
+});
diff --git a/tensorflow/stream_executor/cuda/cuda_rng.h b/tensorflow/stream_executor/cuda/cuda_rng.h
new file mode 100644
index 0000000000..4e1b82969b
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_rng.h
@@ -0,0 +1,89 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_
+
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/rng.h"
+
+typedef struct curandGenerator_st *curandGenerator_t;
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+template <typename ElemT>
+class DeviceMemory;
+
+namespace cuda {
+
+// Opaque and unique identifier for the cuRAND plugin.
+extern const PluginId kCuRandPlugin;
+
+class CUDAExecutor;
+
+// CUDA-platform implementation of the random number generation support
+// interface.
+//
+// Thread-safe post-initialization.
+class CUDARng : public rng::RngSupport {
+ public:
+ explicit CUDARng(CUDAExecutor *parent);
+
+ // Retrieves a curand library generator handle. This is necessary for
+ // enqueuing random number generation work onto the device.
+ // TODO(leary) provide a way for users to select the RNG algorithm.
+ bool Init();
+
+ // Releases a curand library generator handle, if one was acquired.
+ ~CUDARng() override;
+
+ // See rng::RngSupport for details on the following overrides.
+ bool DoPopulateRandUniform(Stream *stream, DeviceMemory<float> *v) override;
+ bool DoPopulateRandUniform(Stream *stream, DeviceMemory<double> *v) override;
+ bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<float>> *v) override;
+ bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<double>> *v) override;
+ bool DoPopulateRandGaussian(Stream *stream, float mean, float stddev,
+ DeviceMemory<float> *v) override;
+ bool DoPopulateRandGaussian(Stream *stream, double mean, double stddev,
+ DeviceMemory<double> *v) override;
+
+ bool SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) override;
+
+ private:
+ // Actually performs the work of generating random numbers - the public
+ // methods are thin wrappers to this interface.
+ template <typename T>
+ bool DoPopulateRandUniformInternal(Stream *stream, DeviceMemory<T> *v);
+ template <typename ElemT, typename FuncT>
+ bool DoPopulateRandGaussianInternal(Stream *stream, ElemT mean, ElemT stddev,
+ DeviceMemory<ElemT> *v, FuncT func);
+
+ // Sets the stream for the internal curand generator.
+ //
+ // This is a stateful operation, as the handle can only have one stream set at
+ // a given time, so it is usually performed right before enqueuing work to do
+ // with random number generation.
+ bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // mutex that guards the cuRAND handle for this device.
+ mutex mu_;
+
+ // CUDAExecutor which instantiated this CUDARng.
+ // Immutable post-initialization.
+ CUDAExecutor *parent_;
+
+ // cuRANDalibrary handle on the device.
+ curandGenerator_t rng_ GUARDED_BY(mu_);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CUDARng);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_stream.cc b/tensorflow/stream_executor/cuda/cuda_stream.cc
new file mode 100644
index 0000000000..e70579b55c
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_stream.cc
@@ -0,0 +1,51 @@
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+
+#include "tensorflow/stream_executor/lib/status.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+bool CUDAStream::Init() {
+ return CUDADriver::CreateStream(parent_->cuda_context(), &cuda_stream_);
+}
+
+void CUDAStream::Destroy() {
+ {
+ mutex_lock lock{mu_};
+ if (completed_event_ != nullptr) {
+ port::Status status =
+ CUDADriver::DestroyEvent(parent_->cuda_context(), &completed_event_);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
+ }
+ }
+
+ CUDADriver::DestroyStream(parent_->cuda_context(), &cuda_stream_);
+}
+
+bool CUDAStream::IsIdle() const {
+ return CUDADriver::IsStreamIdle(parent_->cuda_context(), cuda_stream_);
+}
+
+bool CUDAStream::GetOrCreateCompletedEvent(CUevent *completed_event) {
+ mutex_lock lock{mu_};
+ if (completed_event_ != nullptr) {
+ *completed_event = completed_event_;
+ return true;
+ }
+
+ if (!CUDADriver::CreateEvent(parent_->cuda_context(), &completed_event_,
+ CUDADriver::EventFlags::kDisableTiming)
+ .ok()) {
+ return false;
+ }
+
+ *completed_event = completed_event_;
+ return true;
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h
new file mode 100644
index 0000000000..f6db64a1bf
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_stream.h
@@ -0,0 +1,74 @@
+// Defines the CUDAStream type - the CUDA-specific implementation of the generic
+// StreamExecutor Stream interface.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_
+
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+class CUDAExecutor;
+
+// Wraps a CUstream in order to satisfy the platform-independent
+// StreamInterface.
+//
+// Thread-safe post-initialization.
+class CUDAStream : public internal::StreamInterface {
+ public:
+ explicit CUDAStream(CUDAExecutor *parent)
+ : parent_(parent), cuda_stream_(nullptr), completed_event_(nullptr) {}
+
+ // Note: teardown is handled by a parent's call to DeallocateStream.
+ ~CUDAStream() override {}
+
+ void *CudaStreamHack() override { return cuda_stream_; }
+ void **CudaStreamMemberHack() override {
+ return reinterpret_cast<void **>(&cuda_stream_);
+ }
+
+ // Explicitly initialize the CUDA resources associated with this stream, used
+ // by StreamExecutor::AllocateStream().
+ bool Init();
+
+ // Explicitly destroy the CUDA resources associated with this stream, used by
+ // StreamExecutor::DeallocateStream().
+ void Destroy();
+
+ // Returns true if no work is pending or executing on the stream.
+ bool IsIdle() const;
+
+ // Retrieves an event which indicates that all work enqueued into the stream
+ // has completed. Ownership of the event is not transferred to the caller, the
+ // event is owned by this stream.
+ bool GetOrCreateCompletedEvent(CUevent *completed_event);
+
+ // Returns the CUstream value for passing to the CUDA API.
+ //
+ // Precond: this CUDAStream has been allocated (otherwise passing a nullptr
+ // into the NVIDIA library causes difficult-to-understand faults).
+ CUstream cuda_stream() const {
+ DCHECK(cuda_stream_ != nullptr);
+ return const_cast<CUstream>(cuda_stream_);
+ }
+
+ CUDAExecutor *parent() const { return parent_; }
+
+ private:
+ mutex mu_; // mutex that guards the completion event.
+ CUDAExecutor *parent_; // Executor that spawned this stream.
+ CUstream cuda_stream_; // Wrapped CUDA stream handle.
+
+ // Event that indicates this stream has completed.
+ CUevent completed_event_ GUARDED_BY(mu_);
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_timer.cc b/tensorflow/stream_executor/cuda/cuda_timer.cc
new file mode 100644
index 0000000000..ad5e13ab6b
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_timer.cc
@@ -0,0 +1,73 @@
+#include "tensorflow/stream_executor/cuda/cuda_timer.h"
+
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+#include "tensorflow/stream_executor/lib/status.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+bool CUDATimer::Init() {
+ CHECK(start_event_ == nullptr && stop_event_ == nullptr);
+ CUcontext context = parent_->cuda_context();
+ if (!CUDADriver::CreateEvent(context, &start_event_,
+ CUDADriver::EventFlags::kDefault)
+ .ok()) {
+ return false;
+ }
+
+ if (!CUDADriver::CreateEvent(context, &stop_event_,
+ CUDADriver::EventFlags::kDefault)
+ .ok()) {
+ port::Status status = CUDADriver::DestroyEvent(context, &start_event_);
+ if (!status.ok()) {
+ LOG(ERROR) << status;
+ }
+ return false;
+ }
+
+ CHECK(start_event_ != nullptr && stop_event_ != nullptr);
+ return true;
+}
+
+void CUDATimer::Destroy() {
+ CUcontext context = parent_->cuda_context();
+ port::Status status = CUDADriver::DestroyEvent(context, &start_event_);
+ if (!status.ok()) {
+ LOG(ERROR) << status;
+ }
+
+ status = CUDADriver::DestroyEvent(context, &stop_event_);
+ if (!status.ok()) {
+ LOG(ERROR) << status;
+ }
+}
+
+float CUDATimer::GetElapsedMilliseconds() const {
+ CHECK(start_event_ != nullptr && stop_event_ != nullptr);
+ // TODO(leary) provide a way to query timer resolution?
+ // CUDA docs say a resolution of about 0.5us
+ float elapsed_milliseconds = NAN;
+ (void)CUDADriver::GetEventElapsedTime(parent_->cuda_context(),
+ &elapsed_milliseconds, start_event_,
+ stop_event_);
+ return elapsed_milliseconds;
+}
+
+bool CUDATimer::Start(CUDAStream *stream) {
+ return CUDADriver::RecordEvent(parent_->cuda_context(), start_event_,
+ stream->cuda_stream())
+ .ok();
+}
+
+bool CUDATimer::Stop(CUDAStream *stream) {
+ return CUDADriver::RecordEvent(parent_->cuda_context(), stop_event_,
+ stream->cuda_stream())
+ .ok();
+}
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_timer.h b/tensorflow/stream_executor/cuda/cuda_timer.h
new file mode 100644
index 0000000000..e49e212403
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_timer.h
@@ -0,0 +1,69 @@
+// Defines the CUDATimer type - the CUDA-specific implementation of the generic
+// StreamExecutor Timer interface.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_
+
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+class CUDAExecutor;
+class CUDAStream;
+
+// Wraps a pair of CUevents in order to satisfy the platform-independent
+// TimerInferface -- both a start and a stop event are present which may be
+// recorded in a stream.
+class CUDATimer : public internal::TimerInterface {
+ public:
+ explicit CUDATimer(CUDAExecutor *parent)
+ : parent_(parent), start_event_(nullptr), stop_event_(nullptr) {}
+
+ // Note: teardown is explicitly handled in this API by a call to
+ // StreamExecutor::DeallocateTimer(), which invokes Destroy().
+ ~CUDATimer() override {}
+
+ // Allocates the platform-specific pieces of the timer, called as part of
+ // StreamExecutor::AllocateTimer().
+ bool Init();
+
+ // Deallocates the platform-specific pieces of the timer, called as part of
+ // StreamExecutor::DeallocateTimer().
+ void Destroy();
+
+ // Records the "timer start" event at the current point in the stream.
+ bool Start(CUDAStream *stream);
+
+ // Records the "timer stop" event at the current point in the stream.
+ bool Stop(CUDAStream *stream);
+
+ // Returns the elapsed time, in milliseconds, between the start and stop
+ // events.
+ float GetElapsedMilliseconds() const;
+
+ // See perftools::gputools::Timer::Microseconds().
+ // TODO(leary) make this into an error code interface...
+ uint64 Microseconds() const override {
+ return GetElapsedMilliseconds() * 1e3;
+ }
+
+ // See perftools::GPUTools::Timer::Nanoseconds().
+ uint64 Nanoseconds() const override { return GetElapsedMilliseconds() * 1e6; }
+
+ private:
+ CUDAExecutor *parent_;
+ CUevent start_event_; // Event recorded to indicate the "start" timestamp
+ // executing in a stream.
+ CUevent stop_event_; // Event recorded to indicate the "stop" timestamp
+ // executing in a stream.
+};
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_
diff --git a/tensorflow/stream_executor/cuda/multi_op_activation.h b/tensorflow/stream_executor/cuda/multi_op_activation.h
new file mode 100644
index 0000000000..ba2bcd3a91
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/multi_op_activation.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_MULTI_OP_ACTIVATION_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_MULTI_OP_ACTIVATION_H_
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// Type-safe boolean wrapper: denotes whether a ScopedActivateExecutorContext
+// may have other ScopedActivateExecutorContexts nested within it.
+enum class MultiOpActivation { kNo = false, kYes = true };
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_MULTI_OP_ACTIVATION_H_
diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc
new file mode 100644
index 0000000000..23c110c2f3
--- /dev/null
+++ b/tensorflow/stream_executor/device_description.cc
@@ -0,0 +1,221 @@
+#include "tensorflow/stream_executor/device_description.h"
+
+#include <algorithm>
+
+#include "tensorflow/stream_executor/lib/human_readable.h"
+#include "tensorflow/stream_executor/lib/mathutil.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+
+namespace perftools {
+namespace gputools {
+
+static const uint64 kUninitializedUint64 = -1ULL;
+/* static */ const char *DeviceDescription::kUndefinedString = "<undefined>";
+
+DeviceDescription::DeviceDescription()
+ : device_vendor_(kUndefinedString),
+ platform_version_(kUndefinedString),
+ driver_version_(kUndefinedString),
+ runtime_version_(kUndefinedString),
+ pci_bus_id_(kUndefinedString),
+ name_(kUndefinedString),
+ thread_dim_limit_(kUninitializedUint64, kUninitializedUint64,
+ kUninitializedUint64),
+ block_dim_limit_(kUninitializedUint64, kUninitializedUint64,
+ kUninitializedUint64),
+ blocks_per_core_limit_(kUninitializedUint64),
+ threads_per_core_limit_(kUninitializedUint64),
+ threads_per_block_limit_(kUninitializedUint64),
+ threads_per_warp_(kUninitializedUint64),
+ registers_per_core_limit_(kUninitializedUint64),
+ registers_per_block_limit_(kUninitializedUint64),
+ registers_per_thread_limit_(kUninitializedUint64),
+ warp_alloc_granularity_(1),
+ register_alloc_granularity_(1),
+ shared_memory_alloc_granularity_(1),
+ device_address_bits_(kUninitializedUint64),
+ device_memory_size_(kUninitializedUint64),
+ shared_memory_per_core_(kUninitializedUint64),
+ shared_memory_per_block_(kUninitializedUint64),
+ clock_rate_ghz_(-1.0),
+ cuda_compute_capability_major_(-1),
+ cuda_compute_capability_minor_(-1),
+ numa_node_(-1),
+ core_count_(-1),
+ ecc_enabled_(false) {}
+
+std::unique_ptr<std::map<string, string>> DeviceDescription::ToMap() const {
+ std::unique_ptr<std::map<string, string>> owned_result{
+ new std::map<string, string>};
+ std::map<string, string> &result = *owned_result;
+ result["Device Vendor"] = device_vendor();
+ result["Platform Version"] = platform_version();
+ result["Driver Version"] = driver_version();
+ result["Runtime Version"] = runtime_version();
+ result["PCI bus ID"] = pci_bus_id_;
+ result["Device Name"] = name_;
+
+ const ThreadDim &thread_dim = thread_dim_limit();
+ result["ThreadDim Limit"] =
+ port::StrCat(thread_dim.x, ",", thread_dim.y, ",", thread_dim.z);
+ const BlockDim &block_dim = block_dim_limit();
+ result["BlockDim Limit"] =
+ port::StrCat(block_dim.x, ",", block_dim.y, ",", block_dim.z);
+
+ result["Threads Per Core Limit"] = port::StrCat(threads_per_core_limit());
+ result["Threads Per Block Limit"] = port::StrCat(threads_per_block_limit());
+ result["Registers Per Block Limit"] =
+ port::StrCat(registers_per_block_limit());
+
+ result["Device Address Bits"] = port::StrCat(device_address_bits());
+ result["Device Memory Size"] =
+ port::HumanReadableNumBytes::ToString(device_memory_size());
+
+ result["Shared Memory Per Core"] =
+ port::HumanReadableNumBytes::ToString(shared_memory_per_core_);
+ result["Shared Memory Per Block"] =
+ port::HumanReadableNumBytes::ToString(shared_memory_per_block_);
+
+ result["Clock Rate GHz"] = port::StrCat(clock_rate_ghz());
+
+ result["CUDA Compute Capability"] = port::StrCat(
+ cuda_compute_capability_major_, ".", cuda_compute_capability_minor_);
+
+ result["NUMA Node"] = port::StrCat(numa_node());
+ result["Core Count"] = port::StrCat(core_count());
+ result["ECC Enabled"] = port::StrCat(ecc_enabled());
+ return owned_result;
+}
+
+namespace internal {
+
+DeviceDescriptionBuilder::DeviceDescriptionBuilder()
+ : device_description_(new DeviceDescription) {}
+
+} // namespace internal
+
+bool DeviceDescription::cuda_compute_capability(int *major, int *minor) const {
+ *major = cuda_compute_capability_major_;
+ *minor = cuda_compute_capability_minor_;
+ return cuda_compute_capability_major_ != 0;
+}
+
+bool ThreadDimOk(const DeviceDescription &device_description,
+ const ThreadDim &thread_dim) {
+ auto total_threads = thread_dim.x * thread_dim.y * thread_dim.z;
+ auto threads_per_block_limit = device_description.threads_per_block_limit();
+ if (total_threads > threads_per_block_limit) {
+ VLOG(2) << "exceeded total-thread-per-block limit: " << total_threads
+ << " vs limit " << threads_per_block_limit;
+ return false;
+ }
+
+ const auto &limit = device_description.thread_dim_limit();
+ bool ok = thread_dim.x <= limit.x && thread_dim.y <= limit.y &&
+ thread_dim.z <= limit.z;
+ if (!ok) {
+ VLOG(2) << "thread dim " << thread_dim.ToString()
+ << " exceeds limit contraints of " << limit.ToString();
+ }
+ return ok;
+}
+
+uint64 DivideCeil(uint64 x, uint64 y) {
+ return port::MathUtil::CeilOfRatio(x, y);
+}
+
+void CalculateDimensionality(const DeviceDescription &device_description,
+ uint64 element_count, uint64 *threads_per_block,
+ uint64 *block_count) {
+ *threads_per_block = device_description.threads_per_block_limit();
+ *block_count = DivideCeil(element_count, *threads_per_block);
+ if (*block_count == 1) {
+ CHECK_LE(element_count, *threads_per_block);
+ *threads_per_block = element_count;
+ }
+}
+
+// Round value up to a multiple of n.
+static uint64 RoundUp(uint64 value, uint64 n) {
+ return port::MathUtil::CeilOfRatio(value, n) * n;
+}
+
+// Round value down to a multiple of n.
+static uint64 RoundDown(uint64 value, uint64 n) {
+ return port::MathUtil::FloorOfRatio(value, n) * n;
+}
+
+uint64 CalculateOccupancy(const DeviceDescription &device_description,
+ uint64 registers_per_thread,
+ uint64 shared_memory_per_block,
+ const ThreadDim &thread_dims) {
+ // Don't try to compute occupancy if necessary values are not initialized.
+ uint64 required_fields[] = { device_description.registers_per_thread_limit(),
+ device_description.threads_per_warp(),
+ device_description.warp_alloc_granularity(),
+ device_description.register_alloc_granularity(),
+ device_description.registers_per_block_limit(),
+ device_description.shared_memory_per_core(),
+ device_description.blocks_per_core_limit() };
+ for (auto value : required_fields) {
+ if (value == kUninitializedUint64) {
+ return 0;
+ }
+ }
+
+ if (registers_per_thread > device_description.registers_per_thread_limit()) {
+ return 0;
+ }
+
+ uint64 warps_per_block =
+ port::MathUtil::CeilOfRatio(thread_dims.x * thread_dims.y * thread_dims.z,
+ device_description.threads_per_warp());
+
+ // Warp resources are allocated at a particular granularity. This value is
+ // the effective number of warps for resource allocation purposes.
+ uint64 alloc_warps_per_block =
+ RoundUp(warps_per_block, device_description.warp_alloc_granularity());
+
+ uint64 alloc_regs_per_warp =
+ RoundUp(device_description.threads_per_warp() * registers_per_thread,
+ device_description.register_alloc_granularity());
+ uint64 regs_per_block = alloc_warps_per_block * alloc_regs_per_warp;
+ uint64 reg_limit =
+ device_description.registers_per_block_limit() / regs_per_block;
+
+ uint64 alloc_smem_per_block = RoundUp(
+ shared_memory_per_block,
+ device_description.shared_memory_alloc_granularity());
+ uint64 smem_limit = alloc_smem_per_block > 0 ?
+ device_description.shared_memory_per_core() / alloc_smem_per_block :
+ device_description.blocks_per_core_limit();
+
+ uint64 thread_limit = device_description.threads_per_core_limit()
+ / (warps_per_block * device_description.threads_per_warp());
+
+ return std::min({ device_description.blocks_per_core_limit(),
+ reg_limit, smem_limit, thread_limit });
+}
+
+uint64 CalculateRegisterLimitForTargetOccupancy(
+ const DeviceDescription &device_description, uint64 shared_memory_per_block,
+ const ThreadDim &thread_dims, uint64 target_blocks_per_core) {
+ // Linear search from maximum number of registers down until the target
+ // blocks per SM is found.
+ // TODO(meheff): Compute this using a closed form solution.
+ int reg_step = device_description.register_alloc_granularity() /
+ device_description.threads_per_warp();
+ for (int r = device_description.registers_per_thread_limit(); r > 0;
+ r = RoundDown(r - 1, reg_step)) {
+ uint64 occupancy = CalculateOccupancy(
+ device_description, r, shared_memory_per_block, thread_dims);
+ if (occupancy >= target_blocks_per_core) {
+ return r;
+ }
+ }
+ return 0;
+}
+
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
new file mode 100644
index 0000000000..e7b7102da5
--- /dev/null
+++ b/tensorflow/stream_executor/device_description.h
@@ -0,0 +1,370 @@
+// Describes the underlying platform for a StreamExecutor; e.g. OpenCL or CUDA
+// device and platform properties. Also contains convenience functions for
+// checking/calculating launch dimensionality based on device properties.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
+#define TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
+
+#include <map>
+#include <memory>
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+class DeviceDescriptionBuilder;
+} // namespace internal
+
+// Data that describes the execution target of the StreamExecutor, in terms of
+// important logical parameters. These include dimensionality limits and
+// physical parameters of interest, such as number of cores present on the
+// device.
+//
+// Thread-safe: immutable post-initialization.
+class DeviceDescription {
+ public:
+ // Returns the platform being run on; this value is primarily intended for
+ // printing, and comes out something like "OpenCL 1.2" or "Compute Capability
+ // 3.5".
+ const string &platform_version() const { return platform_version_; }
+
+ // Returns the driver version interfacing with the underlying platform. Vendor
+ // dependent format.
+ const string &driver_version() const { return driver_version_; }
+
+ // Return the runtime version, if one is provided by the underlying platform.
+ // Vendor dependent format / usefulness.
+ const string &runtime_version() const { return runtime_version_; }
+
+ // Returns the name that the device reports. Vendor dependent.
+ const string &name() const { return name_; }
+
+ // Returns the PCI bus identifier for this device, of the form
+ // [domain]:[bus]:[device].[function]
+ const string &pci_bus_id() const { return pci_bus_id_; }
+
+ // Returns the NUMA node associated with this device, for use in
+ // determining socket locality. If the NUMA node could not be determined, -1
+ // is returned.
+ int numa_node() const { return numa_node_; }
+
+ // Number of cores (traditional notion of core; i.e. an SM on an NVIDIA device
+ // or an AMD Compute Unit.
+ int core_count() const { return core_count_; }
+
+ // Returns the limit on the thread dimensionality values in each of the
+ // respective dimensions. These limits affect what constitutes a legitimate
+ // kernel launch request.
+ const ThreadDim &thread_dim_limit() const { return thread_dim_limit_; }
+
+ // Returns the limit on the block dimensionality values in each of the
+ // respective dimensions. These limits may affect what constitutes a
+ // legitimate kernel launch request.
+ const BlockDim &block_dim_limit() const { return block_dim_limit_; }
+
+ // Returns the limit on the number of simultaneously resident blocks
+ // on a multiprocessor.
+ const uint64 blocks_per_core_limit() const { return blocks_per_core_limit_; }
+
+ // Returns the limit on the total number of threads that can be launched in a
+ // single block; i.e. the limit on x * y * z dimensions of a ThreadDim.
+ // This limit affects what constitutes a legitimate kernel launch request.
+ const uint64 &threads_per_block_limit() const {
+ return threads_per_block_limit_;
+ }
+
+ // Returns the limit on the total number of threads that can be simultaneously
+ // launched on a given multiprocessor.
+ const uint64 &threads_per_core_limit() const {
+ return threads_per_core_limit_;
+ }
+
+ // Returns the number of threads per warp/wavefront.
+ const uint64 &threads_per_warp() const { return threads_per_warp_; }
+
+ // Returns the limit on the total number of registers per core.
+ const uint64 &registers_per_core_limit() const {
+ return registers_per_core_limit_;
+ }
+
+ // Returns the limit on the total number of registers that can be
+ // simultaneously used by a block.
+ const uint64 &registers_per_block_limit() const {
+ return registers_per_block_limit_;
+ }
+
+ // Returns the limit on the total number of registers that can be
+ // allocated to a thread.
+ const uint64 &registers_per_thread_limit() const {
+ return registers_per_thread_limit_;
+ }
+
+ // Returns the granularity at which warps are allocated resources.
+ const uint64 &warp_alloc_granularity() const {
+ return warp_alloc_granularity_;
+ }
+
+ // Returns the granularity at which registers are allocated to warps.
+ const uint64 &register_alloc_granularity() const {
+ return register_alloc_granularity_;
+ }
+
+ // Returns the granularity at which shared memory is allocated to warps.
+ const uint64 &shared_memory_alloc_granularity() const {
+ return shared_memory_alloc_granularity_;
+ }
+
+ // Returns the number of address bits available to kernel code running on the
+ // platform. This affects things like the maximum allocation size and perhaps
+ // types used in kernel code such as size_t.
+ const uint64 &device_address_bits() const { return device_address_bits_; }
+
+ // Returns the device memory size in bytes.
+ uint64 device_memory_size() const { return device_memory_size_; }
+
+ // Returns the device's core clock rate in GHz.
+ const float clock_rate_ghz() const { return clock_rate_ghz_; }
+
+ // Returns whether ECC is enabled.
+ bool ecc_enabled() const { return ecc_enabled_; }
+
+ // Returns the device vendor string, e.g., "NVIDIA Corporation", "Advanced
+ // Micro Devices, Inc.", or "GenuineIntel".
+ const string &device_vendor() const { return device_vendor_; }
+
+ // Returns the CUDA compute capability if we're running on the CUDA platform.
+ // If a CUDA compute capability is not available, the major version will be
+ // zero, and the return value will be false.
+ bool cuda_compute_capability(int *major, int *minor) const;
+
+ // Returns the maximum amount of shared memory present on a single core
+ // (i.e. Streaming Multiprocessor on NVIDIA GPUs; Compute Unit for OpenCL
+ // devices). Note that some devices, such as NVIDIA's have a configurable
+ // partitioning between shared memory and L1 cache.
+ uint64 shared_memory_per_core() const { return shared_memory_per_core_; }
+
+ // Returns the maximum amount of shared memory available for a single block.
+ uint64 shared_memory_per_block() const { return shared_memory_per_block_; }
+
+ // TODO(leary): resident blocks per core will be useful.
+
+ // Convenience typedef for the string-based DeviceDescription mapping.
+ typedef std::map<string, string> Map;
+
+ // Returns a mapping from readable names to readable values that describe the
+ // device. This is useful for things like printing.
+ std::unique_ptr<Map> ToMap() const;
+
+ // For string values that are not available via the underlying platform, this
+ // value will be provided.
+ static const char *kUndefinedString;
+
+ private:
+ friend class internal::DeviceDescriptionBuilder;
+
+ DeviceDescription();
+
+ // For description of the following members, see the corresponding accessor
+ // above.
+ //
+ // N.B. If another field is added, update ToMap() above.
+ string device_vendor_;
+ string platform_version_;
+ string driver_version_;
+ string runtime_version_;
+ string pci_bus_id_;
+ string name_;
+
+ ThreadDim thread_dim_limit_;
+ BlockDim block_dim_limit_;
+
+ uint64 blocks_per_core_limit_;
+
+ uint64 threads_per_core_limit_;
+ uint64 threads_per_block_limit_;
+ uint64 threads_per_warp_;
+
+ uint64 registers_per_core_limit_;
+ uint64 registers_per_block_limit_;
+ uint64 registers_per_thread_limit_;
+
+ uint64 warp_alloc_granularity_;
+ uint64 register_alloc_granularity_;
+ uint64 shared_memory_alloc_granularity_;
+
+ uint64 device_address_bits_;
+ uint64 device_memory_size_;
+
+ // Shared memory limits on a given device.
+ uint64 shared_memory_per_core_;
+ uint64 shared_memory_per_block_;
+
+ float clock_rate_ghz_;
+
+ // CUDA "CC" major value, -1 if not available.
+ int cuda_compute_capability_major_;
+ int cuda_compute_capability_minor_;
+
+ int numa_node_;
+ int core_count_;
+ bool ecc_enabled_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(DeviceDescription);
+};
+
+namespace internal {
+
+// Helper class the builds a device description, given that it has a large
+// number of fields that would be easily confused in constructor form.
+class DeviceDescriptionBuilder {
+ public:
+ DeviceDescriptionBuilder();
+
+ // For descriptions of the following fields, see comments on the corresponding
+ // DeviceDescription::* accessors above.
+
+ void set_device_vendor(const string &value) {
+ device_description_->device_vendor_ = value;
+ }
+ void set_platform_version(const string &value) {
+ device_description_->platform_version_ = value;
+ }
+ void set_driver_version(const string &value) {
+ device_description_->driver_version_ = value;
+ }
+ void set_runtime_version(const string &value) {
+ device_description_->runtime_version_ = value;
+ }
+ void set_pci_bus_id(const string &value) {
+ device_description_->pci_bus_id_ = value;
+ }
+ void set_name(const string &value) { device_description_->name_ = value; }
+
+ void set_thread_dim_limit(const ThreadDim &value) {
+ device_description_->thread_dim_limit_ = value;
+ }
+ void set_block_dim_limit(const BlockDim &value) {
+ device_description_->block_dim_limit_ = value;
+ }
+
+ void set_blocks_per_core_limit(uint64 value) {
+ device_description_->blocks_per_core_limit_ = value;
+ }
+
+ void set_threads_per_core_limit(uint64 value) {
+ device_description_->threads_per_core_limit_ = value;
+ }
+ void set_threads_per_block_limit(uint64 value) {
+ device_description_->threads_per_block_limit_ = value;
+ }
+ void set_threads_per_warp(uint64 value) {
+ device_description_->threads_per_warp_ = value;
+ }
+
+ void set_registers_per_core_limit(uint64 value) {
+ device_description_->registers_per_core_limit_ = value;
+ }
+ void set_registers_per_block_limit(uint64 value) {
+ device_description_->registers_per_block_limit_ = value;
+ }
+ void set_registers_per_thread_limit(uint64 value) {
+ device_description_->registers_per_thread_limit_ = value;
+ }
+
+ void set_warp_alloc_granularity(uint64 value) {
+ device_description_->warp_alloc_granularity_ = value;
+ }
+ void set_register_alloc_granularity(uint64 value) {
+ device_description_->register_alloc_granularity_ = value;
+ }
+ void set_shared_memory_alloc_granularity(uint64 value) {
+ device_description_->shared_memory_alloc_granularity_ = value;
+ }
+
+ void set_device_address_bits(uint64 value) {
+ device_description_->device_address_bits_ = value;
+ }
+ void set_device_memory_size(uint64 value) {
+ device_description_->device_memory_size_ = value;
+ }
+
+ void set_shared_memory_per_core(int64 value) {
+ device_description_->shared_memory_per_core_ = value;
+ }
+ void set_shared_memory_per_block(int64 value) {
+ device_description_->shared_memory_per_block_ = value;
+ }
+
+ void set_clock_rate_ghz(float value) {
+ device_description_->clock_rate_ghz_ = value;
+ }
+
+ void set_cuda_compute_capability(int major, int minor) {
+ device_description_->cuda_compute_capability_major_ = major;
+ device_description_->cuda_compute_capability_minor_ = minor;
+ }
+
+ void set_numa_node(int value) { device_description_->numa_node_ = value; }
+ void set_core_count(int value) { device_description_->core_count_ = value; }
+ void set_ecc_enabled(bool value) {
+ device_description_->ecc_enabled_ = value;
+ }
+
+ // Returns a built DeviceDescription with ownership transferred to the
+ // caller. There are currently no restrictions on which fields must be set in
+ // order to build the descriptor.
+ //
+ // Once the description is built, this builder object should be discarded.
+ std::unique_ptr<DeviceDescription> Build() {
+ return std::move(device_description_);
+ }
+
+ private:
+ std::unique_ptr<DeviceDescription> device_description_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(DeviceDescriptionBuilder);
+};
+
+} // namespace internal
+
+// Returns whether the given thread_dim is acceptable given the limits described
+// in device_description. For detailed reasons for failing the predicate, enable
+// VLOG(2) for this module.
+bool ThreadDimOk(const DeviceDescription &device_description,
+ const ThreadDim &thread_dim);
+
+// [deprecated] Use MathUtil::CeilOfRatio directly instead.
+//
+// Equivalent to ceil(double(element_count) / threads_per_block).
+uint64 DivideCeil(uint64 x, uint64 y);
+
+// Calculate the number of threads/blocks required to process element_count
+// elements. Note that you can still end up with more threads than
+// element_count due to rounding, so kernels often start with an "is this
+// thread id in the element_count range?" test.
+void CalculateDimensionality(const DeviceDescription &device_description,
+ uint64 element_count, uint64 *threads_per_block,
+ uint64 *block_count);
+
+// Compute and return maximum blocks per core (occupancy) based on the
+// device description, some kernel characteristics and the number of threads per
+// block. If unable to compute occupancy, zero is returned.
+uint64 CalculateOccupancy(const DeviceDescription &device_description,
+ uint64 registers_per_thread,
+ uint64 shared_memory_per_block,
+ const ThreadDim &thread_dims);
+
+// Compute and return the maximum number of registers per thread which
+// achieves the target occupancy. If the target is not possible then
+// zero is returned.
+uint64 CalculateRegisterLimitForTargetOccupancy(
+ const DeviceDescription &device_description, uint64 shared_memory_per_block,
+ const ThreadDim &thread_dims, uint64 target_blocks_per_core);
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
diff --git a/tensorflow/stream_executor/device_memory.h b/tensorflow/stream_executor/device_memory.h
new file mode 100644
index 0000000000..9e88180316
--- /dev/null
+++ b/tensorflow/stream_executor/device_memory.h
@@ -0,0 +1,284 @@
+// Suite of types that represent device memory allocations. These are
+// allocated by the StreamExecutor interface, which produces values appropriate
+// for the underlying platform (whether it be CUDA or OpenCL).
+//
+// The untyped base class (like a device void*) is DeviceMemoryBase, which can
+// be specialized for a given allocation type (like a device T*) using
+// DeviceMemory<T>.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_DEVICE_MEMORY_H_
+#define TENSORFLOW_STREAM_EXECUTOR_DEVICE_MEMORY_H_
+
+#include <stddef.h>
+
+#include "tensorflow/stream_executor/lib/casts.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+class StreamExecutor;
+
+// void*-analogous device memory allocation. For the typed variation, see
+// DeviceMemory<T>.
+//
+// This is effectively a two-tuple of a pointer and size; however, note that the
+// pointer may not be to the virtual address itself -- in OpenCL the pointer is
+// to a cl_mem handle that describes the device allocation. Therefore,
+// DeviceMemoryBase::opaque does not necessarily produce a pointer that can be
+// referenced directly, so use it with caution.
+//
+// Thread-compatible.
+class DeviceMemoryBase {
+ public:
+ // Default constructor instantiates a null-pointed, zero-sized device memory
+ // region. An opaque pointer may be provided -- see header for details on the
+ // opacity of that pointer.
+ explicit DeviceMemoryBase(void *opaque = nullptr, uint64 size = 0,
+ bool is_sub_buffer = false)
+ : opaque_(opaque), size_(size), is_sub_buffer_(is_sub_buffer) {}
+
+ // Returns whether the backing memory is the null pointer.
+ // A `== nullptr` convenience method is also provided.
+ bool is_null() const { return opaque_ == nullptr; }
+ bool operator==(std::nullptr_t other) const { return is_null(); }
+ bool operator!=(std::nullptr_t other) const { return !is_null(); }
+
+ // Provides a partial order between device memory values.
+ //
+ // This operator is provided so that this object can be used as a key in an
+ // ordered map.
+ bool operator<(const DeviceMemoryBase &other) const {
+ return opaque() < other.opaque();
+ }
+
+ // Returns the size, in bytes, for the backing memory.
+ uint64 size() const { return size_; }
+
+ // Warning: note that the pointer returned is not necessarily directly to
+ // device virtual address space, but is platform-dependent.
+ void *opaque() { return opaque_; }
+ const void *opaque() const { return opaque_; }
+
+ // Returns true if this is an offset into another primary allocation.
+ bool is_sub_buffer() const { return is_sub_buffer_; }
+
+ // Returns whether the two DeviceMemoryBase segments are identical (both in
+ // their opaque pointer and size).
+ bool IsSameAs(const DeviceMemoryBase &other) const {
+ return opaque() == other.opaque() && size() == other.size();
+ }
+
+ protected:
+ friend class StreamExecutor;
+
+ // Resets the internal values of the opaque pointer and number of bytes in the
+ // memory region, just as in the constructor.
+ void Reset(void *opaque, uint64 bytes) {
+ opaque_ = opaque;
+ size_ = bytes;
+ }
+
+ private:
+ void *opaque_; // Platform-dependent value representing allocated memory.
+ uint64 size_; // Size in bytes of this allocation.
+ bool is_sub_buffer_; // Is this a primary allocation or a sub-buffer?
+};
+
+// Typed wrapper around "void *"-like DeviceMemoryBase.
+//
+// For example, DeviceMemory<int> is a simple wrapper around DeviceMemoryBase
+// that represents one or more integers in Device memory.
+//
+// Thread-compatible.
+template <typename ElemT>
+class DeviceMemory final : public DeviceMemoryBase {
+ public:
+ // Default constructor instantiates a null-pointed, zero-sized memory region.
+ DeviceMemory() : DeviceMemoryBase(nullptr, 0) {}
+
+ // Typed device memory regions may be constructed from untyped device memory
+ // regions, this effectively amounts to a cast from a void*.
+ explicit DeviceMemory(const DeviceMemoryBase &other)
+ : DeviceMemoryBase(const_cast<DeviceMemoryBase &>(other).opaque(),
+ other.size(), other.is_sub_buffer()) {}
+
+ static constexpr size_t kElemSize = sizeof(ElemT);
+
+ // Returns the number of elements of type ElemT that constitute this
+ // allocation.
+ uint64 ElementCount() const { return size() / kElemSize; }
+
+ // Returns whether this is a single-element allocation.
+ bool IsScalar() const { return ElementCount() == 1; }
+
+ // Create a typed area of DeviceMemory with a given opaque pointer and the
+ // quantity of bytes in the allocation. This function is broken out to
+ // distinguish bytes from an element count.
+ static DeviceMemory<ElemT> MakeFromByteSize(void *opaque, uint64 bytes) {
+ return DeviceMemory<ElemT>(opaque, bytes);
+ }
+
+ // Resets the DeviceMemory data, in MakeFromByteSize fashion.
+ // This simply clobbers the prior values.
+ void ResetFromByteSize(void *opaque, uint64 bytes) {
+ // TODO(leary) when NVCC is eliminated we can add this check (and the
+ // logging include it requires).
+ // CHECK_EQ(0, bytes % kElemSize);
+ DeviceMemoryBase::Reset(opaque, bytes);
+ }
+
+ // ------------------------------------------------------------
+ // DO NOT USE - FASTR TEAM-INTERNAL FUNCTIONS
+ // Used internally by gcudacc.
+#ifdef __GCUDACC__
+ // Implicit conversion operators needed to support mixed mode. Since buffer
+ // sizes aren't used in the CUDA launching process, and since the constructed
+ // objects are all temporary, this is safe.
+ // Linter warning disabled as we require an implicit conversion.
+ DeviceMemory(const ElemT *opaque) : // NOLINT
+ DeviceMemoryBase(reinterpret_cast<void *>(const_cast<ElemT *>(opaque)),
+ 0) {}
+
+ operator ElemT *() { return reinterpret_cast<ElemT *>(opaque()); }
+ operator const ElemT *() {
+ return const_cast<const ElemT *>(reinterpret_cast<ElemT *>(opaque()));
+ }
+#endif
+ // ------------------------------------------------------------
+
+ protected:
+ // This constructor is solely used from derived classes; it is made protected
+ // because it accepts a byte-size instead of an element count, which could
+ // potentially be misused given the ElementCount() nature of this interface.
+ //
+ // In order to specify the desire to use byte size instead of element count
+ // explicitly, use MakeFromByteSize.
+ DeviceMemory(void *opaque, uint64 size) : DeviceMemoryBase(opaque, size) {}
+};
+
+// A class to encapsulate the type and size of a dynamic shared memory
+// buffer. Because the buffer exists solely on the device and is not copyable
+// to the host, memory objects of this type do not maintain buffer pointers
+// on the host.
+template <typename ElemT>
+class SharedDeviceMemory final : public DeviceMemoryBase {
+ public:
+ explicit SharedDeviceMemory(uint64 elem_count)
+ : DeviceMemoryBase(nullptr, elem_count * kElemSize) {}
+
+ static constexpr size_t kElemSize = sizeof(ElemT);
+
+ // Returns the number of elements of type ElemT that constitute this
+ // allocation.
+ uint64 ElementCount() const { return size() / kElemSize; }
+
+ // Returns whether this is a single-element allocation.
+ bool IsScalar() const { return ElementCount() == 1; }
+};
+
+// Similar to the typed DeviceMemory, but is the unique owner of its
+// memory, if any. ScopedDeviceMemory is thread-compatible. It is also
+// movable and uncopyable to represent unique ownership.
+template <typename ElemT>
+class ScopedDeviceMemory {
+ public:
+ // Parameters:
+ // parent: Executor used to deallocate memory when this instance goes
+ // out of scope.
+ // value: Already-allocated device memory value for this scoped mechanism to
+ // deallocate. This memory must have been allocated by parent.
+ ScopedDeviceMemory(StreamExecutor *parent, DeviceMemoryBase value);
+
+ // Constructor overload that places a literal array into device memory
+ ScopedDeviceMemory(StreamExecutor *parent,
+ std::initializer_list<ElemT> values);
+
+ // Moves ownership of the memory from other to the constructed
+ // object.
+ //
+ // Postcondition: other == nullptr.
+ ScopedDeviceMemory(ScopedDeviceMemory &&other) noexcept:
+ ScopedDeviceMemory(other.parent_, other.Release()) {}
+
+ // Releases the memory that was provided in the constructor, through the
+ // "parent" StreamExecutor.
+ ~ScopedDeviceMemory();
+
+ // Moves ownership of the memory from other to this object.
+ //
+ // Postcondition: other == nullptr.
+ ScopedDeviceMemory& operator=(ScopedDeviceMemory &&other) {
+ Reset(other.Release());
+ parent_ = other.parent_;
+ return *this;
+ }
+
+ // Returns the memory that backs this scoped allocation converted to
+ // DeviceMemory<T> apparent type. This is useful for cases where the
+ // DeviceMemory must be passed by const-ref, as the ScopedDeviceMemory doesn't
+ // allow copying, for scoped-object-lifetime reasons.
+ const DeviceMemory<ElemT> &cref() const { return wrapped_; }
+
+ // Returns a pointer to the DeviceMemory<T> apparent type for use in mutable
+ // operations. The value returned should not be used outside the scope of this
+ // ScopedDeviceMemory object's lifetime.
+ DeviceMemory<ElemT> *ptr() { return &wrapped_; }
+ const DeviceMemory<ElemT> *ptr() const { return &wrapped_; }
+
+ // Smart-pointer-like operators for the wrapped DeviceMemory.
+ // This reference must not be used outside the lifetime of this
+ // ScopedDeviceMemory.
+ const DeviceMemory<ElemT> &operator*() const { return cref(); }
+ DeviceMemory<ElemT> *operator->() { return ptr(); }
+ const DeviceMemory<ElemT> *operator->() const { return ptr(); }
+ bool operator==(std::nullptr_t other) const { return wrapped_.is_null(); }
+ bool operator!=(std::nullptr_t other) const { return !wrapped_.is_null(); }
+
+ // Analogous to std::unique_ptr::reset, frees the existing memory held in
+ // this scoped memory container and replaces it with updated. Ownership
+ // of updated is transferred to this object.
+ void Reset(DeviceMemory<ElemT> updated);
+ void Reset(std::nullptr_t);
+
+ // Analogous to std::unique_ptr::release, releases ownership of the held
+ // memory and transfers it to the caller.
+ //
+ // Postcondition: *this == nullptr
+ DeviceMemory<ElemT> Release() {
+ auto tmp = wrapped_;
+ wrapped_.ResetFromByteSize(nullptr, 0);
+ return tmp;
+ }
+
+ private:
+ DeviceMemory<ElemT> wrapped_; // Value we wrap with scoped-release.
+ StreamExecutor *parent_; // See constructor.
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ScopedDeviceMemory);
+};
+
+// Host-side representation of packed-and-aligned vector datatypes on the device
+// side. Since these can appear in device kernel signatures, we support
+// launching them with these datatypes in launch signatures.
+
+struct Float2 {
+ float x, y;
+};
+
+struct Float4 {
+ Float2 xz, yw;
+};
+
+struct Double2 {
+ double x, y;
+};
+
+static_assert(sizeof(Float2) == 2 * sizeof(float), "Float2 must be packed");
+static_assert(sizeof(Float4) == 4 * sizeof(float), "Float4 must be packed");
+static_assert(sizeof(Double2) == 2 * sizeof(double), "Double2 must be packed");
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_MEMORY_H_
diff --git a/tensorflow/stream_executor/device_options.h b/tensorflow/stream_executor/device_options.h
new file mode 100644
index 0000000000..bd393a6efb
--- /dev/null
+++ b/tensorflow/stream_executor/device_options.h
@@ -0,0 +1,70 @@
+// Contains device-level options that can be specified at a platform level.
+// Example usage:
+// auto device_options = DeviceOptions::Default();
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_DEVICE_OPTIONS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_DEVICE_OPTIONS_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/platform/logging.h"
+
+namespace perftools {
+namespace gputools {
+
+// Indicates a set of options for a device's usage, which generally must be
+// provided at StreamExecutor device-initialization time.
+//
+// These are intended to be useful-but-not-mandatorily-supported options for
+// using devices on the underlying platform. Presently, if the option requested
+// is not available on the target platform, a warning will be emitted.
+struct DeviceOptions {
+ public:
+ // When it is observed that more memory has to be allocated for thread stacks,
+ // this flag prevents it from ever being deallocated. Potentially saves
+ // thrashing the thread stack memory allocation, but at the potential cost of
+ // some memory space.
+ static const unsigned kDoNotReclaimStackAllocation = 0x1;
+
+ // The following options refer to synchronization options when
+ // using SynchronizeStream or SynchronizeContext.
+
+ // Synchronize with spinlocks.
+ static const unsigned kScheduleSpin = 0x02;
+ // Synchronize with spinlocks that also call CPU yield instructions.
+ static const unsigned kScheduleYield = 0x04;
+ // Synchronize with a "synchronization primitive" (e.g. mutex).
+ static const unsigned kScheduleBlockingSync = 0x08;
+
+ static const unsigned kMask = 0xf; // Mask of all available flags.
+
+ // Constructs an or-d together set of device options.
+ explicit DeviceOptions(unsigned flags) : flags_(flags) {
+ CHECK((flags & kMask) == flags);
+ }
+
+ // Factory for the default set of device options.
+ static DeviceOptions Default() { return DeviceOptions(0); }
+
+ unsigned flags() const { return flags_; }
+
+ bool operator==(const DeviceOptions& other) const {
+ return flags_ == other.flags_;
+ }
+
+ bool operator!=(const DeviceOptions& other) const {
+ return !(*this == other);
+ }
+
+ string ToString() {
+ return flags_ == 0 ? "none" : "kDoNotReclaimStackAllocation";
+ }
+
+ private:
+ unsigned flags_;
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_OPTIONS_H_
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
new file mode 100644
index 0000000000..020de7f7bb
--- /dev/null
+++ b/tensorflow/stream_executor/dnn.cc
@@ -0,0 +1,297 @@
+#include "tensorflow/stream_executor/dnn.h"
+
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace perftools {
+namespace gputools {
+namespace dnn {
+
+string ActivationModeString(ActivationMode mode) {
+ switch (mode) {
+ case ActivationMode::kSigmoid:
+ return "sigmoid";
+ case ActivationMode::kRelu:
+ return "relu";
+ case ActivationMode::kRelu6:
+ return "relu6";
+ case ActivationMode::kReluX:
+ return "reluX";
+ case ActivationMode::kTanh:
+ return "tanh";
+ default:
+ LOG(FATAL) << "Unknown activation_mode " << static_cast<int32>(mode);
+ }
+}
+
+string ElementwiseOperationString(ElementwiseOperation op) {
+ switch (op) {
+ case ElementwiseOperation::kAdd:
+ return "add";
+ case ElementwiseOperation::kMultiply:
+ return "multiply";
+ default:
+ LOG(FATAL) << "Unknown elementwise op " << static_cast<int32>(op);
+ }
+}
+
+string DataLayoutString(DataLayout layout) {
+ switch (layout) {
+ case DataLayout::kYXDepthBatch:
+ return "YXDepthBatch";
+ case DataLayout::kYXBatchDepth:
+ return "YXBatchDepth";
+ case DataLayout::kBatchYXDepth:
+ return "BatchYXDepth";
+ case DataLayout::kBatchDepthYX:
+ return "BatchDepthYX";
+ default:
+ LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout);
+ }
+}
+
+string FilterLayoutString(FilterLayout layout) {
+ switch (layout) {
+ case FilterLayout::kOutputInputYX:
+ return "OutputInputYX";
+ case FilterLayout::kInputYXOutput:
+ return "InputYXOutput";
+ case FilterLayout::kYXInputOutput:
+ return "YXInputOutput";
+ default:
+ LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(layout);
+ }
+}
+
+// -- BatchDescriptor
+
+BatchDescriptor::BatchDescriptor()
+ : count_(0),
+ feature_map_count_(0),
+ height_(0),
+ width_(0),
+ value_max_(0.0),
+ value_min_(0.0),
+ layout_(DataLayout::kYXDepthBatch),
+ quantized_activation_mode_(QuantizedActivationMode::k8Bit) {}
+
+void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
+ count_ = other.count_;
+ feature_map_count_ = other.feature_map_count_;
+ height_ = other.height_;
+ width_ = other.width_;
+ value_max_ = other.value_max_;
+ value_min_ = other.value_min_;
+ layout_ = other.layout_;
+ quantized_activation_mode_ = other.quantized_activation_mode_;
+}
+
+string BatchDescriptor::ToString() const {
+ return port::Printf(
+ "{count: %lld feature_map_count: %lld height: %lld width: %lld "
+ "value_min: %f value_max: %f layout: %s}",
+ count_, feature_map_count_, height_, width_, value_min_, value_max_,
+ DataLayoutString(layout_).c_str());
+}
+
+string BatchDescriptor::ToShortString() const {
+ // All the constituent strings are less than 15 characters, so the
+ // small string optimization ensures that there will be at most one
+ // heap memory allocation.
+ string x = port::StrCat("x", width());
+ string y = port::StrCat("y", height());
+ string depth = port::StrCat("d", feature_map_count());
+ string batch = port::StrCat("b", count());
+
+ string suffix;
+ if (value_min() != value_max()) {
+ port::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
+ }
+ if (quantized_activation_mode() == QuantizedActivationMode::k16Bit) {
+ suffix += "_16bit";
+ }
+
+ switch (layout()) {
+ case DataLayout::kYXDepthBatch:
+ return port::StrCat(y, x, depth, batch, suffix);
+ case DataLayout::kYXBatchDepth:
+ return port::StrCat(y, x, batch, depth, suffix);
+ case DataLayout::kBatchYXDepth:
+ return port::StrCat(batch, y, x, depth, suffix);
+ case DataLayout::kBatchDepthYX:
+ return port::StrCat(batch, depth, y, x, suffix);
+ default:
+ LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
+ }
+}
+
+int64 BatchDescriptor::NodesPerFeatureMap() const { return width_ * height_; }
+
+int64 BatchDescriptor::NodesAcrossFeatureMaps() const {
+ return NodesPerFeatureMap() * feature_map_count_;
+}
+
+int64 BatchDescriptor::ElementCount() const {
+ return count_ * feature_map_count_ * height_ * width_;
+}
+
+int64 BatchDescriptor::FullyConnectedWeightCount(
+ const BatchDescriptor& input, const BatchDescriptor& output) {
+ return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps();
+}
+
+int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) {
+ return output.NodesAcrossFeatureMaps();
+}
+
+// -- FilterDescriptor
+
+FilterDescriptor::FilterDescriptor()
+ : output_feature_map_count_(0),
+ input_feature_map_count_(0),
+ input_filter_height_(0),
+ input_filter_width_(0),
+ layout_(FilterLayout::kOutputInputYX) {}
+
+FilterDescriptor::~FilterDescriptor() {}
+
+void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
+ set_output_feature_map_count(other.output_feature_map_count())
+ .set_input_feature_map_count(other.input_feature_map_count())
+ .set_input_filter_height(other.input_filter_height())
+ .set_input_filter_width(other.input_filter_width())
+ .set_layout(other.layout());
+}
+
+string FilterDescriptor::ToString() const {
+ return port::Printf(
+ "{output_feature_map_count: %lld input_feature_map_count: %lld "
+ "input_filter_height: %lld input_filter_width: %lld layout: %s}",
+ output_feature_map_count_, input_feature_map_count_, input_filter_height_,
+ input_filter_width_, FilterLayoutString(layout_).c_str());
+}
+
+string FilterDescriptor::ToShortString() const {
+ // All the constituent strings are less than 15 characters, so the
+ // small string optimization ensures that there will be at most one
+ // heap memory allocation.
+ string od = port::StrCat("od", output_feature_map_count_);
+ string id = port::StrCat("id", input_feature_map_count_);
+ string y = port::StrCat("y", input_filter_height_);
+ string x = port::StrCat("x", input_filter_width_);
+
+ switch (layout_) {
+ case FilterLayout::kOutputInputYX:
+ return port::StrCat(od, id, y, x);
+ case FilterLayout::kInputYXOutput:
+ return port::StrCat(id, y, x, od);
+ case FilterLayout::kYXInputOutput:
+ return port::StrCat(y, x, id, od);
+ default:
+ LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_);
+ }
+}
+
+int64 FilterDescriptor::ComputeWeightCount() const {
+ return output_feature_map_count_ * input_feature_map_count_ *
+ input_filter_height_ * input_filter_width_;
+}
+
+// -- ConvolutionDescriptor
+
+ConvolutionDescriptor::ConvolutionDescriptor()
+ : zero_padding_height_(0),
+ zero_padding_width_(0),
+ vertical_filter_stride_(1),
+ horizontal_filter_stride_(1) {}
+
+ConvolutionDescriptor::~ConvolutionDescriptor() {}
+
+string ConvolutionDescriptor::ToString() const {
+ return port::Printf(
+ "{zero_padding_height: %lld zero_padding_width: %lld "
+ "vertical_filter_stride: %lld horizontal_filter_stride: %lld}",
+ zero_padding_height_, zero_padding_width_, vertical_filter_stride_,
+ horizontal_filter_stride_);
+}
+
+string ConvolutionDescriptor::ToShortString() const {
+ return port::StrCat("py:", zero_padding_height_, "_px:", zero_padding_width_,
+ "_sy:", vertical_filter_stride_, "_sx:",
+ horizontal_filter_stride_);
+}
+
+// -- PoolingDescriptor
+
+PoolingDescriptor::PoolingDescriptor()
+ : mode_(dnn::PoolingMode::kMaximum),
+ window_height_(0),
+ window_width_(0),
+ vertical_padding_(0),
+ horizontal_padding_(0),
+ vertical_stride_(0),
+ horizontal_stride_(0) {}
+
+void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
+ mode_ = other.mode_;
+ window_height_ = other.window_height_;
+ window_width_ = other.window_width_;
+ vertical_padding_ = other.vertical_padding_;
+ horizontal_padding_ = other.horizontal_padding_;
+ vertical_stride_ = other.vertical_stride_;
+ horizontal_stride_ = other.horizontal_stride_;
+}
+
+string PoolingDescriptor::ToString() const {
+ const char* mode_string =
+ mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
+ return port::Printf(
+ "{mode: %s window_height: %lld window_width: %lld vertical_stride: %lld "
+ "horizontal_stride: %lld vertical padding: %lld horizontal padding: "
+ "%lld}",
+ mode_string, window_height_, window_width_, vertical_stride_,
+ horizontal_stride_, vertical_padding_, horizontal_padding_);
+}
+
+string PoolingDescriptor::ToShortString() const {
+ return port::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
+ "_y:", window_height_, "_x:", window_width_, "_py:",
+ vertical_padding_, "_px:", horizontal_padding_, "_sy:",
+ vertical_stride_, "_sx:", horizontal_stride_);
+}
+
+// -- NormalizeDescriptor
+
+NormalizeDescriptor::NormalizeDescriptor()
+ : bias_(0.0),
+ range_(0),
+ alpha_(0.0),
+ beta_(0.0),
+ wrap_around_(false),
+ segment_size_(0) {}
+
+void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
+ bias_ = other.bias_;
+ range_ = other.range_;
+ alpha_ = other.alpha_;
+ beta_ = other.beta_;
+ wrap_around_ = other.wrap_around_;
+ segment_size_ = other.segment_size_;
+}
+
+string NormalizeDescriptor::ToString() const {
+ return port::Printf(
+ "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
+ "segment_size: %d}",
+ bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
+}
+
+string NormalizeDescriptor::ToShortString() const {
+ return port::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
+ "_beta:", beta_, "_wrap:", wrap_around_, "_size:",
+ segment_size_);
+}
+
+} // namespace dnn
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
new file mode 100644
index 0000000000..e737d1c78f
--- /dev/null
+++ b/tensorflow/stream_executor/dnn.h
@@ -0,0 +1,895 @@
+// Neural Net operation support for StreamExecutor instances.
+//
+// This is an abstract interface for a platform to optionally support common
+// neural net operations; it accommodates implementations such as the cudnn
+// library operations.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_DNN_H_
+#define TENSORFLOW_STREAM_EXECUTOR_DNN_H_
+
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+
+namespace dnn {
+
+// Describes how an input or output layer's data is formatted.
+// Specify int64 so there's no padding in BatchDescriptor.
+enum class DataLayout : int64 {
+ kYXDepthBatch = 0, // Same as dist_belief::DF_DEPTH_MAJOR.
+ kYXBatchDepth, // Same as dist_belief::DF_BATCH_MAJOR.
+ kBatchYXDepth, // Same as run_brain output, and tensorflow's layout.
+ kBatchDepthYX, // cuDNN's NCHW layout, data laid out as image, feature,
+ // maps, rows, columns.
+};
+
+// Returns a string representation of the given data layout.
+string DataLayoutString(DataLayout layout);
+
+// Specifies a quantization for activations in a given BatchDescriptor.
+enum class QuantizedActivationMode {
+ k8Bit = 1,
+ k16Bit = 2,
+ k32Bit = 4,
+};
+
+// Describes the dimensions that a layer consumes/produces.
+//
+// This is a matrix (height, width), its "depth" (feature_map_count),
+// how many of these matrices are present (count),
+// and the maximum and minimum values expected in the matrix (value_max,
+// value_min).
+// If input is quantized, all values greater
+// than value_max will be clipped to value_max and all values less than
+// value_min will be clipped to value_min.
+// When quantized output is dequantized no value will be greater than
+// value_max or less than value_min.
+//
+// Uses the named argument construction form:
+//
+// auto input_batch_dimensions =
+// BatchDescriptor().set_count(42).set_feature_map_count(7)...
+//
+// Details:
+//
+// For a convolutional layer, a single inference takes a 3-dimensional matrix
+// of input and produces a 3-dimensional matrix of output. We call the three
+// dimensions height, width and feature_map_count, where for an image, the
+// height and width correspond to the Y and X pixel indices, respectively, and
+// the feature_map_count corresponds to the RGB dimension of the input data.
+// Then the count indicates how many 3D matrices are being presented to be
+// processed at once; this corresponds to the neural network concept of
+// minibatch size.
+//
+// For a fully connected layer, it's better to put the nodes of the layer in
+// the feature_map_count, and leave the height and weight as degenerate (== 1).
+// Count indicates how many input vectors (degenerate 3D matrices) are to be
+// processed.
+//
+// If unspecified, value_max and value_min default to 0.0.
+// If value_max == value_min the Stream will attempt to derive valid values -
+// for example the output of Relu6 activation will always be in the range
+// [0.0, 6.0].
+//
+// If unspecified, layout defaults to kYXDepthBatch.
+class BatchDescriptor {
+ public:
+ // Creates a "blank" batch descriptor, which should be initialized via the
+ // named argument helpers.
+ BatchDescriptor();
+
+ // Clones values from 'other' for initialization.
+ void CloneFrom(const BatchDescriptor& other);
+
+ string ToString() const;
+ string ToShortString() const;
+
+ // Accessors.
+ int64 count() const { return count_; }
+ int64 feature_map_count() const { return feature_map_count_; }
+ int64 height() const { return height_; }
+ int64 width() const { return width_; }
+ float value_max() const { return value_max_; }
+ float value_min() const { return value_min_; }
+ DataLayout layout() const { return layout_; }
+ QuantizedActivationMode quantized_activation_mode() const {
+ return quantized_activation_mode_;
+ }
+
+ // Named-argument helpers for avoiding user error during construction.
+ BatchDescriptor& set_count(int64 value) {
+ count_ = value;
+ return *this;
+ }
+ BatchDescriptor& set_feature_map_count(int64 value) {
+ feature_map_count_ = value;
+ return *this;
+ }
+ BatchDescriptor& set_height(int64 value) {
+ height_ = value;
+ return *this;
+ }
+ BatchDescriptor& set_width(int64 value) {
+ width_ = value;
+ return *this;
+ }
+ BatchDescriptor& set_value_max(float value) {
+ value_max_ = value;
+ return *this;
+ }
+ BatchDescriptor& set_value_min(float value) {
+ value_min_ = value;
+ return *this;
+ }
+ BatchDescriptor& set_layout(DataLayout layout) {
+ layout_ = layout;
+ return *this;
+ }
+ BatchDescriptor& set_quantized_activation_mode(
+ QuantizedActivationMode quantized_activation_mode) {
+ quantized_activation_mode_ = quantized_activation_mode;
+ return *this;
+ }
+
+ // Return the number of nodes in a single feature map.
+ int64 NodesPerFeatureMap() const;
+
+ // Return the number of nodes across all feature maps. Note that this is not
+ // affected by the batch count.
+ int64 NodesAcrossFeatureMaps() const;
+
+ // Returns the number of elements (e.g. RGB pixel values) required to hold a
+ // given batch descriptor, given a no-padding assumption. Note that this is
+ // affected by the batch count.
+ int64 ElementCount() const;
+
+ // Return the number of weights required to fully connect a layer with
+ // dimensions given by the 'input' descriptor with a layer with dimensions
+ // given by the 'output' descriptor.
+ static int64 FullyConnectedWeightCount(const BatchDescriptor& input,
+ const BatchDescriptor& output);
+
+ // Return the number of biases required to fully connect to an output layer
+ // with dimensions given the 'output' descriptor.
+ static int64 FullyConnectedBiasCount(const BatchDescriptor& output);
+
+ private:
+ int64 count_;
+ int64 feature_map_count_;
+ int64 height_;
+ int64 width_;
+ float value_max_;
+ float value_min_;
+ DataLayout layout_;
+ QuantizedActivationMode quantized_activation_mode_;
+};
+
+// Describes how a filter is laid out in the memory.
+// Specify int64 so there's no padding in FilterDescriptor.
+enum class FilterLayout : int64 {
+ kOutputInputYX = 0, // cuDNN's default filter layout, laid out as:
+ // (major) output feature maps >> input feature maps >>
+ // rows >> columns (minor).
+ kInputYXOutput, // Same as dist_belief's default filter layout.
+ kYXInputOutput, // Same as tensorflow's default filter layout.
+};
+
+// Returns a string representation of the given filter layout.
+string FilterLayoutString(FilterLayout layout);
+
+// Describes a filter for the convolution. This is the "window" from
+// height-by-width patches of each of the feature maps in the input layer to the
+// cells within the output feature map.
+//
+// Uses the named argument construction form:
+//
+// FilterDescriptor filter_dimensions;
+// filter_dimensions
+// .set_output_feature_map_count(42)
+// .set_input_feature_map_count(7)
+// ...
+//
+// Arguments:
+// - output_feature_map_count: number of feature maps in the output layer.
+// - input_feature_map_count: number of feature maps in the input layer (from
+// which the filter patch is taken).
+// - input_filter_height: "height" number of neurons used in the sliding window
+// over the input layer.
+// - input_filter_width: "width" number of neurons used in the sliding window
+// over the input layer.
+//
+// Sometimes names like "filter input height" are referred to by synonymous
+// terminology, such as "kernel y size".
+//
+// If unspecified, layout defaults to kOutputInputYX.
+class FilterDescriptor {
+ public:
+ // By default construction, all dimensions are set to zero, so they should all
+ // be populated by the user via the named-argument helpers below. (See class
+ // comment for details.)
+ FilterDescriptor();
+
+ ~FilterDescriptor();
+
+ // Named-argument helpers for avoiding user error during construction.
+ FilterDescriptor& set_output_feature_map_count(int64 value) {
+ output_feature_map_count_ = value;
+ return *this;
+ }
+ FilterDescriptor& set_input_feature_map_count(int64 value) {
+ input_feature_map_count_ = value;
+ return *this;
+ }
+ FilterDescriptor& set_input_filter_height(int64 value) {
+ input_filter_height_ = value;
+ return *this;
+ }
+ FilterDescriptor& set_input_filter_width(int64 value) {
+ input_filter_width_ = value;
+ return *this;
+ }
+ FilterDescriptor& set_layout(FilterLayout layout) {
+ layout_ = layout;
+ return *this;
+ }
+
+ void CloneFrom(const FilterDescriptor& other);
+
+ string ToString() const;
+ string ToShortString() const;
+
+ // Returns the number of weights required as parameters for a convolution
+ // using this filter descriptor.
+ int64 ComputeWeightCount() const;
+
+ // Returns the number of biases required as parameters for a convolution using
+ // this filter descriptor.
+ int64 bias_count() const { return output_feature_map_count_; }
+
+ int64 output_feature_map_count() const { return output_feature_map_count_; }
+ int64 input_feature_map_count() const { return input_feature_map_count_; }
+ int64 input_filter_height() const { return input_filter_height_; }
+ int64 input_filter_width() const { return input_filter_width_; }
+ FilterLayout layout() const { return layout_; }
+
+ private:
+ int64 output_feature_map_count_;
+ int64 input_feature_map_count_;
+ int64 input_filter_height_;
+ int64 input_filter_width_;
+ FilterLayout layout_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(FilterDescriptor);
+};
+
+// Describes a convolution.
+//
+// Uses the named argument construction form:
+//
+// ConvolutionDescriptor convolution_dimensions;
+// convolution_dimensions
+// .set_vertical_filter_stride(2)
+// .set_horizontal_filter_stride(2)
+// ...
+//
+// Arguments:
+// - zero_padding_height: padding of the "y dimension" of the input data. Note
+// that this is different from the height of the filter.
+// - zero_padding_width: analogouus to the height above, but in the "x
+// dimension".
+// - vertical_filter_stride: the convolution slides a 2-dimensional window of
+// filter-height-by-filter-width over the input layer -- the center of that
+// window is moved in the "y dimension" according to this stride value.
+// - horizontal_filter_stride: analogous to the vertical stride above, but in
+// the "x dimension".
+class ConvolutionDescriptor {
+ public:
+ // By default construction, there is no zero-padding and the filter stride is
+ // 1x1 (centering the filter on every cell in the input layer's
+ // width-by-height area).
+ ConvolutionDescriptor();
+ ~ConvolutionDescriptor();
+
+ string ToString() const;
+ string ToShortString() const;
+
+ ConvolutionDescriptor& set_zero_padding_height(int64 value) {
+ zero_padding_height_ = value;
+ return *this;
+ }
+ ConvolutionDescriptor& set_zero_padding_width(int64 value) {
+ zero_padding_width_ = value;
+ return *this;
+ }
+ ConvolutionDescriptor& set_vertical_filter_stride(int64 value) {
+ vertical_filter_stride_ = value;
+ return *this;
+ }
+ ConvolutionDescriptor& set_horizontal_filter_stride(int64 value) {
+ horizontal_filter_stride_ = value;
+ return *this;
+ }
+
+ int64 zero_padding_height() const { return zero_padding_height_; }
+ int64 zero_padding_width() const { return zero_padding_width_; }
+ int64 vertical_filter_stride() const { return vertical_filter_stride_; }
+ int64 horizontal_filter_stride() const { return horizontal_filter_stride_; }
+
+ private:
+ int64 zero_padding_height_;
+ int64 zero_padding_width_;
+ int64 vertical_filter_stride_;
+ int64 horizontal_filter_stride_;
+ // TODO(leary) cudnn provides these fields, but need to characterize what
+ // their effect is -- they may be boolean rather than integral.
+ // int64 upscale_input_x;
+ // int64 upscale_input_y;
+};
+
+// A patch of values in the input can be pooled via either a max or an average
+// operation.
+// Specify int64 so there's no padding in PoolingDescriptor.
+enum class PoolingMode : int64 {
+ kMaximum,
+ kAverage,
+};
+
+// Describes a pooling operation to be enqueued onto a stream via a platform's
+// DnnSupport.
+//
+// TODO(broune): describe how padding works and what happens if the
+// window height/width is not divisible by the vertical/horizontal
+// stride.
+//
+// Arguments:
+// pooling_mode: pooling operator to use on the input patch
+// window_height: height of input window
+// window_width: width of input window
+// vertical_stride: vertical delta for center of the input patch
+// horizontal_stride: horizontal delta for center of the input patch
+class PoolingDescriptor {
+ public:
+ PoolingDescriptor();
+
+ PoolingDescriptor& set_pooling_mode(PoolingMode value) {
+ mode_ = value;
+ return *this;
+ }
+ PoolingDescriptor& set_window_height(int64 value) {
+ window_height_ = value;
+ return *this;
+ }
+ PoolingDescriptor& set_window_width(int64 value) {
+ window_width_ = value;
+ return *this;
+ }
+ PoolingDescriptor& set_vertical_padding(int64 value) {
+ vertical_padding_ = value;
+ return *this;
+ }
+ PoolingDescriptor& set_horizontal_padding(int64 value) {
+ horizontal_padding_ = value;
+ return *this;
+ }
+ PoolingDescriptor& set_vertical_stride(int64 value) {
+ vertical_stride_ = value;
+ return *this;
+ }
+ PoolingDescriptor& set_horizontal_stride(int64 value) {
+ horizontal_stride_ = value;
+ return *this;
+ }
+
+ void CloneFrom(const PoolingDescriptor& other);
+
+ string ToString() const;
+ string ToShortString() const;
+
+ PoolingMode mode() const { return mode_; }
+ int64 window_height() const { return window_height_; }
+ int64 window_width() const { return window_width_; }
+ int64 vertical_padding() const { return vertical_padding_; }
+ int64 horizontal_padding() const { return horizontal_padding_; }
+ int64 vertical_stride() const { return vertical_stride_; }
+ int64 horizontal_stride() const { return horizontal_stride_; }
+
+ private:
+ PoolingMode mode_;
+ int64 window_height_;
+ int64 window_width_;
+ int64 vertical_padding_;
+ int64 horizontal_padding_;
+ int64 vertical_stride_;
+ int64 horizontal_stride_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(PoolingDescriptor);
+};
+
+// Describes a dist_belief local response normalization.
+// The normalization equation is:
+// y_i = x_i / (bias + alpha * (sum_j_{i - range}^{i + range} x_j^2)) ^ beta
+// where x_i is the input in feature map i, y_i is the output.
+// Each feature map is split into segment_size segments for performing the
+// sum_j_. If wrap_around is true, the sum_j_ for y_i on the left and right of
+// a segment wrap around at the edges of the segment, if wrap_around is false
+// zeros are inserted instead.
+class NormalizeDescriptor {
+ public:
+ NormalizeDescriptor();
+
+ NormalizeDescriptor& set_bias(float bias) {
+ bias_ = bias;
+ return *this;
+ }
+
+ NormalizeDescriptor& set_range(int32 range) {
+ range_ = range;
+ return *this;
+ }
+
+ NormalizeDescriptor& set_alpha(float alpha) {
+ alpha_ = alpha;
+ return *this;
+ }
+
+ NormalizeDescriptor& set_beta(float beta) {
+ beta_ = beta;
+ return *this;
+ }
+
+ NormalizeDescriptor& set_wrap_around(bool wrap_around) {
+ wrap_around_ = wrap_around;
+ return *this;
+ }
+
+ NormalizeDescriptor& set_segment_size(int32 segment_size) {
+ segment_size_ = segment_size;
+ return *this;
+ }
+
+ void CloneFrom(const NormalizeDescriptor& other);
+
+ string ToString() const;
+ string ToShortString() const;
+
+ float bias() const { return bias_; }
+ int32 range() const { return range_; }
+ float alpha() const { return alpha_; }
+ float beta() const { return beta_; }
+ bool wrap_around() const { return wrap_around_; }
+ int32 segment_size() const { return segment_size_; }
+
+ private:
+ float bias_;
+ int32 range_;
+ float alpha_;
+ float beta_;
+ bool wrap_around_;
+ int32 segment_size_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(NormalizeDescriptor);
+};
+
+// Describes a kind of non-linearity (threshold-like mathematical function).
+enum class ActivationMode {
+ kSigmoid,
+ // Rectified linear activation: f(x) = x < 0 ? 0 : x
+ kRelu,
+ // Rectified linear activation, where upper maximum is 6.0.
+ kRelu6,
+ // Rectified linear activation, where upper maximum specified by
+ // BatchDescriptor::value_max().
+ kReluX,
+ kTanh,
+};
+
+// Returns a string representation of the given activation mode.
+string ActivationModeString(ActivationMode mode);
+
+// Describes the operation that DoElementwiseOperation should perform on its
+// inputs.
+enum class ElementwiseOperation {
+ kAdd,
+ kMultiply
+};
+
+string ElementwiseOperationString(ElementwiseOperation op);
+
+// Suite of operations typically used for implementing Deep/Convolutional Neural
+// Nets.
+class DnnSupport {
+ public:
+ DnnSupport() {}
+ virtual ~DnnSupport() {}
+
+ virtual port::Status Init() = 0;
+
+ // Enqueues a single-precision convolution operation onto the stream.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'convolve' operation
+ // should be enqueued onto.
+ // input_descriptor: dimensions of the input layer.
+ // input_data: un-owned device memory region which contains the
+ // convolution input.
+ // filter_descriptor: dimensions of the convolution filter.
+ // weights: coefficients for the convolution filter, these are multiplied
+ // against values in the input that the filter convolves over.
+ // convolution_descriptor: stride of the convolution filter.
+ // output_descriptor: dimensions of the output layer.
+ // output_data: un-owned device memory region in which to place the
+ // convolution result.
+ //
+ // input_descriptor, filter_descriptor, convolution_descriptor and
+ // output_descriptor together specify exactly how the convolution is aligned
+ // with the input data:
+ //
+ // * (input dimensions - filter size + 1) / filter stride == output dimensions
+ // corresponds to dist_belief padding = VALID, i.e. the input is not padded.
+ // * input dimensions / filter stride == output dimensions
+ // corresponds to dist_belief padding = SAME, i.e. input and output are the
+ // same size - this requires padding the input.
+ // * (input dimensions + filter size - 1) / filter stride == output dimensions
+ // corresponds to dist_belief padding = FULL, i.e. the output is sized so
+ // that if the inverse of the filter is applied to the output in VALID mode
+ // the result is the same size as the input - this requires even more
+ // padding
+ // of the input.
+ virtual bool DoConvolve(
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<float>& input_data,
+ const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<float>& filter_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Enqueues a double-precision convolution operation onto the stream.
+ // See DoConvolve above for argument details.
+ virtual bool DoConvolve(
+ Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
+ const DeviceMemory<double>& input_data,
+ const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<double>& filter_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<double>* output_data) = 0;
+
+ // Variation of the above with the weight matrix split into two matrices.
+ // first_weights: Coefficients of the first matrix.
+ // second_weights: Coefficients of the second matrix.
+ // depth_multiplier: specifies the columns of the first matrix and rows
+ // of the second one - first_weights columns = depth_multiplier,
+ // second_weights rows = depth_multiplier *
+ // filter_descriptor.input_feature_map_count().
+ // see go/separable for documentation on separable convolutions.
+ virtual bool DoSeparableConvolve(
+ Stream* stream, const BatchDescriptor& input_descriptor,
+ const DeviceMemory<float>& input_data,
+ const FilterDescriptor& filter_descriptor, int depth_multiplier,
+ const DeviceMemory<float>& first_weights,
+ const DeviceMemory<float>& second_weights,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const BatchDescriptor& output_descriptor,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Enqueues a single-precision backward convolution (for data) operation onto
+ // the stream.
+ //
+ // Arguments:
+ // stream: borrowed pointer to the stream that the 'convolve' operation
+ // should be enqueued onto.
+ // filter_descriptor: dimensions of the convolution filter.
+ // filter_data: coefficients for the convolution filter.
+ // output_descriptor: dimensions of the output gradients, which is the same
+ // as
+ // the dimensions of the ouput.
+ // backward_output_data: un-owned device memory region which contains the
+ // backprop of the output.
+ // convolution_descriptor: stride of the convolution filter.
+ // input_descriptor: dimensions of the input layer.
+ // backward_input_data: un-owned device memory region in which to place the
+ // backprop of the input.
+ virtual bool DoConvolveBackwardData(
+ Stream* stream, const FilterDescriptor& filter_descriptor,
+ const DeviceMemory<float>& filter_data,
+ const BatchDescriptor& output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const BatchDescriptor& input_descriptor,
+ DeviceMemory<float>* backward_input_data) = 0;
+
+ // Enqueues a single-precision backward convolution (for filter) operation
+ // onto
+ // the stream.
+ //
+ // Arguments:
+ // stream: borrowed pointer to the stream that the 'convolve' operation
+ // should be enqueued onto.
+ // input_descriptor: dimensions of the input layer.
+ // input_data: un-owned device memory region which contains the
+ // convolution input.
+ // output_descriptor: dimensions of the output gradients, which is the same
+ // as
+ // the dimensions of the ouput.
+ // backward_output_data: un-owned device memory region which contains the
+ // backprop of the output.
+ // convolution_descriptor: stride of the convolution filter.
+ // filter_descriptor: dimensions of the convolution filter.
+ // backward_filter_data: un-owned device memory region in which to place the
+ // backprop of the filter.
+ virtual bool DoConvolveBackwardFilter(
+ Stream* stream, const BatchDescriptor& input_descriptor,
+ const DeviceMemory<float>& input_data,
+ const BatchDescriptor& output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const FilterDescriptor& filter_descriptor,
+ DeviceMemory<float>* backward_filter_data) = 0;
+
+ // Fully connects the "nodes" (float values) in input_data with
+ // shape input_dimensions to output_data with output_dimensions
+ // using provided weights. This is equivalent to computing a matrix
+ // product, hence the name MatMul.
+ //
+ // A BatchDescriptor has four dimensions: batch, y, x, depth. Matrix products
+ // happen in two dimensions. To get down to two dimensions, we consider the
+ // input y, x and depth dimension as one combined dimension T. For now,
+ // assume that the output height and width are 1 and let OD be the output
+ // depth.
+ //
+ // There are three device memory buffers passed in to this
+ // function. We can now view all three as matrices:
+ //
+ // input_data: A batch x T matrix
+ // weights: A T x OD matrix
+ // output_data: A batch x OD matrix
+ //
+ // This function then computes the matrix product of input_data and
+ // weights and writes the result into output_data.
+ //
+ // Here the weights buffer is in row major order, i.e. the first OD
+ // entries in weights are the first row, the second OD entries in
+ // weights are the second row and so on.
+ //
+ // The case for output width*height > 1 is more complicated. Let K =
+ // OY * OX where OY is the output height and OX is the output
+ // width. Then weights is divided into K sub-arrays W_i, for
+ // i=0,...,k-1, that each represent a T x OD matrix. This function
+ // then computes the K matrix multiplications of input_data with
+ // each W_i. This creates K matrices with dimensions batch x
+ // OD. These K matrices are concatenated horizontally to form one
+ // larger matrix with dimensions batch x (K*OD); note that this is
+ // not the same as concatenating the bytes of the matrices. The
+ // combined matrix can then be interpreted as a tensor with
+ // dimensions (batch, OY, OX, OD). If the output tensor format is
+ // not kBatchYXDepth, this function would then need to arrange for
+ // the output to be in the requested layout, if that is
+ // supported. Note that the case K=1 is equivalent to the
+ // description above. It is recommended to prefer the case K=1.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'fully connect' operation
+ // should be enqueued onto.
+ // output_data: un-owned device memory region in which to place the
+ // fully connected result.
+ virtual bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
+ const DeviceMemory<float>& weights,
+ const dnn::BatchDescriptor& input_dimensions,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Version of DoMatMul that uses pre-quantized 8 bit weights.
+ // weight_scales specifies the scaling of each column of weights:
+ // original float weight[row * num_columns + column] =
+ // quantized_weight[row * nnum_columns + column] * weight_scales[column].
+ virtual bool DoMatMulQuantized(Stream* stream,
+ const DeviceMemory<float>& input_data,
+ const DeviceMemory<int8>& quantized_weights,
+ const DeviceMemory<float>& weight_scales,
+ const dnn::BatchDescriptor& input_dimensions,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Version of DoMatMul that uses pre-quantized 16 bit weights.
+ // weight_scales specifies the scaling of each column of weights:
+ // original float weight[row * num_columns + column] =
+ // quantized_weight[row * nnum_columns + column] * weight_scales[column].
+ virtual bool DoMatMulQuantized(Stream* stream,
+ const DeviceMemory<float>& input_data,
+ const DeviceMemory<int16>& quantized_weights,
+ const DeviceMemory<float>& weight_scales,
+ const dnn::BatchDescriptor& input_dimensions,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Adds biases to the feature maps in input_data producing
+ // output_data. input_data can equal output_data, but must not
+ // partially overlap it.
+ //
+ // Let K = count() * height() * width() and N = feature_map_count()
+ // on dimensions. Then input_value contains K*N values and biases
+ // contains N values. We can thus logically consider input_value to
+ // contain K vectors of N elements each. This function adds biases
+ // to each of those N vectors.
+ //
+ // TODO(broune): This works differently when width() * height() > 1
+ // and the call to ThenBiasAdd() follows a call to ThenMatMul(). In
+ // that case there should be width() * height() *
+ // feature_map_count() biases, but this is not implemented on all
+ // StreamExecutors.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'bias add' operation
+ // should be enqueued onto.
+ // input_data: un-owned device memory region containing the input.
+ // biases: un-owned device memory region containing biases to add to the
+ // input.
+ // dimensions: dimensions of input_data and output_data.
+ // output_data: un-owned device memory region in which to place the result.
+ virtual bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
+ const DeviceMemory<float>& biases,
+ const dnn::BatchDescriptor& dimensions,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Performs a forward pooling operation on input_data, writing to
+ // output_data. See PoolingDescriptor for how to configure the
+ // pooling operation.
+ //
+ // Pooling happens as a window that moves across the Y and X
+ // dimensions of input_data, where each position of the window
+ // yields one output value. E.g. for max pooling, the computed value
+ // is the maximum element in the window. The operation is applied
+ // independently to each batch and at each feature map (depth), so
+ // that the output depth and feature_map_count are the same as for
+ // the input. The output width and height can be different.
+ //
+ // See PoolingDescriptor for how to configure the pooling operation.
+ virtual bool DoPoolForward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Performs differentiation of the pooling operation.
+ virtual bool DoPoolBackward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<float>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ const DeviceMemory<float>& output_data,
+ const DeviceMemory<float>& input_diff_data,
+ DeviceMemory<float>* output_diff_data) = 0;
+
+ // Applies local response normalization to all of the values
+ // held on the device in 'input_data'.
+ virtual bool DoNormalize(Stream* stream,
+ const dnn::NormalizeDescriptor& normalize_descriptor,
+ const DeviceMemory<float>& input_data,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Applies an activation function (see ActivationMode) to all of the values
+ // held on the device in 'input_data', whose dimensions are described by
+ // 'dimensions'.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'activate' operation
+ // should be enqueued onto.
+ // activation_mode: Type of activation to perform.
+ // input_data: un-owned device memory region which contains the
+ // activate input.
+ // output_data: un-owned device memory region in which to place the
+ // activate result.
+ virtual bool DoActivate(Stream* stream, ActivationMode activation_mode,
+ const BatchDescriptor& dimensions,
+ const DeviceMemory<float>& input_data,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Concatenates several layers into one, by concatenating the depth of each
+ // layer at matching x and y coordinates.
+ // The inputs must all have the same width and height, the output will have
+ // the same width and height as the inputs and its depth will be the sum of
+ // the input depths.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'depth concatenate'
+ // operation should be enqueued onto.
+ // input_dimensions: The dimensions of each input.
+ // input_data: un-owned device memory region which contains the
+ // input data for each input layer.
+ // output_data: un-owned device memory region in which to place the
+ // depth concatenate result.
+ virtual bool DoDepthConcatenate(
+ Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float>*> input_data,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Computes the specified operation (e.g. addition or multiplication)
+ // between corresponding elements in the inputs and stores the result in the
+ // output element.
+ // The inputs and output must all have the same dimensions, but may have
+ // different quantization parameters (min_value and max_value).
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'elementwise operation'
+ // should be enqueued onto.
+ // operation: The operation to perform.
+ // input_dimensions: The dimensions of each input.
+ // input_data: un-owned device memory region which contains the
+ // input data for each input layer.
+ // output_dimensions: The dimensions of the output.
+ // output_data: un-owned device memory region in which to place the
+ // operation result.
+ virtual bool DoElementwiseOperate(
+ Stream* stream, ElementwiseOperation operation,
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float>*> input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<float>* output_data) = 0;
+
+ // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
+ // is, bytes instead of scaled floats) into 'host_dst' if they are available
+ // for the underlying DNN implementation. If this quantized output is not
+ // available, false is returned, which will place 'stream' into an error
+ // state.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'quantized memcpy'
+ // operation should be enqueued onto.
+ // gpu_unquantized_src: the device memory that contains the unquantized data
+ // -- this data should also have a corresponding quantized representation
+ // on the device for this operation to succeed.
+ // host_dst: un-owned host memory region that is mutated in place,
+ // it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
+ // (asynchronous) memcpy operation is performed.
+ // TODO(wgulland) Merge all these versions of DoMemcpyD2HQuantized.
+ virtual bool DoMemcpyD2HQuantized(
+ Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
+ port::MutableArraySlice<uint8> host_dst) = 0;
+
+ // As above, but for 16-bit values.
+ virtual bool DoMemcpyD2HQuantized(
+ Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
+ port::MutableArraySlice<uint16> host_dst) = 0;
+
+ // As above, but for signed 32-bit values.
+ virtual bool DoMemcpyD2HQuantized(
+ Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
+ port::MutableArraySlice<int32> host_dst) = 0;
+
+ // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
+ // of a layer (that is, bytes instead of scaled floats) if they are supported
+ // by the underlying DNN implementation. If this quantized input is not
+ // supported, false is returned, which will place 'stream' into an error
+ // state.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'quantized memcpy'
+ // operation should be enqueued onto.
+ // host_src: un-owned host memory region that contains the quantized data.
+ // gpu_unquantized_dst: the device memory that is clobbered by the values in
+ // 'host_src' when the enqueued (asynchronous) memcpy operation is
+ // performed. -- this data should also have a corresponding quantized
+ // representation on the device for this operation to
+ // succeed.
+ virtual bool DoMemcpyH2DQuantized(
+ Stream* stream, port::ArraySlice<uint8> host_src,
+ DeviceMemory<float>* gpu_unquantized_dst) = 0;
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(DnnSupport);
+};
+
+} // namespace dnn
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
new file mode 100644
index 0000000000..4ac14ea30b
--- /dev/null
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -0,0 +1,208 @@
+#include "tensorflow/stream_executor/dso_loader.h"
+
+#include <dlfcn.h>
+#include <limits.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <initializer_list>
+#include "tensorflow/stream_executor/platform/port.h"
+#include <vector>
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/str_util.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/lib/str_util.h"
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+
+/* static */ port::Status DsoLoader::GetCublasDsoHandle(void** dso_handle) {
+ return GetDsoHandle(FindDsoPath("libcublas.so.7.0",
+ "third_party/gpus/cuda/lib64"),
+ dso_handle);
+}
+
+/* static */ port::Status DsoLoader::GetCudnnDsoHandle(void** dso_handle) {
+ // libcudnn is versioned differently than the other libraries. See b/22397368
+ // for some details about the complications surrounding this.
+ return GetDsoHandle(FindDsoPath("libcudnn.so.6.5",
+ "third_party/gpus/cuda/lib64"),
+ dso_handle);
+}
+
+/* static */ port::Status DsoLoader::GetCufftDsoHandle(void** dso_handle) {
+ return GetDsoHandle(FindDsoPath("libcufft.so.7.0",
+ "third_party/gpus/cuda/lib64"),
+ dso_handle);
+}
+
+/* static */ port::Status DsoLoader::GetCurandDsoHandle(void** dso_handle) {
+ return GetDsoHandle(FindDsoPath("libcurand.so.7.0",
+ "third_party/gpus/cuda/lib64"),
+ dso_handle);
+}
+
+/* static */ port::Status DsoLoader::GetLibcudaDsoHandle(void** dso_handle) {
+ return GetDsoHandle(FindDsoPath("libcuda.so",
+ "third_party/gpus/cuda/driver/lib64"),
+ dso_handle);
+}
+
+/* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) {
+ return GetDsoHandle(
+ FindDsoPath("libcupti.so.7.0",
+ "third_party/gpus/cuda/extras/CUPTI/lib64"),
+ dso_handle);
+}
+
+/* static */ void DsoLoader::RegisterRpath(port::StringPiece path) {
+ mutex_lock lock{rpath_mutex_};
+ GetRpaths()->push_back(path.ToString());
+}
+
+
+/* static */ port::Status DsoLoader::GetDsoHandle(port::StringPiece path,
+ void** dso_handle,
+ LoadKind load_kind) {
+
+ int dynload_flags =
+ RTLD_LAZY | (load_kind == LoadKind::kLocal ? RTLD_LOCAL : RTLD_GLOBAL);
+ string path_string = path.ToString();
+ *dso_handle = dlopen(path_string.c_str(), dynload_flags);
+ if (*dso_handle == nullptr) {
+ LOG(INFO) << "LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
+ // TODO(b/22689637): Eliminate unnecessary ToString once StrCat has been
+ // moved to the open-sourceable version.
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ port::StrCat("could not dlopen DSO: ", path, "; dlerror: ", dlerror()));
+ }
+
+ VLOG(2) << "loaded path \"" << path << "\" "
+ << (load_kind == LoadKind::kLocal ? "locally" : "globally");
+ return port::Status::OK();
+}
+
+/* static */ string DsoLoader::GetBinaryDirectory(bool strip_executable_name) {
+ char exe_path[PATH_MAX] = {0};
+ CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1));
+ // Make sure it's null-terminated:
+ exe_path[sizeof(exe_path) - 1] = 0;
+
+ if (strip_executable_name) {
+ // The exe is the last component of the path, so remove one component.
+ std::vector<string> components = port::Split(exe_path, '/');
+ components.pop_back();
+ return port::Join(components, "/");
+ }
+ return exe_path;
+}
+
+// Creates a heap-allocated vector for initial rpaths.
+// Ownership is transferred to the caller.
+static std::vector<string>* CreatePrimordialRpaths() {
+ auto rpaths = new std::vector<string>;
+ rpaths->push_back(
+ "driver/driver_sh.runfiles/third_party/gpus/cuda/lib64");
+ return rpaths;
+}
+
+/* static */ mutex DsoLoader::rpath_mutex_{LINKER_INITIALIZED};
+/* static */ std::vector<string>* DsoLoader::GetRpaths() {
+ static std::vector<string>* rpaths = CreatePrimordialRpaths();
+ return rpaths;
+}
+
+/* static */ bool DsoLoader::TrySymbolicDereference(string* candidate) {
+ char buf[PATH_MAX];
+ char* result = realpath(candidate->c_str(), buf);
+ if (result == nullptr) {
+ return false;
+ }
+ VLOG(3) << "realpath resolved candidate path \"" << *candidate << "\" to \""
+ << result << "\"";
+ *candidate = result;
+ return true;
+}
+
+/* static */ string DsoLoader::FindDsoPath(port::StringPiece library_name,
+ port::StringPiece runfiles_relpath) {
+
+ // Keep a record of the paths we attempted so we can dump out meaningful
+ // diagnostics if no path is found.
+ std::vector<string> attempted;
+
+ using StringPieces = std::vector<port::StringPiece>;
+ string candidate;
+
+ // Otherwise, try binary-plus-rpath locations.
+ string binary_directory =
+ GetBinaryDirectory(true /* = strip_executable_name */);
+ mutex_lock lock{rpath_mutex_};
+ for (const string& rpath : *GetRpaths()) {
+ candidate =
+ port::Join(StringPieces{binary_directory, rpath, library_name}, "/");
+ if (TrySymbolicDereference(&candidate)) {
+ return candidate;
+ }
+ }
+ attempted.push_back(candidate);
+
+ return library_name.ToString();
+}
+
+// -- CachedDsoLoader
+
+/* static */ port::StatusOr<void*> CachedDsoLoader::GetCublasDsoHandle() {
+ static port::StatusOr<void*> result =
+ FetchHandleResult(DsoLoader::GetCublasDsoHandle);
+ return result;
+}
+
+/* static */ port::StatusOr<void*> CachedDsoLoader::GetCurandDsoHandle() {
+ static port::StatusOr<void*> result =
+ FetchHandleResult(DsoLoader::GetCurandDsoHandle);
+ return result;
+}
+
+/* static */ port::StatusOr<void*> CachedDsoLoader::GetCudnnDsoHandle() {
+ static port::StatusOr<void*> result =
+ FetchHandleResult(DsoLoader::GetCudnnDsoHandle);
+ return result;
+}
+
+/* static */ port::StatusOr<void*> CachedDsoLoader::GetCufftDsoHandle() {
+ static port::StatusOr<void*> result =
+ FetchHandleResult(DsoLoader::GetCufftDsoHandle);
+ return result;
+}
+
+/* static */ port::StatusOr<void*> CachedDsoLoader::GetLibcudaDsoHandle() {
+ static port::StatusOr<void*> result =
+ FetchHandleResult(DsoLoader::GetLibcudaDsoHandle);
+ return result;
+}
+
+/* static */ port::StatusOr<void*> CachedDsoLoader::GetLibcuptiDsoHandle() {
+ static port::StatusOr<void*> result =
+ FetchHandleResult(DsoLoader::GetLibcuptiDsoHandle);
+ return result;
+}
+
+/* static */ port::StatusOr<void*> CachedDsoLoader::FetchHandleResult(
+ std::function<port::Status(void**)> load_dso) {
+ void* handle;
+ auto status = load_dso(&handle);
+ if (!status.ok()) {
+ return status;
+ }
+ return handle;
+}
+
+} // namespace internal
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/dso_loader.h b/tensorflow/stream_executor/dso_loader.h
new file mode 100644
index 0000000000..4dcc48d231
--- /dev/null
+++ b/tensorflow/stream_executor/dso_loader.h
@@ -0,0 +1,107 @@
+// Common DSO loading functionality: exposes callables that dlopen DSOs
+// in either the runfiles directories
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+#include <vector>
+
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+
+// Permits StreamExecutor code to dynamically load a pre-determined set of
+// relevant DSOs via dlopen.
+//
+// Thread-safe.
+class DsoLoader {
+ public:
+ // The following methods either load the DSO of interest and return a dlopen
+ // handle or error status in the canonical namespace.
+
+ static port::Status GetCublasDsoHandle(void** dso_handle);
+ static port::Status GetCudnnDsoHandle(void** dso_handle);
+ static port::Status GetCufftDsoHandle(void** dso_handle);
+ static port::Status GetCurandDsoHandle(void** dso_handle);
+ static port::Status GetLibcudaDsoHandle(void** dso_handle);
+ static port::Status GetLibcuptiDsoHandle(void** dso_handle);
+
+ // Registers a new binary-relative path to use as a dlopen search path.
+ static void RegisterRpath(port::StringPiece path);
+
+ private:
+ // Registered rpaths (singleton vector) and a mutex that guards it.
+ static std::vector<string>* GetRpaths();
+ static mutex rpath_mutex_;
+
+ // Descriptive boolean wrapper to indicate whether symbols are made available
+ // to resolve in later-loaded libraries.
+ enum class LoadKind { kLocal, kGlobal };
+
+ // Loads a DSO from the given "path" (which can technically be any dlopen-able
+ // name). If the load kind is global, the symbols in the loaded DSO are
+ // visible to subsequent DSO loading operations.
+ static port::Status GetDsoHandle(port::StringPiece path, void** dso_handle,
+ LoadKind load_kind = LoadKind::kLocal);
+
+
+ // Returns the binary directory (or binary path) associated with the currently
+ // executing program. If strip_executable_name is true, the executable file is
+ // stripped off of the path.
+ static string GetBinaryDirectory(bool strip_executable_name);
+
+ // Returns the location of the runfiles directory.
+ // * Manual invocation gets the runfiles as a relative path to the current
+ // executable.
+ static string GetRunfilesDirectory();
+
+ // Invokes realpath on the original path; updates candidate and returns true
+ // if it succeeds (i.e. a file exists at the path); otherwise, returns false.
+ static bool TrySymbolicDereference(string* candidate);
+
+ // Attempts to find a path to the DSO of interest, otherwise returns the
+ // bare library name:
+ // Arguments:
+ // library_name: the filename in tree; e.g. libOpenCL.so.1.0.0
+ // runfiles_relpath: where to look for the library relative to the runfiles
+ // root; e.g. third_party/gpus/cuda/lib64
+ static string FindDsoPath(port::StringPiece library_name,
+ port::StringPiece runfiles_relpath);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(DsoLoader);
+};
+
+// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
+// more than once.
+class CachedDsoLoader {
+ public:
+ // Cached versions of the corresponding DsoLoader methods above.
+ static port::StatusOr<void*> GetCublasDsoHandle();
+ static port::StatusOr<void*> GetCudnnDsoHandle();
+ static port::StatusOr<void*> GetCufftDsoHandle();
+ static port::StatusOr<void*> GetCurandDsoHandle();
+ static port::StatusOr<void*> GetLibcudaDsoHandle();
+ static port::StatusOr<void*> GetLibcuptiDsoHandle();
+
+ private:
+ // Fetches a DSO handle via "load_dso" and returns the StatusOr form of the
+ // result.
+ static port::StatusOr<void*> FetchHandleResult(
+ std::function<port::Status(void**)> load_dso);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CachedDsoLoader);
+};
+
+} // namespace internal
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_
diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc
new file mode 100644
index 0000000000..79c3d39f24
--- /dev/null
+++ b/tensorflow/stream_executor/event.cc
@@ -0,0 +1,48 @@
+#include "tensorflow/stream_executor/event.h"
+
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+#include "tensorflow/stream_executor/stream.h"
+
+namespace perftools {
+namespace gputools {
+
+internal::EventInterface* CreateEventImplementation(
+ StreamExecutor* stream_exec) {
+ PlatformKind platform_kind = stream_exec->platform_kind();
+ switch (platform_kind) {
+ case PlatformKind::kCuda:
+ return (*internal::MakeCUDAEventImplementation())(stream_exec);
+ default:
+ LOG(FATAL) << "Cannot create event implementation for platform kind: "
+ << PlatformKindString(platform_kind);
+ }
+}
+
+Event::Event(StreamExecutor* stream_exec)
+ : implementation_(CreateEventImplementation(stream_exec)),
+ stream_exec_(stream_exec) {}
+
+Event::~Event() {
+ auto status = stream_exec_->DeallocateEvent(this);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
+}
+
+bool Event::Init() {
+ auto status = stream_exec_->AllocateEvent(this);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ return false;
+ }
+
+ return true;
+}
+
+Event::Status Event::PollForStatus() {
+ return stream_exec_->PollForEventStatus(this);
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h
new file mode 100644
index 0000000000..fdd5112d9a
--- /dev/null
+++ b/tensorflow/stream_executor/event.h
@@ -0,0 +1,63 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_EVENT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_EVENT_H_
+
+#include <memory>
+
+namespace perftools {
+namespace gputools {
+
+namespace internal {
+class EventInterface;
+}
+
+class Stream;
+class StreamExecutor;
+
+// The Event class, when supported by a platform, enables low-overhead status
+// reporting for a Stream. An Event is inserted at a location in a stream via
+// the Stream::ThenRecordEvent() API. From then on, the Event's status can be
+// monitored via the nonblocking Event::PollForStatus() call.
+class Event {
+ public:
+ // Potential states for an Event. If PollForStatus() returns anything aside
+ // from kPending or kComplete, an error has occurred; kUnknown is a bad state.
+ // Not all implementations are able to return all enumeration values. Refer to
+ // the platform-specific implementation for details.
+ enum class Status {
+ kUnknown,
+ kError,
+ kPending,
+ kComplete,
+ };
+
+ explicit Event(StreamExecutor* stream_exec); // NOLINT
+
+ // Releases any resources held by the Event object.
+ ~Event();
+
+ // Performs any platform-specific or potentially error-generating
+ // initialization.
+ bool Init();
+
+ // Returns the current Status for the event.
+ Status PollForStatus();
+
+ // Returns a pointer to the underlying platform-specific implementation.
+ internal::EventInterface* implementation() { return implementation_.get(); }
+
+ private:
+ friend class Stream;
+
+ // Pointer to the platform-specific EventInterface implementation underlying
+ // the object. Owned.
+ std::unique_ptr<internal::EventInterface> implementation_;
+
+ // Pointer to the StreamExecutor interface used to create this object.
+ // Not owned.
+ StreamExecutor* stream_exec_;
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_EVENT_H_
diff --git a/tensorflow/stream_executor/executor_cache.cc b/tensorflow/stream_executor/executor_cache.cc
new file mode 100644
index 0000000000..7bf1a9aa4a
--- /dev/null
+++ b/tensorflow/stream_executor/executor_cache.cc
@@ -0,0 +1,43 @@
+#include "tensorflow/stream_executor/executor_cache.h"
+
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace perftools {
+namespace gputools {
+
+port::Status ExecutorCache::Insert(const StreamExecutorConfig& config,
+ std::unique_ptr<StreamExecutor> entry) {
+ if (Get(config).ok()) {
+ return port::Status(port::error::ALREADY_EXISTS,
+ "An executor with a matching config already exists.");
+ }
+
+ cache_[config.ordinal].emplace_back(Entry(config, std::move(entry)));
+
+ return port::Status::OK();
+}
+
+port::StatusOr<StreamExecutor*> ExecutorCache::Get(
+ const StreamExecutorConfig& config) {
+ auto entries = cache_.find(config.ordinal);
+ if (entries == cache_.end()) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::Printf("No executors registered for ordinal %d", config.ordinal));
+ }
+
+ for (const auto& iter : entries->second) {
+ if (iter.first.plugin_config == config.plugin_config &&
+ iter.first.device_options == config.device_options) {
+ return iter.second.get();
+ }
+ }
+
+ return port::Status(port::error::NOT_FOUND,
+ "No executor found with a matching config.");
+}
+
+void ExecutorCache::DestroyAllExecutors() { cache_.clear(); }
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/executor_cache.h b/tensorflow/stream_executor/executor_cache.h
new file mode 100644
index 0000000000..4d1d9ddb07
--- /dev/null
+++ b/tensorflow/stream_executor/executor_cache.h
@@ -0,0 +1,45 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
+
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+namespace perftools {
+namespace gputools {
+
+// Utility class to allow Platform objects to manage cached StreamExecutors.
+class ExecutorCache {
+ public:
+ ExecutorCache() {}
+
+ // Inserts a new StreamExecutor with the given configuration into the cache.
+ // Will not overwrite if called when a matching element is already present.
+ port::Status Insert(const StreamExecutorConfig& config,
+ std::unique_ptr<StreamExecutor> executor);
+
+ // Returns a pointer to the described executor (if one with a matching config
+ // has been created), or a NOT_FOUND status.
+ port::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
+
+ // Destroys all Executors and clears the cache.
+ // Performs no synchronization - undefined behavior may occur if any executors
+ // are active!
+ void DestroyAllExecutors();
+
+ private:
+ typedef std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>
+ Entry;
+
+ // Maps ordinal number to a list of cached executors for that ordinal.
+ // We key off of ordinal (instead of just looking up all fields in the
+ // StreamExecutorConfig) for a slight improvement in lookup time.
+ std::map<int, std::vector<Entry>> cache_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h
new file mode 100644
index 0000000000..b47921d8f2
--- /dev/null
+++ b/tensorflow/stream_executor/fft.h
@@ -0,0 +1,187 @@
+// Exposes the family of FFT routines as pre-canned high performance calls for
+// use in conjunction with the StreamExecutor abstraction.
+//
+// Note that this interface is optionally supported by platforms; see
+// StreamExecutor::SupportsFft() for details.
+//
+// This abstraction makes it simple to entrain FFT operations on GPU data into
+// a Stream -- users typically will not use this API directly, but will use the
+// Stream builder methods to entrain these operations "under the hood". For
+// example:
+//
+// DeviceMemory<std::complex<float>> x =
+// stream_exec->AllocateArray<std::complex<float>>(1024);
+// DeviceMemory<std::complex<float>> y =
+// stream_exec->AllocateArray<std::complex<float>>(1024);
+// // ... populate x and y ...
+// Stream stream{stream_exec};
+// std::unique_ptr<Plan> plan =
+// stream_exec.AsFft()->Create1dPlan(&stream, 1024, Type::kC2CForward);
+// stream
+// .Init()
+// .ThenFft(plan.get(), x, &y)
+// .BlockHostUntilDone();
+//
+// By using stream operations in this manner the user can easily intermix custom
+// kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned FFT
+// routines.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_FFT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_FFT_H_
+
+#include <complex>
+#include <memory>
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+template <typename ElemT>
+class DeviceMemory;
+
+namespace fft {
+
+// Specifies FFT input and output types, and the direction.
+// R, D, C, and Z stand for SP real, DP real, SP complex, and DP complex.
+enum class Type {
+ kC2CForward,
+ kC2CInverse,
+ kC2R,
+ kR2C,
+ kZ2ZForward,
+ kZ2ZInverse,
+ kZ2D,
+ kD2Z
+};
+
+// FFT plan class. Each FFT implementation should define a plan class that is
+// derived from this class. It does not provide any interface but serves
+// as a common type that is used to execute the plan.
+class Plan {
+ public:
+ virtual ~Plan() {}
+};
+
+// FFT support interface -- this can be derived from a GPU executor when the
+// underlying platform has an FFT library implementation available. See
+// StreamExecutor::AsFft().
+//
+// This support interface is not generally thread-safe; it is only thread-safe
+// for the CUDA platform (cuFFT) usage; host side FFT support is known
+// thread-compatible, but not thread-safe.
+class FftSupport {
+ public:
+ virtual ~FftSupport() {}
+
+ // Creates a 1d FFT plan.
+ virtual std::unique_ptr<Plan> Create1dPlan(Stream *stream, uint64 num_x,
+ Type type, bool in_place_fft) = 0;
+
+ // Creates a 2d FFT plan.
+ virtual std::unique_ptr<Plan> Create2dPlan(Stream *stream, uint64 num_x,
+ uint64 num_y, Type type,
+ bool in_place_fft) = 0;
+
+ // Creates a 3d FFT plan.
+ virtual std::unique_ptr<Plan> Create3dPlan(Stream *stream, uint64 num_x,
+ uint64 num_y, uint64 num_z,
+ Type type, bool in_place_fft) = 0;
+
+ // Creates a batched FFT plan.
+ //
+ // stream: The GPU stream in which the FFT runs.
+ // rank: Dimensionality of the transform (1, 2, or 3).
+ // elem_count: Array of size rank, describing the size of each dimension.
+ // input_embed, output_embed:
+ // Pointer of size rank that indicates the storage dimensions
+ // of the input/output data in memory. If set to null_ptr all
+ // other advanced data layout parameters are ignored.
+ // input_stride: Indicates the distance (number of elements; same below)
+ // between two successive input elements.
+ // input_distance: Indicates the distance between the first element of two
+ // consecutive signals in a batch of the input data.
+ // output_stride: Indicates the distance between two successive output
+ // elements.
+ // output_distance: Indicates the distance between the first element of two
+ // consecutive signals in a batch of the output data.
+ virtual std::unique_ptr<Plan> CreateBatchedPlan(
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed,
+ uint64 output_stride, uint64 output_distance, Type type,
+ bool in_place_fft, int batch_count) = 0;
+
+ // Computes complex-to-complex FFT in the transform direction as specified
+ // by direction parameter.
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<std::complex<float>> *output) = 0;
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<std::complex<double>> *output) = 0;
+
+ // Computes real-to-complex FFT in forward direction.
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<float> &input,
+ DeviceMemory<std::complex<float>> *output) = 0;
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<double> &input,
+ DeviceMemory<std::complex<double>> *output) = 0;
+
+ // Computes complex-to-real FFT in inverse direction.
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<float> *output) = 0;
+ virtual bool DoFft(Stream *stream, Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<double> *output) = 0;
+
+ protected:
+ FftSupport() {}
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(FftSupport);
+};
+
+// Macro used to quickly declare overrides for abstract virtuals in the
+// fft::FftSupport base class. Assumes that it's emitted somewhere inside the
+// ::perftools::gputools namespace.
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
+ std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \
+ fft::Type type, bool in_place_fft) \
+ override; \
+ std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64 num_x, \
+ uint64 num_y, fft::Type type, \
+ bool in_place_fft) override; \
+ std::unique_ptr<fft::Plan> Create3dPlan( \
+ Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, \
+ fft::Type type, bool in_place_fft) override; \
+ std::unique_ptr<fft::Plan> CreateBatchedPlan( \
+ Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed, \
+ uint64 input_stride, uint64 input_distance, uint64 *output_embed, \
+ uint64 output_stride, uint64 output_distance, fft::Type type, \
+ bool in_place_fft, int batch_count) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<float>> &input, \
+ DeviceMemory<std::complex<float>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<double>> &input, \
+ DeviceMemory<std::complex<double>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<float> &input, \
+ DeviceMemory<std::complex<float>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<double> &input, \
+ DeviceMemory<std::complex<double>> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<float>> &input, \
+ DeviceMemory<float> *output) override; \
+ bool DoFft(Stream *stream, fft::Plan *plan, \
+ const DeviceMemory<std::complex<double>> &input, \
+ DeviceMemory<double> *output) override;
+
+} // namespace fft
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_FFT_H_
diff --git a/tensorflow/stream_executor/gcuda.cc b/tensorflow/stream_executor/gcuda.cc
new file mode 100644
index 0000000000..505534c08f
--- /dev/null
+++ b/tensorflow/stream_executor/gcuda.cc
@@ -0,0 +1,87 @@
+#include "tensorflow/stream_executor/gcuda.h"
+
+namespace perftools {
+namespace gputools {
+
+// Returns the mapping of gcudacc kernel stub to preferred cache
+// configuration. C++ static singleton pattern.
+std::map<void *, KernelCacheConfig> &GetGcudaccStubToCacheConfigMap() {
+ static std::map<void *, KernelCacheConfig> cache_config_by_stub;
+ return cache_config_by_stub;
+}
+
+shared_mem_config::SharedMemConfig DeviceGetSharedMemConfig(
+ StreamExecutor *stream_exec) {
+ SharedMemoryConfig config = stream_exec->GetDeviceSharedMemoryConfig();
+
+ switch (config) {
+ case SharedMemoryConfig::kDefault:
+ return shared_mem_config::kDefaultBankSize;
+ case SharedMemoryConfig::kFourByte:
+ return shared_mem_config::kFourByteBankSize;
+ case SharedMemoryConfig::kEightByte:
+ return shared_mem_config::kEightByteBankSize;
+ default:
+ LOG(FATAL) << "Impossible shared memory config returned: "
+ << static_cast<int>(config);
+ }
+}
+
+void DeviceSetSharedMemConfig(StreamExecutor *stream_exec,
+ shared_mem_config::SharedMemConfig config) {
+ SharedMemoryConfig executor_config;
+ switch (config) {
+ case shared_mem_config::kDefaultBankSize:
+ executor_config = SharedMemoryConfig::kDefault;
+ break;
+ case shared_mem_config::kFourByteBankSize:
+ executor_config = SharedMemoryConfig::kFourByte;
+ break;
+ case shared_mem_config::kEightByteBankSize:
+ executor_config = SharedMemoryConfig::kEightByte;
+ break;
+ default:
+ LOG(FATAL) << "Impossible shared memory config specified: "
+ << static_cast<int>(config);
+ }
+
+ if (!stream_exec->SetDeviceSharedMemoryConfig(executor_config).ok()) {
+ // The message is logged at a higher level.
+ LOG(INFO) << "Unable to set cache configuration; proceeding.";
+ }
+}
+
+template <>
+void FuncSetCacheConfig<void *>(Stream *stream, void *fptr,
+ cache_config::CacheConfig cache_config) {
+ // Map from the legacy to the C++11 type.
+ KernelCacheConfig kernel_cache_config;
+ switch (cache_config) {
+ case cache_config::kPreferShared:
+ kernel_cache_config = KernelCacheConfig::kPreferShared;
+ break;
+ case cache_config::kPreferL1:
+ kernel_cache_config = KernelCacheConfig::kPreferL1;
+ break;
+ case cache_config::kPreferEqual:
+ kernel_cache_config = KernelCacheConfig::kPreferEqual;
+ break;
+ default:
+ kernel_cache_config = KernelCacheConfig::kNoPreference;
+ }
+ auto cache_config_map = GetGcudaccStubToCacheConfigMap();
+ cache_config_map[fptr] = kernel_cache_config;
+}
+
+template <>
+KernelCacheConfig FuncGetCacheConfig<void *>(void *fptr) {
+ auto cache_config_map = GetGcudaccStubToCacheConfigMap();
+ auto iter = cache_config_map.find(fptr);
+ if (iter == cache_config_map.end()) {
+ return KernelCacheConfig::kNoPreference;
+ }
+ return cache_config_map[fptr];
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/gcuda.h b/tensorflow/stream_executor/gcuda.h
new file mode 100644
index 0000000000..24b09c5358
--- /dev/null
+++ b/tensorflow/stream_executor/gcuda.h
@@ -0,0 +1,415 @@
+// Common declarations and includes for mixed-mode GPU usage at Google.
+//
+// This header serves to define a "common baseline" for GPU usage,
+// either with gcudacc or nvcc, and on the host or device. The rule of thumb is,
+// "if you're working with mixed-mode GPU code at Google, include this header."
+#ifndef TENSORFLOW_STREAM_EXECUTOR_GCUDA_H_
+#define TENSORFLOW_STREAM_EXECUTOR_GCUDA_H_
+
+// Symbol glossary:
+// __CUDACC__: CUDA capable compiler, compiling host or device
+// __CUDA_ARCH__: Compiling device code
+// __GCUDACC__: Using gcudacc
+// __NVCC__: Using nvcc
+
+// For device code compiled with gcudacc, CUDA_ASSUME(X) tells the compiler
+// that it may assume that X is true. This can enable further optimization.
+// It is undefined behavior if X is not true. X should not have side-effects
+// and gcudacc will try to warn you if it does.
+#if defined(__CUDA_ARCH__) && defined(__GCUDACC__)
+#define CUDA_ASSUME(X) __builtin_assume(X)
+#else
+#define CUDA_ASSUME(X) do {} while (false)
+#endif
+
+namespace perftools {
+namespace gputools {
+namespace cache_config {
+// A version of the KernelCacheConfig enum class, exposed for pre-C++11
+// compilers.
+enum CacheConfig {
+ // Indicates no preference for device L1/shared memory configuration.
+ kNoPreference,
+
+ // Indicates a preference for more shared memory than L1 cache.
+ kPreferShared,
+
+ // Indicates a preference for more L1 cache than shared memory.
+ kPreferL1,
+
+ // Indicates a preference for equal amounts of L1 cache and shared memory.
+ kPreferEqual,
+};
+} // namespace cache_config
+
+namespace shared_mem_config {
+// A compatability-layer declaration of CUsharedconfig, needed to support
+// cuFuncSetSharedMemConfig/cudaDeviceSetSharedMemConfig. Declared here for
+// compatability with pre-C++11 compilers.
+enum SharedMemConfig {
+ // Indicates that the context's shared memory config should be used.
+ kDefaultBankSize,
+
+ // Specifies a four-byte bank size for shared memory.
+ kFourByteBankSize,
+
+ // Specifies an eight-byte bank size for shared memory.
+ kEightByteBankSize,
+};
+} // namespace shared_mem_config
+} // namespace gputools
+} // namespace perftools
+
+#if !defined(__NVCC__) && !defined(GCUDACC_STANDALONE_MODE)
+// Using gcudacc, either device-only or mixed-mode code. No special declarations
+// are needed for host-only code being compiled under gcudacc.
+
+// These includes are required by the code introduced during gcudacc operation.
+// Since the user code may not directly include these headers, they may not be
+// present in the build environment without inclusion here.
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/kernel.h"
+#include "tensorflow/stream_executor/kernel_cache_config.h"
+#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/machine_manager.h"
+#include "tensorflow/stream_executor/shared_memory_config.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+// cudaConfigureCall is a symbol used by Clang when it sees a CUDA triple-angle-
+// bracket launch, so we declare it here so the symbol resolves. It is not used
+// by gcudacc-generated code, however, so it is not defined anywhere.
+// In other words, this is a dummy declaration needed for parsing.
+
+#ifdef __GCUDACC__
+// These symbols only need to be defined during compilation with gcudacc.
+namespace perftools {
+namespace gputools {
+
+// This class defines all the implicit conversions necessary to match launch
+// dimensions against the cudaConfigureCall() signature, and sits where a dim3
+// usually would in triple angle launches. This supports the kernel launch
+// dimension styles:
+// kernel<<<1, 1>>>() and
+// kernel<<<BlockDim(...), ThreadDim(...)>>> and
+// kernel<<<dim3(1), dim3(1)>>>
+// All of these are predicated upon implicit conversions, which are frowned upon
+// by the style guide. Rather then add this CUDA-specific bad behavior to
+// StreamExecutor headers, we isolate it here.
+class LaunchDimConverter {
+ public:
+ LaunchDimConverter(unsigned long long int i) : _dim(i, 1, 1) {} // NOLINT
+ LaunchDimConverter(::perftools::gputools::BlockDim dim)
+ : // NOLINT
+ _dim(dim.x, dim.y, dim.z) {}
+ LaunchDimConverter(::perftools::gputools::ThreadDim dim)
+ : // NOLINT
+ _dim(dim.x, dim.y, dim.z) {}
+ LaunchDimConverter(dim3 dim) : // NOLINT
+ _dim(dim.x, dim.y, dim.z) {}
+
+ ::perftools::gputools::BlockDim AsBlockDim() {
+ return ::perftools::gputools::BlockDim(_dim.x, _dim.y, _dim.z);
+ }
+
+ ::perftools::gputools::ThreadDim AsThreadDim() {
+ return ::perftools::gputools::ThreadDim(_dim.x, _dim.y, _dim.z);
+ }
+
+ private:
+ ::perftools::gputools::Dim3D _dim;
+};
+} // namespace gputools
+} // namespace perftools
+
+int cudaConfigureCall(::perftools::gputools::LaunchDimConverter grid_size,
+ ::perftools::gputools::LaunchDimConverter block_size,
+ unsigned shared_size = 0,
+ ::perftools::gputools::Stream *stream = 0);
+#endif
+
+// The rest of the symbols in this block are needed during both StreamExecutor
+// and user library compilation.
+namespace perftools {
+namespace gputools {
+
+// Gets the preferred shared memory configuration for the device to which
+// the specified executor is bound.
+shared_mem_config::SharedMemConfig DeviceGetSharedMemConfig(
+ StreamExecutor *stream_exec);
+
+// Sets the preferred shared memory configuration for the device to which
+// the specified executor is bound.
+// Does not return an error if the current device is invalid.
+void DeviceSetSharedMemConfig(StreamExecutor *stream_exec,
+ shared_mem_config::SharedMemConfig config);
+
+// Sets the preferred cache configuration for the given kernel.
+template <typename KernelT>
+void FuncSetCacheConfig(Stream *stream, KernelT kernel,
+ cache_config::CacheConfig cache_config) {
+ FuncSetCacheConfig(stream, reinterpret_cast<void *>(kernel), cache_config);
+}
+
+// Internal specialization of the above.
+template <>
+void FuncSetCacheConfig<void *>(Stream *stream, void *kernel,
+ cache_config::CacheConfig cache_config);
+
+// Gets the preferred cache configuration for the given kernel.
+template <typename KernelT>
+KernelCacheConfig FuncGetCacheConfig(KernelT kernel) {
+ return FuncGetCacheConfig(reinterpret_cast<void *>(kernel));
+}
+
+// Internal specialization of the above.
+template <>
+KernelCacheConfig FuncGetCacheConfig<void *>(void *kernel);
+
+} // namespace gputools
+} // namespace perftools
+
+#elif defined(__NVCC__)
+// NVCC code compilation, device-only or mixed mode. As above, no special
+// declarations are needed for host-only code.
+namespace perftools {
+namespace gputools {
+class Stream;
+} // namespace gputools
+} // namespace perftools
+
+// --- BEGIN EXTERNALLY-DEFINED FUNCTIONS
+
+// The following functions must be defined in some external library linked in to
+// the final binary - they are _not_ defined in the StreamExecutor
+// (in nvcc mode).
+
+// Sets the preferred cache configuration for the specified kernel.
+template <typename KernelT>
+void SetCudaCacheConfig(perftools::gputools::Stream* stream, KernelT kernel,
+ ::perftools::gputools::cache_config::CacheConfig preference);
+
+// Gets the current device for use in CUDA runtime-emulating routines.
+// "device" is the device ordinal as returned by
+// StreamExecutor::device_ordinal().
+int GetDevice();
+
+// Sets the current device for use in CUDA runtime-emulating routines.
+// "device" is the device ordinal as returned by
+// StreamExecutor::device_ordinal().
+void SetDevice(int device);
+
+// --- END EXTERNALLY-DEFINED FUNCTIONS
+
+namespace perftools {
+namespace gputools {
+template <typename KernelT>
+void FuncSetCacheConfig(Stream *stream, KernelT kernel,
+ cache_config::CacheConfig cache_config) {
+ SetCudaCacheConfig(stream, reinterpret_cast<void*>(kernel), cache_config);
+}
+} // namespace gputools
+} // namespace perftools
+
+// The following functions are declared extern "C" in CUDA's device_functions.h,
+// so we have to wrap them for compatability with the cuda_builtin namespace.
+// Thin wrappers to break these functions out of cuda_builtin are defined below.
+__forceinline__ __device__ clock_t __gcuda_nvcc_clock() { return clock(); }
+__forceinline__ __device__ int __gcuda_nvcc__clz(int x) {
+ return __clz(x);
+}
+__forceinline__ __device__ int __gcuda_nvcc__clzll(long long int x) {
+ return __clzll(x);
+}
+__forceinline__ __device__ float __gcuda_nvcc__fdividef(float a, float b) {
+ return __fdividef(a, b);
+}
+__forceinline__ __device__ int __gcuda_nvcc__ffsll(long long int x) { // NOLINT
+ return __ffsll(x);
+}
+__forceinline__ __device__ int __gcuda_nvcc__popc(unsigned int x) {
+ return __popc(x);
+}
+__forceinline__ __device__ float __gcuda_nvcc__powf(float a, float b) {
+ return __powf(a, b);
+}
+__forceinline__ __device__ void __gcuda_nvcc__sincosf(
+ float x, float *sptr, float *cptr) {
+ __sincosf(x, sptr, cptr);
+}
+__forceinline__ __device__ unsigned int __gcuda_nvcc__umulhi(
+ unsigned int x, unsigned int y) {
+ return __umulhi(x, y);
+}
+
+#if __CUDA_ARCH__ >= 200 || !defined(__CUDA_ARCH__)
+__forceinline__ __device__ unsigned int __gcuda_nvcc__ballot(int x) {
+ return __ballot(x);
+}
+#endif // __CUDA_ARCH__ >= 200 || !defined(__CUDA_ARCH__)
+
+// Forward-declare printf as nvcc does not declare it by itself and we
+// need this file to compile even if it is included before including
+// stdio.h or cstdio.
+int printf(const char* format, ...);
+
+namespace cuda_builtin {
+using ::abs;
+using ::atomicAdd;
+using ::atomicCAS;
+using ::ceil;
+using ::ceilf;
+using ::cos;
+using ::cosf;
+using ::erfcinv;
+using ::erfcinvf;
+using ::exp;
+using ::expf;
+using ::fabs;
+using ::fabsf;
+using ::floor;
+using ::floorf;
+using ::fabs;
+using ::fabsf;
+using ::fma;
+using ::fmaf;
+using ::fmax;
+using ::fmaxf;
+using ::fmin;
+using ::fminf;
+using ::log;
+using ::log1p;
+using ::log1pf;
+using ::logf;
+using ::max;
+using ::min;
+using ::powf;
+using ::printf;
+using ::sin;
+using ::sinf;
+using ::sincos;
+using ::sincosf;
+using ::sincospi;
+using ::sincospif;
+using ::sqrt;
+using ::sqrtf;
+using ::tanh;
+using ::trunc;
+using ::truncf;
+using ::trunc;
+
+// rsqrt and rsqrtf are functions defined by nvcc in both host and device mode.
+// Add these functions to gcuda.h such that it is also host device. In device
+// side they correspond to intrinsics while explicit definitions are provided
+// below for host side.
+#ifdef __CUDA_ARCH__
+using ::rsqrt;
+using ::rsqrtf;
+#else
+__forceinline__ __host__ __device__ float rsqrtf(float x) {
+ return 1 / std::sqrt(x);
+}
+__forceinline__ __host__ __device__ double rsqrt(double x) {
+ return 1 / std::sqrt(x);
+}
+#endif
+
+__forceinline__ __device__ int clock() { return __gcuda_nvcc_clock(); }
+
+__forceinline__ __device__ int __clz(int x) {
+ return __gcuda_nvcc__clz(x);
+}
+
+__forceinline__ __device__ int __clz(long long int x) {
+ return __gcuda_nvcc__clzll(x);
+}
+
+__forceinline__ __device__ float __fdividef(float a, float b) {
+ return __gcuda_nvcc__fdividef(a, b);
+}
+
+__forceinline__ __device__ int __ffsll(long long int x) { // NOLINT
+ return __gcuda_nvcc__ffsll(x);
+}
+
+__forceinline__ __device__ int __popc(unsigned int x) {
+ return __gcuda_nvcc__popc(x);
+}
+
+__forceinline__ __device__ float __powf(float a, float b) {
+ return __gcuda_nvcc__powf(a, b);
+}
+
+__forceinline__ __device__ void __sincosf(float x, float *sptr, float *cptr) {
+ __gcuda_nvcc__sincosf(x, sptr, cptr);
+}
+
+__forceinline__ __device__ unsigned int __umulhi(unsigned int x,
+ unsigned int y) {
+ return __gcuda_nvcc__umulhi(x, y);
+}
+
+#ifdef __CUDA_ARCH__
+// These symbols are only visible when parsing device code.
+using ::__double_as_longlong;
+using ::__int_as_float;
+using ::__float_as_int;
+using ::__longlong_as_double;
+#endif // __CUDA_ARCH__
+
+#if __CUDA_ARCH__ >= 200 || !defined(__CUDA_ARCH__)
+__forceinline__ __device__ unsigned int __ballot(int x) {
+ return __gcuda_nvcc__ballot(x);
+}
+#endif // __CUDA_ARCH__ >= 200 || !defined(__CUDA_ARCH__)
+
+#if __CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__)
+using ::__shfl;
+using ::__shfl_down;
+using ::__shfl_up;
+using ::__shfl_xor;
+#endif // __CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__)
+
+#if __CUDA_ARCH__ >= 320 || !defined(__CUDA_ARCH__)
+using ::__ldg;
+#endif // __CUDA_ARCH__ >= 320 || !defined(__CUDA_ARCH__)
+
+#if __CUDA_API_VERSION < 6050
+// CUDA < 6.5 defines isfinite as a macro, while CUDA >= 6.5 and gcudacc
+// define isfinite as a function. Work around this for the CUDA 5.5 case,
+// duplicating that macro definition.
+#undef isfinite
+#define __gcuda_nvcc_isfinite(x) \
+ (sizeof(x) == sizeof(float) ? __finitef(x) : \
+ sizeof(x) == sizeof(double) ? __finite(x) : __finitel(x))
+inline __device__ int isfinite(float x) {
+ return __gcuda_nvcc_isfinite(x);
+}
+inline __device__ int isfinite(double x) {
+ return __gcuda_nvcc_isfinite(x);
+}
+inline __device__ int isfinite(long double x) {
+ return __gcuda_nvcc_isfinite(x);
+}
+#else
+// CUDA API >= v6.5
+using ::isfinite;
+#endif // __CUDA_API_VERSION >= 6050
+} // namespace cuda_builtin
+
+#if __CUDA_API_VERSION >= 6050
+// The second part of the isfinite workaround.
+inline __device__ int isfinite(float x) {
+ return __gcuda_nvcc_isfinite(x);
+}
+inline __device__ int isfinite(double x) {
+ return __gcuda_nvcc_isfinite(x);
+}
+inline __device__ int isfinite(long double x) {
+ return __gcuda_nvcc_isfinite(x);
+}
+#endif // __CUDA_API_VERSION >= 6050
+
+#endif // defined(__NVCC__)
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_GCUDA_H_
diff --git a/tensorflow/stream_executor/gpu_launch_dim.h b/tensorflow/stream_executor/gpu_launch_dim.h
new file mode 100644
index 0000000000..51182b2d32
--- /dev/null
+++ b/tensorflow/stream_executor/gpu_launch_dim.h
@@ -0,0 +1,8 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_LAUNCH_DIM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_LAUNCH_DIM_H_
+
+// TODO(rspringer): Temporary redirection until all users - including gcudacc -
+// are using the new file.
+#include "tensorflow/stream_executor/launch_dim.h"
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_LAUNCH_DIM_H_
diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc
new file mode 100644
index 0000000000..5e7fe95627
--- /dev/null
+++ b/tensorflow/stream_executor/kernel.cc
@@ -0,0 +1,95 @@
+// Implementation of the pointer-to-implementation wrapper for the data-parallel
+// kernel abstraction. KernelBase just delegates to the internal
+// platform-specific implementation instance.
+
+#include "tensorflow/stream_executor/kernel.h"
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/lib/demangle.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace perftools {
+namespace gputools {
+
+bool KernelMetadata::registers_per_thread(int *registers_per_thread) const {
+ if (has_registers_per_thread_) {
+ *registers_per_thread = registers_per_thread_;
+ return true;
+ }
+
+ return false;
+}
+
+void KernelMetadata::set_registers_per_thread(int registers_per_thread) {
+ registers_per_thread_ = registers_per_thread;
+ has_registers_per_thread_ = true;
+}
+
+bool KernelMetadata::shared_memory_bytes(int *shared_memory_bytes) const {
+ if (has_shared_memory_bytes_) {
+ *shared_memory_bytes = shared_memory_bytes_;
+ return true;
+ }
+
+ return false;
+}
+
+void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) {
+ shared_memory_bytes_ = shared_memory_bytes;
+ has_shared_memory_bytes_ = true;
+}
+
+static internal::KernelInterface *KernelImplementationFromPlatformKind(
+ PlatformKind platform_kind) {
+ if (platform_kind == PlatformKind::kCuda) {
+ return (*internal::MakeCUDAKernelImplementation())();
+ } else if (platform_kind == PlatformKind::kOpenCL ||
+ platform_kind == PlatformKind::kOpenCLAltera) {
+ return (*internal::MakeOpenCLKernelImplementation())();
+ } else {
+ LOG(FATAL) << "cannot create kernel implementation for platform kind: "
+ << PlatformKindString(platform_kind);
+ }
+}
+
+KernelBase::KernelBase(StreamExecutor *parent)
+ : implementation_(
+ KernelImplementationFromPlatformKind(parent->platform_kind())),
+ parent_(parent) {
+ DCHECK(parent_ != nullptr);
+}
+
+KernelBase::KernelBase(StreamExecutor *parent,
+ internal::KernelInterface *implementation)
+ : implementation_(implementation), parent_(parent) {}
+
+KernelBase::~KernelBase() {}
+
+unsigned KernelBase::Arity() const { return implementation_->Arity(); }
+
+void KernelBase::SetPreferredCacheConfig(KernelCacheConfig config) {
+ return implementation_->SetPreferredCacheConfig(config);
+}
+
+KernelCacheConfig KernelBase::GetPreferredCacheConfig() const {
+ return implementation_->GetPreferredCacheConfig();
+}
+
+// Prefix stub functions emitted by the CUDA splitter.
+static const char *kStubPrefix = "__device_stub_";
+
+void KernelBase::set_name(port::StringPiece name) {
+ name_ = name.ToString();
+ port::StringPiece stubless_name = name;
+ if (name.starts_with(kStubPrefix)) {
+ stubless_name.remove_prefix(strlen(kStubPrefix));
+ }
+ demangled_name_ = port::Demangle(stubless_name.data());
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
new file mode 100644
index 0000000000..da646d0f40
--- /dev/null
+++ b/tensorflow/stream_executor/kernel.h
@@ -0,0 +1,499 @@
+// Suite of datatypes to represent data-parallel kernel objects (code entities).
+// Kernel is the untyped variant, whereas TypedKernel takes a type signature
+// to do some template-based helper generation and give compile-time type
+// checking for kernel launch parameters.
+//
+// Users typically don't see KernelBase, they see typed kernels, analogous to a
+// typed function pointer. TypedKernels express their argument types via
+// template parameters like so:
+//
+// TypedKernel<DeviceMemory<int>*, int>
+//
+// Which expresses a data parallel kernel signature for:
+//
+// void(int*, int);
+//
+// And for a const memory region:
+//
+// TypedKernel<const DeviceMemory<int>&, int>
+//
+// Corresponds to a data parallel kernel signature for:
+//
+// void(const int*, int)
+//
+// Note that kernels always have a void return type, so results typically must
+// be memcpy'ied from device memory to the host.
+//
+// Also note that a scalar integer residing in device memory and an array of
+// integers residing in device memory have the same signature: DeviceMemory<T>.
+// However, in the future, checks may be added for additional safety that arrays
+// of minimum sizes are passed when those minimum sizes are contractually
+// expected by the kernel.
+//
+// For user-defined types whose definitions are appropriately shared between the
+// host code doing the launching and the kernel code being launched, the user
+// defined types are similarly permitted to be expressed as residing in device
+// memory:
+//
+// TypedKernel<DeviceMemory<MyUserDefinedStructure>>
+//
+// And, when the alignment and padding are agreed upon, POD types will also be
+// able to be passed by value; for example, it is a common idiom to specify a
+// bunch of options simultaneously with a structure:
+//
+// TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
+//
+// Which corresponds to a data parallel kernel signature like:
+//
+// void(MyOptionsStructurePassedByValue value, float *result);
+//
+// Users typically won't need to type out the TypedKernel signature in full, it
+// will be typedef'd by automatically generated code; for example, see
+// perftools::gputools::executor_sample::VecReduceAddKernel.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
+
+#include <memory>
+#include <tuple>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/kernel_cache_config.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
+
+namespace perftools {
+namespace gputools {
+
+class DeviceMemoryBase;
+template <typename ElemT>
+class DeviceMemory;
+class StreamExecutor;
+
+namespace internal {
+class KernelInterface;
+} // namespace internal
+
+// KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
+// registers allocated, shared memory used, etc.
+// Not all platforms support reporting of all information, so each accessor
+// returns false if the associated field is not populated in the underlying
+// platform.
+class KernelMetadata {
+ public:
+ KernelMetadata()
+ : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
+
+ // Returns the number of registers used per thread executing this kernel.
+ bool registers_per_thread(int *registers_per_thread) const;
+
+ // Sets the number of registers used per thread executing this kernel.
+ void set_registers_per_thread(int registers_per_thread);
+
+ // Returns the amount of [static] shared memory used per block executing this
+ // kernel. Note that dynamic shared memory allocations are not (and can not)
+ // be reported here (since they're not specified until kernel launch time).
+ bool shared_memory_bytes(int *shared_memory_bytes) const;
+
+ // Sets the amount of [static] shared memory used per block executing this
+ // kernel.
+ void set_shared_memory_bytes(int shared_memory_bytes);
+
+ private:
+ // Holds the value returned by registers_per_thread above.
+ bool has_registers_per_thread_;
+ int registers_per_thread_;
+
+ // Holds the value returned by shared_memory_bytes above.
+ bool has_shared_memory_bytes_;
+ int64 shared_memory_bytes_;
+};
+
+// A data-parallel kernel (code entity) for launching via the StreamExecutor,
+// analogous to a void* device function pointer. See TypedKernel for the typed
+// variant.
+//
+// Thread-compatible.
+class KernelBase {
+ public:
+ // Constructs an "empty" (not-yet-loaded) kernel instance.
+ //
+ // parent is the StreamExecutor that will be responsible for loading the
+ // implementation of this kernel. It must not be null.
+ explicit KernelBase(StreamExecutor *parent);
+
+ // Test-only constructor that can take a mock KernelInterface implementation.
+ // Takes ownership of implementation, it should not be null.
+ KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
+
+ // Releases resources associated with the kernel instance (i.e.
+ // platform-specific implementation).
+ ~KernelBase();
+
+ // Returns the number of parameters that this kernel accepts. (Arity refers to
+ // nullary, unary, ...).
+ unsigned Arity() const;
+
+ // Returns the StreamExecutor that represents the platform this kernel
+ // executes upon.
+ StreamExecutor *parent() const { return parent_; }
+
+ // Returns a const pointer to the (opaque) platform-dependent implementation.
+ const internal::KernelInterface *implementation() const {
+ return implementation_.get();
+ }
+
+ // Returns a non-const pointer to the (opaque) platform-dependent
+ // implementation.
+ internal::KernelInterface *implementation() { return implementation_.get(); }
+
+ void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }
+
+ const KernelMetadata &metadata() const { return metadata_; }
+
+ // Sets the preferred cache configuration for a kernel. This is just a
+ // suggestion to the runtime, and may not be honored during execution.
+ void SetPreferredCacheConfig(KernelCacheConfig config);
+
+ // Gets the preferred cache configuration for a kernel.
+ KernelCacheConfig GetPreferredCacheConfig() const;
+
+ void set_name(port::StringPiece name);
+ const string &name() const { return name_; }
+ const string &demangled_name() const { return demangled_name_; }
+
+ private:
+ // Implementation delegated to for platform-specific functionality.
+ std::unique_ptr<internal::KernelInterface> implementation_;
+
+ // The StreamExecutor that loads this kernel object.
+ StreamExecutor *parent_;
+
+ string name_;
+ string demangled_name_;
+
+ KernelMetadata metadata_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
+};
+
+// Whether T is a DeviceMemory-family pointer.
+template <typename T>
+struct IsDeviceMemoryPointer {
+ static constexpr bool value = false;
+};
+
+template <typename U>
+struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
+ static constexpr bool value = true;
+};
+
+template <>
+struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
+ static constexpr bool value = true;
+};
+
+// Whether T is a DeviceMemory-family value-like thing (which includes a
+// reference). This trait is useful because we pack values in the same manner as
+// references.
+template <typename T>
+struct IsDeviceMemoryValueLike {
+ static constexpr bool value = false;
+};
+
+template <typename U>
+struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
+ static constexpr bool value = true;
+};
+
+// We need to treat SharedDeviceMemory types differently than other DeviceMemory
+// types (since they maintain no allocations), hence these specializations.
+template <typename U>
+struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
+ static constexpr bool value = false;
+};
+
+template <>
+struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
+ static constexpr bool value = true;
+};
+
+template <typename U>
+struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
+ static constexpr bool value = true;
+};
+
+template <typename U>
+struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
+ static constexpr bool value = false;
+};
+
+template <>
+struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
+ static constexpr bool value = true;
+};
+
+template <typename U>
+struct IsSharedDeviceMemory {
+ static constexpr bool value = false;
+};
+
+template <typename U>
+struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
+ static constexpr bool value = true;
+};
+
+template <typename U>
+struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
+ static constexpr bool value = true;
+};
+
+// KernelArg encapsulates the information necessary for a back-end executor to
+// configure a kernel to launch using the given argument.
+struct KernelArg {
+ // Indicates the type of an argument: normal, to be passed to the kernel
+ // in the standard manner, or shared memory, which has distinct
+ // rules for specification per backend.
+ enum Type {
+ kNormal,
+ kSharedMemory,
+ } type;
+
+ // The data to pass to the kernel - either a pointer to device memory, or the
+ // argument value. compact_array is used to prevent smaller args (ex. u8, u64)
+ // from requiring heap allocation.
+ port::InlinedVector<uint8, 4> data;
+
+ // The size of this argument in bytes.
+ uint64 bytes;
+};
+
+// Typed variant of KernelBase, like a typed device function pointer. See the
+// file comment for details and example usage.
+//
+// This class contains template metaprogramming magic to type check the
+// parameters passed to a kernel launch are acceptable, and subsequently pack
+// them into a form which can be used by the StreamExecutorInterface
+// implementation. (i.e. CUDA and OpenCL both bind void*s with associated
+// sizes as kernel arguments.)
+//
+// Thread-compatible.
+template <typename... Params>
+class TypedKernel : public KernelBase {
+ public:
+ // Delegates to KernelBase::KernelBase(), see that constructor.
+ explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
+
+ // Test-only constructor that can take a mock KernelInterface implementation.
+ // Takes ownership of implementation, it should not be null.
+ TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
+ : KernelBase(parent, implementation) {}
+
+ private:
+ // Stream needs access to the specific parameter-packing functionality that
+ // the TypedKernel provides for its corresponding type signature (and no other
+ // type signatures).
+ friend class Stream;
+
+ // This is the main entry point into the magic. Packs the parameters (which
+ // must type check against the class template) into the args and sizes
+ // arrays.
+ //
+ // Const refs are taken as parameters on all of the handlers to avoid
+ // implicit type promotion of integers.
+ void PackParams(std::vector<KernelArg> *args, Params... params) const {
+ PackOneParam(args, params...);
+ }
+
+ template <typename T, typename... RestOfParams>
+ void PackOneParam(std::vector<KernelArg> *args, const T &arg,
+ const RestOfParams... rest) const {
+ PackOneParam(args, arg);
+ PackOneParam(args, rest...);
+ }
+
+ // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
+ // The enable_if<> is for excluding DeviceMemoryBase args, which have a
+ // separate implementation below.
+ template <typename T>
+ void PackOneParam(
+ std::vector<KernelArg> *args, const T &arg,
+ typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
+ !IsDeviceMemoryPointer<T>::value &&
+ !IsSharedDeviceMemory<T>::value>::type * =
+ nullptr) const {
+ static_assert(!std::is_pointer<T>::value,
+ "cannot pass raw pointer to the device");
+ static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
+ "cannot pass device memory as a normal value");
+ const uint8 *arg_ptr = reinterpret_cast<const uint8 *>(&arg);
+ args->emplace_back(KernelArg{
+ KernelArg::kNormal,
+ port::InlinedVector<uint8, 4>{arg_ptr, arg_ptr + sizeof(arg)}, sizeof(arg)});
+ }
+
+ // DeviceMemoryBase family reference override.
+ template <typename T>
+ void PackOneParam(
+ std::vector<KernelArg> *args, const T &arg,
+ typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
+ nullptr) const {
+ args->emplace_back(parent()->DeviceMemoryToKernelArg(arg));
+ }
+
+ // DeviceMemoryBase family pointer override.
+ template <typename T>
+ void PackOneParam(
+ std::vector<KernelArg> *args, T arg,
+ typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
+ nullptr) const {
+ DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
+ args->emplace_back(parent()->DeviceMemoryToKernelArg(*ptr));
+ }
+
+ // Dynamic shared device memory has a size, but no associated allocation on
+ // the host; internally, the device will allocate storage.
+ template <typename T>
+ void PackOneParam(
+ std::vector<KernelArg> *args, T arg,
+ typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
+ nullptr) const {
+ args->emplace_back(KernelArg{KernelArg::kSharedMemory,
+ port::InlinedVector<uint8, 4>(), arg.size()});
+ }
+
+ // Base case for variadic template expansion - nothing to do!
+ void PackOneParam(std::vector<KernelArg> *args) const {}
+
+ SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
+};
+
+// Template metaprogramming helper type that helps us produce better error
+// messages at compile time when the are mismatches between the parameter
+// type list and the argument type list.
+template <typename ParamTuple, typename ArgTuple>
+struct KernelInvocationChecker {
+ // Whether the parameter tuple and argument tuple match in length.
+ static constexpr bool kLengthMatches =
+ std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;
+
+ // The (matching) length of the parameters and arguments type lists.
+ static constexpr int kTupleLength =
+ static_cast<int>(std::tuple_size<ArgTuple>::value);
+
+ // Helper trait to say whether the parameter wants a DeviceMemory-reference
+ // compatible type. This is for inexact type matches, so that it doesn't have
+ // to be precisely a const DeviceMemory<T>&, but can also be a value that
+ // represents the same.
+ template <typename ParamType, typename ArgType>
+ struct IsCompatibleDeviceMemoryRef {
+ static constexpr bool value = false;
+ };
+
+ // See type trait definition above.
+ template <typename U>
+ struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
+ static constexpr bool value = true;
+ };
+
+ // See type trait definition above.
+ template <typename U>
+ struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
+ SharedDeviceMemory<U>> {
+ static constexpr bool value = true;
+ };
+
+ // Returns whether ParamT and ArgT are compatible for data parallel kernel
+ // parameter packing without any assert functionality.
+ template <typename ParamT, typename ArgT>
+ static constexpr bool CompatibleNoAssert() {
+ return std::is_same<typename std::remove_const<ParamT>::type,
+ ArgT>::value ||
+ IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
+ }
+
+ // Checks whether ParamT and ArgT are compatible for data parallel kernel
+ // parameter packing. kArgumentNumber is unused, it just for error display.
+ //
+ // NOTE: if you encounter an error here, you can see the mismatch by looking
+ // at the end of the last error message, which will be of the form:
+ //
+ // ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &,
+ // perftools::gputools::DeviceMemory<AnotherThing>, true,
+ // 0>'
+ // requested here
+ //
+ // This means that the 0th argument you passed to the kernel invocation should
+ // have been DeviceMemory<OneThing> but was observed to be
+ // DeviceMemory<AnotherThing>.
+ template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
+ int kArgumentNumber>
+ static constexpr bool Compatible() {
+ static_assert(
+ kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
+ "parameter type (LHS) is not compatible with argument type (RHS)");
+ return CompatibleNoAssert<ParamT, ArgT>();
+ }
+
+ // Checks the parameter/argument match at kArgumentNumber for an out of bounds
+ // argument number.
+ //
+ // This is the base case: we've run out of argument to check, so we're all
+ // good.
+ template <int kArgumentNumber, bool kShouldStaticAssert>
+ static constexpr bool CheckParam(
+ typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
+ return true;
+ }
+
+ // Checks the parameter/argument match at kArgumentNumber.
+ // kShouldStaticAssert determines whether to assert out on a mismatch, or just
+ // yield the constexpr boolean value.
+ template <int kArgumentNumber, bool kShouldStaticAssert>
+ static constexpr bool CheckParam(
+ typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
+ typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
+ ParamT;
+ typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
+ return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
+ CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
+ }
+
+ // Checks the parameters/arguments for match, but doesn't static assert out.
+ // This is useful for testing/inspecting whether a set of parameters match in
+ // things like tests.
+ static constexpr bool CheckAllNoStaticAssert() {
+ return kLengthMatches && CheckParam<kTupleLength - 1, false>();
+ }
+
+ // Checks the parameters and static asserts out with a helpful error message
+ // (and useful template parameters in the instantiation stack) if there is an
+ // error.
+ static constexpr bool CheckAllStaticAssert() {
+ static_assert(kLengthMatches,
+ "argument length mismatched against typed kernel parameters");
+ return kLengthMatches && CheckParam<kTupleLength - 1, true>();
+ }
+};
+
+// This is a convenience type for checking whether a typed kernel matches
+// against a type list.
+template <typename KernelT, typename... Params>
+struct KernelParamsOk {
+ static constexpr bool kResult = false;
+};
+
+// See above.
+template <typename... Params, typename... Args>
+struct KernelParamsOk<TypedKernel<Params...>, Args...> {
+ static constexpr bool kResult = KernelInvocationChecker<
+ std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
diff --git a/tensorflow/stream_executor/kernel_cache_config.h b/tensorflow/stream_executor/kernel_cache_config.h
new file mode 100644
index 0000000000..9675d2940c
--- /dev/null
+++ b/tensorflow/stream_executor/kernel_cache_config.h
@@ -0,0 +1,29 @@
+// This file contains declarations relating to kernel cache configuration
+// parameters recognized by the StreamExecutor.
+#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
+#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
+
+namespace perftools {
+namespace gputools {
+
+// This enum represents potential configurations of L1/shared memory when
+// running a particular kernel. These values represent user preference, and
+// the runtime is not required to respect these choices.
+enum class KernelCacheConfig {
+ // Indicates no preference for device L1/shared memory configuration.
+ kNoPreference,
+
+ // Indicates a preference for more shared memory than L1 cache.
+ kPreferShared,
+
+ // Indicates a preference for more L1 cache than shared memory.
+ kPreferL1,
+
+ // Indicates a preference for equal amounts of L1 cache and shared memory.
+ kPreferEqual,
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc
new file mode 100644
index 0000000000..e3b4b0d951
--- /dev/null
+++ b/tensorflow/stream_executor/kernel_spec.cc
@@ -0,0 +1,236 @@
+#include "tensorflow/stream_executor/kernel_spec.h"
+
+
+namespace perftools {
+namespace gputools {
+
+KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname)
+ : kernelname_(kernelname.ToString()) {}
+
+OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(port::StringPiece filename,
+ port::StringPiece kernelname)
+ : KernelLoaderSpec(kernelname), filename_(filename.ToString()) {}
+
+CudaPtxOnDisk::CudaPtxOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname)
+ : OnDiskKernelLoaderSpec(filename, kernelname) {}
+
+CudaCubinOnDisk::CudaCubinOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname)
+ : OnDiskKernelLoaderSpec(filename, kernelname) {}
+
+CudaCubinInMemory::CudaCubinInMemory(const char *bytes,
+ port::StringPiece kernelname)
+ : KernelLoaderSpec(kernelname), bytes_(bytes) {}
+
+bool CompareComputeCapability(const std::tuple<int, int> &lhs,
+ const std::tuple<int, int> &rhs) {
+ return std::get<0>(lhs) < std::get<0>(rhs) ||
+ (std::get<0>(lhs) == std::get<0>(rhs) &&
+ std::get<1>(lhs) < std::get<1>(rhs));
+}
+
+const std::tuple<int, int> CudaPtxInMemory::kMinimumCapability{1, 0};
+
+CudaPtxInMemory::CudaPtxInMemory(port::StringPiece ptx,
+ port::StringPiece kernel_name,
+ bool ptx_compressed)
+ : KernelLoaderSpec(kernel_name),
+ ptx_by_compute_capability_(CompareComputeCapability) {
+ if (ptx_compressed) {
+ // Lazy decompression. Put an empty string in decompressed_ptx_ showing that
+ // the original ptx is compressed.
+ decompressed_ptx_[ptx.data()] = "";
+ }
+ ptx_by_compute_capability_[kMinimumCapability] = ptx.data();
+}
+
+CudaPtxInMemory::CudaPtxInMemory(
+ const std::initializer_list<CudaPtxInMemory::PtxSpec> &spec_list,
+ port::StringPiece kernel_name, bool ptx_compressed)
+ : KernelLoaderSpec(kernel_name),
+ ptx_by_compute_capability_(CompareComputeCapability) {
+ for (const auto &spec : spec_list) {
+ int major, minor;
+ port::StringPiece ptx;
+ std::tie(major, minor, ptx) = spec;
+ if (ptx_compressed) {
+ // Lazy decompression. Put an empty string in decompressed_ptx_ showing
+ // that the original ptx is compressed.
+ decompressed_ptx_[ptx.data()] = "";
+ }
+ ptx_by_compute_capability_[std::tuple<int, int>{major, minor}] = ptx.data();
+ }
+}
+
+string CudaPtxInMemory::DecompressPtx(const char *ptx) {
+ // Get the length of the PTX string from the beginning of the buffer.
+ uint64 ptx_length = *reinterpret_cast<const uint64 *>(ptx);
+ // Get the PTX string from the buffer with offset and length.
+ string compressed_ptx(ptx + sizeof(uint64),
+ ptx + sizeof(uint64) + ptx_length);
+ string decompressed_ptx;
+ // Decompress the PTX string with bzip2.
+ LOG(FATAL) << "bzip2 decompression is not supported yet.";
+ return decompressed_ptx;
+}
+
+const char *CudaPtxInMemory::default_text() const {
+ if (ptx_by_compute_capability_.empty()) {
+ return nullptr;
+ }
+
+ mutex_lock lock{mu_};
+
+ auto ptx = ptx_by_compute_capability_.begin()->second;
+ // Check if there is an entry in decompressed ptx table.
+ auto decompressed_ptx_iter = decompressed_ptx_.find(ptx);
+ if (decompressed_ptx_iter != decompressed_ptx_.end()) {
+ // If the decompressed string is empty, which means the ptx hasn't been
+ // decompressed, decompress it here.
+ if (decompressed_ptx_iter->second.size() == 0) {
+ decompressed_ptx_iter->second = DecompressPtx(ptx);
+ }
+ return decompressed_ptx_iter->second.c_str();
+ }
+ return ptx;
+}
+
+const char *CudaPtxInMemory::original_default_text() const {
+ if (ptx_by_compute_capability_.empty()) {
+ return nullptr;
+ }
+
+ return ptx_by_compute_capability_.begin()->second;
+}
+
+const char *CudaPtxInMemory::text(int compute_capability_major,
+ int compute_capability_minor) const {
+ std::tuple<int, int> capability{compute_capability_major,
+ compute_capability_minor};
+
+ auto ptx_iter = ptx_by_compute_capability_.find(capability);
+ if (ptx_iter == ptx_by_compute_capability_.end()) {
+ return nullptr;
+ }
+
+ mutex_lock lock{mu_};
+
+ // Check if there is an entry in decompressed ptx table.
+ auto decompressed_ptx_iter = decompressed_ptx_.find(ptx_iter->second);
+ if (decompressed_ptx_iter != decompressed_ptx_.end()) {
+ // If the decompressed string is empty, which means the ptx hasn't been
+ // decompressed, decompress it here.
+ if (decompressed_ptx_iter->second.size() == 0) {
+ decompressed_ptx_iter->second = DecompressPtx(ptx_iter->second);
+ }
+ return decompressed_ptx_iter->second.c_str();
+ }
+ return ptx_iter->second;
+}
+
+const char *CudaPtxInMemory::original_text(int compute_capability_major,
+ int compute_capability_minor) const {
+ std::tuple<int, int> capability{compute_capability_major,
+ compute_capability_minor};
+
+ auto ptx_iter = ptx_by_compute_capability_.find(capability);
+ if (ptx_iter == ptx_by_compute_capability_.end()) {
+ return nullptr;
+ }
+
+ return ptx_iter->second;
+}
+
+OpenCLTextOnDisk::OpenCLTextOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname)
+ : OnDiskKernelLoaderSpec(filename, kernelname) {}
+
+OpenCLTextInMemory::OpenCLTextInMemory(port::StringPiece text,
+ port::StringPiece kernelname)
+ : KernelLoaderSpec(kernelname), text_(text.ToString()) {}
+
+OpenCLBinaryOnDisk::OpenCLBinaryOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname)
+ : OnDiskKernelLoaderSpec(filename, kernelname) {}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddOpenCLTextOnDisk(
+ port::StringPiece filename, port::StringPiece kernelname) {
+ CHECK(ocl_text_on_disk_ == nullptr);
+ ocl_text_on_disk_.reset(new OpenCLTextOnDisk{filename, kernelname});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddOpenCLBinaryOnDisk(
+ port::StringPiece filename, port::StringPiece kernelname) {
+ CHECK(ocl_binary_on_disk_ == nullptr);
+ ocl_binary_on_disk_.reset(new OpenCLBinaryOnDisk{filename, kernelname});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddOpenCLTextInMemory(
+ port::StringPiece filename, port::StringPiece kernelname) {
+ CHECK(ocl_text_in_memory_ == nullptr);
+ ocl_text_in_memory_.reset(new OpenCLTextInMemory{filename, kernelname});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxOnDisk(
+ port::StringPiece filename, port::StringPiece kernelname) {
+ CHECK(cuda_ptx_on_disk_ == nullptr);
+ cuda_ptx_on_disk_.reset(new CudaPtxOnDisk{filename, kernelname});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCubinInMemory(
+ const char *bytes, port::StringPiece kernelname) {
+ CHECK(cuda_cubin_in_memory_ == nullptr);
+ cuda_cubin_in_memory_.reset(new CudaCubinInMemory{bytes, kernelname});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCubinOnDisk(
+ port::StringPiece filename, port::StringPiece kernelname) {
+ CHECK(cuda_cubin_on_disk_ == nullptr);
+ cuda_cubin_on_disk_.reset(new CudaCubinOnDisk{filename, kernelname});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory(
+ port::StringPiece ptx, port::StringPiece kernelname) {
+ CHECK(cuda_ptx_in_memory_ == nullptr);
+ cuda_ptx_in_memory_.reset(
+ new CudaPtxInMemory{ptx, kernelname, false /* ptx_compressed */});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory(
+ port::StringPiece ptx, port::StringPiece kernelname) {
+ CHECK(cuda_ptx_in_memory_ == nullptr);
+ cuda_ptx_in_memory_.reset(
+ new CudaPtxInMemory{ptx, kernelname, true /* ptx_compressed */});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaPtxInMemory(
+ std::initializer_list<CudaPtxInMemory::PtxSpec> spec_list,
+ port::StringPiece kernelname) {
+ CHECK(cuda_ptx_in_memory_ == nullptr);
+ cuda_ptx_in_memory_.reset(
+ new CudaPtxInMemory{spec_list, kernelname, false /* ptx_compressed */});
+ return this;
+}
+
+MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory(
+ std::initializer_list<CudaPtxInMemory::PtxSpec> spec_list,
+ port::StringPiece kernelname) {
+ CHECK(cuda_ptx_in_memory_ == nullptr);
+ cuda_ptx_in_memory_.reset(
+ new CudaPtxInMemory{spec_list, kernelname, true /* ptx_compressed */});
+ return this;
+}
+
+MultiKernelLoaderSpec::MultiKernelLoaderSpec(size_t arity) : arity_(arity) {}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/kernel_spec.h b/tensorflow/stream_executor/kernel_spec.h
new file mode 100644
index 0000000000..01a47ac253
--- /dev/null
+++ b/tensorflow/stream_executor/kernel_spec.h
@@ -0,0 +1,365 @@
+// Kernel-loader specs are structures that describe how to load a data-parallel
+// kernel on a given platform for subsequent launching. Headers that instantiate
+// these data structures will typically be auto-generated. However, users can
+// also instantiate them by hand.
+//
+// A kernel with the same exact functionality and type signature may be
+// implemented on several different platforms. Typical usage is to create a
+// singleton that describes how to load a kernel on the various supported
+// platforms:
+//
+// static const MultiKernelLoaderSpec &SaxpySpec() {
+// static auto *mkls =
+// (new MultiKernelLoaderSpec{4 /* = arity */})
+// ->AddCudaPtxOnDisk(ptx_file_path, ptx_kernelname)
+// ->AddOpenCLTextOnDisk(opencl_text_file_path, ocl_kernelname);
+// };
+//
+// return *mkls;
+// }
+//
+// This lazily instantiates an object that describes how to load CUDA PTX
+// present on disk that implements saxpy for the for the CUDA platform, or
+// OpenCL text present on disk that implements saxpy for an OpenCL-based
+// platform. The CudaPtxOnDisk and OpenCLTextOnDisk objects are subtypes of
+// KernelLoaderSpec -- KernelLoaderSpec describes how to load a kernel for
+// subsequent launching on a single platform.
+//
+// For the loader functionality that accepts these KernelLoaderSpecs in order
+// to grab the kernel appropriately, see StreamExecutor::GetKernel().
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_
+#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_
+
+#include <stddef.h>
+#include <map>
+#include <memory>
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+// Describes how to load a kernel on a target platform.
+//
+// This is an abstract base class, subclassed for specific platforms.
+// The filename_or_text field represents the program location (i.e. PTX or
+// OpenCL loadable translation unit path) and is simply stored; whether it is a
+// filename or text is exposed via more specifically named accessors in
+// subclasses.
+//
+// These kernel loader specifications are typically auto-generated into header
+// files at build time, but can also be specified manually.
+class KernelLoaderSpec {
+ public:
+ virtual ~KernelLoaderSpec() {}
+
+ // Returns the kernel name to load out of the program.
+ const string &kernelname() const { return kernelname_; }
+
+ protected:
+ explicit KernelLoaderSpec(port::StringPiece kernelname);
+
+ private:
+ // The kernel name that should be loaded out of the program description given
+ // above.
+ string kernelname_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(KernelLoaderSpec);
+};
+
+// An abstract kernel loader spec that has an associated file path, where
+// there's a canonical suffix for the filename; e.g. see CudaPtxOnDisk whose
+// canonical filename suffix is ".ptx".
+class OnDiskKernelLoaderSpec : public KernelLoaderSpec {
+ public:
+ ~OnDiskKernelLoaderSpec() override {}
+
+ // Returns the path to the on-disk loadable kernel file.
+ const string &filename() const { return filename_; }
+
+ // Returns the canonical suffix for this on-disk kernel loader spec format;
+ // e.g. PTX files on disk have a canonical suffix of ".ptx".
+ virtual const char *CanonicalSuffix() const = 0;
+
+ protected:
+ OnDiskKernelLoaderSpec(port::StringPiece filename,
+ port::StringPiece kernelname);
+
+ string filename_;
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(OnDiskKernelLoaderSpec);
+};
+
+// Kernel loader specification for PTX text that resides on disk.
+class CudaPtxOnDisk : public OnDiskKernelLoaderSpec {
+ public:
+ CudaPtxOnDisk(port::StringPiece filename, port::StringPiece kernelname);
+ ~CudaPtxOnDisk() override {}
+
+ const char *CanonicalSuffix() const override { return ".ptx"; }
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(CudaPtxOnDisk);
+};
+
+// Kernel loader specification for CUBIN binary that resides on disk.
+class CudaCubinOnDisk : public OnDiskKernelLoaderSpec {
+ public:
+ CudaCubinOnDisk(port::StringPiece filename, port::StringPiece kernelname);
+ ~CudaCubinOnDisk() override {}
+
+ const string &filename() const { return filename_; }
+
+ const char *CanonicalSuffix() const override { return ".cubin"; }
+
+ private:
+ string filename_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CudaCubinOnDisk);
+};
+
+// Kernel loader specification for PTX text that resides in memory.
+class CudaPtxInMemory : public KernelLoaderSpec {
+ public:
+ // Components: compute capability major number, compute capability minor
+ // number, and PTX source.
+ typedef std::tuple<int, int, port::StringPiece> PtxSpec;
+
+ // Single-PTX constructor. Adds the provided PTX version with an unknown
+ // compute capability. Since the CC is unknown, the PTX is assumed to be very
+ // generally usable - in other words, PTX specified in this manner is VERY
+ // likely to be used as the default! Note that the PTX can be compressed,
+ // which is indicated by the argument ptx_compressed.
+ //
+ // Warning: the string backing the provided port::StringPiece ptx must outlive this
+ // instance.
+ CudaPtxInMemory(port::StringPiece ptx, port::StringPiece kernelname,
+ bool ptx_compressed = false);
+
+ // Multiple-PTX-version constructor. Adds each item in spec_list to this
+ // object. Note that the PTX can be compressed, which is indicated by the
+ // argument ptx_compressed.
+ CudaPtxInMemory(const std::initializer_list<PtxSpec> &spec_list,
+ port::StringPiece kernel_name, bool ptx_compressed = false);
+ ~CudaPtxInMemory() override {}
+
+ // Add the PTX implementation described by ptx_spec to this object. On
+ // collision (i.e., if a version with the same compute_capability already
+ // exists), the existing implementation will be overwritten.
+ void AddSpec(PtxSpec ptx_spec);
+
+ // Returns pointer to the ptx of available implementation with the
+ // lowest-valued compute capability. For example, if PTX written to CC2.0,
+ // 3.0, and 3.5 are all available, the version for CC2.0 will be set. Returns
+ // nullptr on failed lookup (if any version is not available).
+ // When the ptx is compressed, returns the decompressed ptx.
+ const char *default_text() const;
+
+ // Similar to default_text().
+ // When the ptx is compressed, returns the decompressed ptx.
+ const char *original_default_text() const;
+
+ // Returns pointer to the ptx for the requested compute capability.
+ // Returns nullptr on failed lookup (if the requested version is not
+ // available).
+ // When the ptx is compressed, returns the decompressed ptx.
+ const char *text(int compute_capability_major,
+ int compute_capability_minor) const;
+
+ // Similar to text().
+ // When the ptx is compressed, returns the original compressed ptx.
+ const char *original_text(int compute_capability_major,
+ int compute_capability_minor) const;
+
+ // Decompresses the PTX string using bzip2.
+ static string DecompressPtx(const char *ptx);
+
+ private:
+ // PTX translation unit text contents in memory. The key is of as a tuple
+ // "<cc_major>,<cc_minor>", i.e., "2,0", "3,0", "3,5". Because CC's
+ // represented in this way have a clear sorting order, map::begin() will give
+ // the lowest-numbered version available, i.e. the default.
+ std::map<std::tuple<int, int>, const char *,
+ bool (*)(const std::tuple<int, int> &, const std::tuple<int, int> &)>
+ ptx_by_compute_capability_;
+
+ // Stores all decompressed ptx strings, with original ptx string as keys.
+ // It is marked as mutable for lazy decompression.
+ mutable std::map<const char *, string> decompressed_ptx_;
+ mutable mutex mu_;
+
+ // Defines the minimum compute capability possible. Used when PTX has no
+ // compute capability specified (in the single-PTX constructor).
+ static const std::tuple<int, int> kMinimumCapability;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CudaPtxInMemory);
+};
+
+// Kernel loader specification for OpenCL text that resides on disk.
+class OpenCLTextOnDisk : public OnDiskKernelLoaderSpec {
+ public:
+ OpenCLTextOnDisk(port::StringPiece filename, port::StringPiece kernelname);
+ ~OpenCLTextOnDisk() override {}
+
+ const char *CanonicalSuffix() const override { return ".ocl"; }
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(OpenCLTextOnDisk);
+};
+
+// Kernel loader specification for OpenCL binary that resides on disk.
+class OpenCLBinaryOnDisk : public OnDiskKernelLoaderSpec {
+ public:
+ OpenCLBinaryOnDisk(port::StringPiece filename, port::StringPiece kernelname);
+ ~OpenCLBinaryOnDisk() override {}
+
+ const char *CanonicalSuffix() const override { return ".aocx"; }
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(OpenCLBinaryOnDisk);
+};
+
+// Kernel loader specification for OpenCL text that resides in memory.
+class OpenCLTextInMemory : public KernelLoaderSpec {
+ public:
+ OpenCLTextInMemory(port::StringPiece text, port::StringPiece kernelname);
+ ~OpenCLTextInMemory() override {}
+
+ // Returns the OpenCL text contents.
+ const string &text() const { return text_; }
+
+ private:
+ // OpenCL translation unit text contents in memory.
+ string text_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(OpenCLTextInMemory);
+};
+
+// Kernel loader specification for a CUBIN blob that resides in memory.
+class CudaCubinInMemory : public KernelLoaderSpec {
+ public:
+ CudaCubinInMemory(const char *bytes, port::StringPiece kernelname);
+ ~CudaCubinInMemory() override {}
+
+ const char *bytes() const { return bytes_; }
+
+ private:
+ const char *bytes_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(CudaCubinInMemory);
+};
+
+// Describes how to load a kernel on any subset of a number of target platforms.
+class MultiKernelLoaderSpec {
+ public:
+ explicit MultiKernelLoaderSpec(size_t arity);
+
+ // Returns the number of arguments that this kernel accepts.
+ size_t arity() const { return arity_; }
+
+ // Convenience getters for testing whether these platform variants have
+ // kernel loader specifications available.
+ bool has_cuda_ptx_on_disk() const { return cuda_ptx_on_disk_ != nullptr; }
+ bool has_cuda_cubin_on_disk() const { return cuda_cubin_on_disk_ != nullptr; }
+ bool has_cuda_cubin_in_memory() const {
+ return cuda_cubin_in_memory_ != nullptr;
+ }
+ bool has_cuda_ptx_in_memory() const { return cuda_ptx_in_memory_ != nullptr; }
+ bool has_ocl_text_on_disk() const { return ocl_text_on_disk_ != nullptr; }
+ bool has_ocl_binary_on_disk() const { return ocl_binary_on_disk_ != nullptr; }
+ bool has_ocl_text_in_memory() const { return ocl_text_in_memory_ != nullptr; }
+
+ // Accessors for platform variant kernel load specifications.
+ // Precondition: corresponding has_* is true.
+ const CudaPtxOnDisk &cuda_ptx_on_disk() const {
+ CHECK(has_cuda_ptx_on_disk());
+ return *cuda_ptx_on_disk_;
+ }
+ const CudaCubinOnDisk &cuda_cubin_on_disk() const {
+ CHECK(has_cuda_cubin_on_disk());
+ return *cuda_cubin_on_disk_;
+ }
+ const CudaCubinInMemory &cuda_cubin_in_memory() const {
+ CHECK(has_cuda_cubin_in_memory());
+ return *cuda_cubin_in_memory_;
+ }
+ const CudaPtxInMemory &cuda_ptx_in_memory() const {
+ CHECK(has_cuda_ptx_in_memory());
+ return *cuda_ptx_in_memory_;
+ }
+ const OpenCLTextOnDisk &ocl_text_on_disk() const {
+ CHECK(has_ocl_text_on_disk());
+ return *ocl_text_on_disk_;
+ }
+ const OpenCLBinaryOnDisk &ocl_binary_on_disk() const {
+ CHECK(has_ocl_binary_on_disk());
+ return *ocl_binary_on_disk_;
+ }
+ const OpenCLTextInMemory &ocl_text_in_memory() const {
+ CHECK(has_ocl_text_in_memory());
+ return *ocl_text_in_memory_;
+ }
+
+ // Builder-pattern-like methods for use in initializing a
+ // MultiKernelLoaderSpec. Each of these should be used at most once for a
+ // single MultiKernelLoaderSpec object. See file comment for example usage.
+ //
+ // Note that the kernelname parameter must be consistent with the kernel in
+ // the PTX or OpenCL being loaded. Also be aware that in CUDA C++ the kernel
+ // name may be mangled by the compiler if it is not declared in an
+ // extern "C" scope.
+ MultiKernelLoaderSpec *AddOpenCLTextOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddOpenCLBinaryOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddOpenCLTextInMemory(port::StringPiece ocl_text,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddCudaPtxOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddCudaCubinOnDisk(port::StringPiece filename,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddCudaCubinInMemory(const char *cubin_bytes,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddCudaPtxInMemory(port::StringPiece ptx,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddCudaCompressedPtxInMemory(
+ port::StringPiece ptx, port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddCudaPtxInMemory(
+ std::initializer_list<CudaPtxInMemory::PtxSpec> spec_list,
+ port::StringPiece kernelname);
+ MultiKernelLoaderSpec *AddCudaCompressedPtxInMemory(
+ std::initializer_list<CudaPtxInMemory::PtxSpec> spec_list,
+ port::StringPiece kernelname);
+
+ private:
+ std::unique_ptr<CudaPtxOnDisk>
+ cuda_ptx_on_disk_; // PTX text that resides in a file.
+ std::unique_ptr<CudaCubinOnDisk>
+ cuda_cubin_on_disk_; // Binary CUDA program in a file.
+ std::unique_ptr<CudaCubinInMemory>
+ cuda_cubin_in_memory_; // Binary CUDA program in memory.
+ std::unique_ptr<CudaPtxInMemory>
+ cuda_ptx_in_memory_; // PTX text that resides in memory.
+ std::unique_ptr<OpenCLTextOnDisk>
+ ocl_text_on_disk_; // OpenCL text that resides on disk.
+ std::unique_ptr<OpenCLBinaryOnDisk>
+ ocl_binary_on_disk_; // OpenCL binary that resides on disk.
+ std::unique_ptr<OpenCLTextInMemory>
+ ocl_text_in_memory_; // OpenCL text that resides in memory.
+
+ // Number of parameters that the kernel takes. (This is nicer to have in a
+ // constexpr than having to determine it from the types via template
+ // metaprogramming).
+ size_t arity_;
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_
diff --git a/tensorflow/stream_executor/launch_dim.h b/tensorflow/stream_executor/launch_dim.h
new file mode 100644
index 0000000000..9b870ed6aa
--- /dev/null
+++ b/tensorflow/stream_executor/launch_dim.h
@@ -0,0 +1,65 @@
+// Types to express dimensionality of a kernel launch. Blocks and threads
+// are (up to) 3-dimensional.
+//
+// A thread is conceptually like a SIMD lane. Some number, typically 32
+// (though that fact should not be relied on) SIMD lanes are tied together with
+// a single PC in a unit called a warp. There is a maximum number of threads
+// that can execute in a shared-context entity called a block. Presently, that
+// number is 1024 -- again, something that should not be relied on from this
+// comment, but checked via perftools::gputools::DeviceDescription.
+//
+// For additional information, see
+// http://docs.nvidia.com/cuda/kepler-tuning-guide/#device-utilization-and-occupancy
+//
+// Because of that modest thread-per-block limit, a kernel can be launched with
+// multiple blocks. Each block is indivisibly scheduled onto a single core.
+// Blocks can also be used in a multi-dimensional configuration, and the block
+// count has much less modest limits -- typically they're similar to the maximum
+// amount of addressable memory.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LAUNCH_DIM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LAUNCH_DIM_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+// Basic type that represents a 3-dimensional index space.
+struct Dim3D {
+ uint64 x, y, z;
+
+ Dim3D(uint64 x, uint64 y, uint64 z) : x(x), y(y), z(z) {}
+};
+
+// Thread dimensionality for use in a kernel launch. See file comment for
+// details.
+struct ThreadDim : public Dim3D {
+ explicit ThreadDim(uint64 x = 1, uint64 y = 1, uint64 z = 1)
+ : Dim3D(x, y, z) {}
+
+ // Returns a string representation of the thread dimensionality.
+ string ToString() const {
+ return port::StrCat("ThreadDim{", x, ", ", y, ", ", z, "}");
+ }
+};
+
+// Block dimensionality for use in a kernel launch. See file comment for
+// details.
+struct BlockDim : public Dim3D {
+ explicit BlockDim(uint64 x = 1, uint64 y = 1, uint64 z = 1)
+ : Dim3D(x, y, z) {}
+
+ // Returns a string representation of the block dimensionality.
+ string ToString() const {
+ return port::StrCat("BlockDim{", x, ", ", y, ", ", z, "}");
+ }
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LAUNCH_DIM_H_
diff --git a/tensorflow/stream_executor/lib/array_slice.h b/tensorflow/stream_executor/lib/array_slice.h
new file mode 100644
index 0000000000..271b1c15a0
--- /dev/null
+++ b/tensorflow/stream_executor/lib/array_slice.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::gtl::ArraySlice;
+using tensorflow::gtl::MutableArraySlice;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
diff --git a/tensorflow/stream_executor/lib/casts.h b/tensorflow/stream_executor/lib/casts.h
new file mode 100644
index 0000000000..61ff2ab00e
--- /dev/null
+++ b/tensorflow/stream_executor/lib/casts.h
@@ -0,0 +1,85 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
+
+#include <stdlib.h>
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+// port::bit_cast<Dest,Source> is a template function that implements the
+// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in
+// very low-level functions like the protobuf library and fast math
+// support.
+//
+// float f = 3.14159265358979;
+// int i = port::bit_cast<int32>(f);
+// // i = 0x40490fdb
+//
+// The classical address-casting method is:
+//
+// // WRONG
+// float f = 3.14159265358979; // WRONG
+// int i = * reinterpret_cast<int*>(&f); // WRONG
+//
+// The address-casting method actually produces undefined behavior
+// according to ISO C++ specification section 3.10 -15 -. Roughly, this
+// section says: if an object in memory has one type, and a program
+// accesses it with a different type, then the result is undefined
+// behavior for most values of "different type".
+//
+// This is true for any cast syntax, either *(int*)&f or
+// *reinterpret_cast<int*>(&f). And it is particularly true for
+// conversions between integral lvalues and floating-point lvalues.
+//
+// The purpose of 3.10 -15- is to allow optimizing compilers to assume
+// that expressions with different types refer to different memory. gcc
+// 4.0.1 has an optimizer that takes advantage of this. So a
+// non-conforming program quietly produces wildly incorrect output.
+//
+// The problem is not the use of reinterpret_cast. The problem is type
+// punning: holding an object in memory of one type and reading its bits
+// back using a different type.
+//
+// The C++ standard is more subtle and complex than this, but that
+// is the basic idea.
+//
+// Anyways ...
+//
+// port::bit_cast<> calls memcpy() which is blessed by the standard,
+// especially by the example in section 3.9 . Also, of course,
+// port::bit_cast<> wraps up the nasty logic in one place.
+//
+// Fortunately memcpy() is very fast. In optimized mode, with a
+// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
+// code with the minimal amount of data movement. On a 32-bit system,
+// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
+// compiles to two loads and two stores.
+//
+// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
+//
+// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
+// is likely to surprise you.
+//
+// Props to Bill Gibbons for the compile time assertion technique and
+// Art Komninos and Igor Tandetnik for the msvc experiments.
+//
+// -- mec 2005-10-17
+
+template <class Dest, class Source>
+inline Dest bit_cast(const Source& source) {
+ // Compile time assertion: sizeof(Dest) == sizeof(Source)
+ // A compile error here means your Dest and Source have different sizes.
+ static_assert(sizeof(Dest) == sizeof(Source),
+ "src and dst types must have equal sizes");
+
+ Dest dest;
+ memcpy(&dest, &source, sizeof(dest));
+ return dest;
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
diff --git a/tensorflow/stream_executor/lib/demangle.cc b/tensorflow/stream_executor/lib/demangle.cc
new file mode 100644
index 0000000000..6b837b803a
--- /dev/null
+++ b/tensorflow/stream_executor/lib/demangle.cc
@@ -0,0 +1,38 @@
+#include "tensorflow/stream_executor/lib/demangle.h"
+
+#if (__GNUC__ >= 4 || (__GNUC__ >= 3 && __GNUC_MINOR__ >= 4)) && \
+ !defined(__mips__)
+# define HAS_CXA_DEMANGLE 1
+#else
+# define HAS_CXA_DEMANGLE 0
+#endif
+
+#include <stdlib.h>
+#if HAS_CXA_DEMANGLE
+#include <cxxabi.h>
+#endif
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+// The API reference of abi::__cxa_demangle() can be found in
+// libstdc++'s manual.
+// https://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-html-USERS-4.3/a01696.html
+string Demangle(const char *mangled) {
+ string demangled;
+ int status = 0;
+ char *result = NULL;
+#if HAS_CXA_DEMANGLE
+ result = abi::__cxa_demangle(mangled, NULL, NULL, &status);
+#endif
+ if (status == 0 && result != NULL) { // Demangling succeeeded.
+ demangled.append(result);
+ free(result);
+ }
+ return demangled;
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/lib/demangle.h b/tensorflow/stream_executor/lib/demangle.h
new file mode 100644
index 0000000000..0420f7101f
--- /dev/null
+++ b/tensorflow/stream_executor/lib/demangle.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_DEMANGLE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_DEMANGLE_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+string Demangle(const char* mangled);
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_DEMANGLE_H_
diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h
new file mode 100644
index 0000000000..74b50ad42d
--- /dev/null
+++ b/tensorflow/stream_executor/lib/env.h
@@ -0,0 +1,29 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_
+
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::Env;
+using tensorflow::ReadFileToString;
+using tensorflow::Thread;
+using tensorflow::WriteStringToFile;
+
+inline bool FileExists(const string& filename) {
+ return Env::Default()->FileExists(filename);
+}
+
+inline bool FileExists(const port::StringPiece& filename) {
+ return Env::Default()->FileExists(filename.ToString());
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_
diff --git a/tensorflow/stream_executor/lib/error.h b/tensorflow/stream_executor/lib/error.h
new file mode 100644
index 0000000000..376ddd3d07
--- /dev/null
+++ b/tensorflow/stream_executor/lib/error.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
+
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+namespace error = tensorflow::error;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
diff --git a/tensorflow/stream_executor/lib/human_readable.h b/tensorflow/stream_executor/lib/human_readable.h
new file mode 100644
index 0000000000..78df4a4a70
--- /dev/null
+++ b/tensorflow/stream_executor/lib/human_readable.h
@@ -0,0 +1,58 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_HUMAN_READABLE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_HUMAN_READABLE_H_
+
+#include <assert.h>
+#include <limits>
+
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+class HumanReadableNumBytes {
+ public:
+ static string ToString(int64 num_bytes) {
+ if (num_bytes == std::numeric_limits<int64>::min()) {
+ // Special case for number with not representable nagation.
+ return "-8E";
+ }
+
+ const char* neg_str = GetNegStr(&num_bytes);
+
+ // Special case for bytes.
+ if (num_bytes < 1024LL) {
+ // No fractions for bytes.
+ return port::Printf("%s%lldB", neg_str, num_bytes);
+ }
+
+ static const char units[] = "KMGTPE"; // int64 only goes up to E.
+ const char* unit = units;
+ while (num_bytes >= (1024LL) * (1024LL)) {
+ num_bytes /= (1024LL);
+ ++unit;
+ assert(unit < units + sizeof(units));
+ }
+
+ return port::Printf(((*unit == 'K') ? "%s%.1f%c" : "%s%.2f%c"), neg_str,
+ num_bytes / 1024.0, *unit);
+ }
+
+ private:
+ template <typename T>
+ static const char* GetNegStr(T* value) {
+ if (*value < 0) {
+ *value = -(*value);
+ return "-";
+ } else {
+ return "";
+ }
+ }
+};
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_HUMAN_READABLE_H_
diff --git a/tensorflow/stream_executor/lib/initialize.h b/tensorflow/stream_executor/lib/initialize.h
new file mode 100644
index 0000000000..d1832d6b26
--- /dev/null
+++ b/tensorflow/stream_executor/lib/initialize.h
@@ -0,0 +1,35 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INITIALIZE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_INITIALIZE_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE)
+#else
+
+#undef REGISTER_MODULE_INITIALIZER
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+class Initializer {
+ public:
+ typedef void (*InitializerFunc)();
+ explicit Initializer(InitializerFunc func) { func(); }
+};
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#define REGISTER_INITIALIZER(type, name, body) \
+ static void google_init_##type##_##name() { body; } \
+ perftools::gputools::port::Initializer google_initializer_##type##_##name( \
+ google_init_##type##_##name)
+
+#define REGISTER_MODULE_INITIALIZER(name, body) \
+ REGISTER_INITIALIZER(module, name, body)
+
+#endif // !defined(PLATFORM_GOOGLE)
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_INITIALIZE_H_
diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h
new file mode 100644
index 0000000000..e1f7a29904
--- /dev/null
+++ b/tensorflow/stream_executor/lib/inlined_vector.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
+
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::gtl::InlinedVector;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
diff --git a/tensorflow/stream_executor/lib/mathutil.h b/tensorflow/stream_executor/lib/mathutil.h
new file mode 100644
index 0000000000..dd3d37a19c
--- /dev/null
+++ b/tensorflow/stream_executor/lib/mathutil.h
@@ -0,0 +1,88 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_MATHUTIL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_MATHUTIL_H_
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+class MathUtil {
+ public:
+ template <typename IntegralType>
+ static IntegralType CeilOfRatio(IntegralType numerator,
+ IntegralType denominator) {
+ return CeilOrFloorOfRatio<IntegralType, true>(numerator, denominator);
+ }
+ template <typename IntegralType>
+ static IntegralType FloorOfRatio(IntegralType numerator,
+ IntegralType denominator) {
+ return CeilOrFloorOfRatio<IntegralType, false>(numerator, denominator);
+ }
+ template <typename IntegralType, bool ceil>
+ static IntegralType CeilOrFloorOfRatio(IntegralType numerator,
+ IntegralType denominator);
+};
+
+// ---- CeilOrFloorOfRatio ----
+// This is a branching-free, cast-to-double-free implementation.
+//
+// Casting to double is in general incorrect because of loss of precision
+// when casting an int64 into a double.
+//
+// There's a bunch of 'recipes' to compute a integer ceil (or floor) on the web,
+// and most of them are incorrect.
+template<typename IntegralType, bool ceil>
+IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator,
+ IntegralType denominator) {
+ static_assert(std::is_integral<IntegralType>::value,
+ "CeilOfRatio_is_only_defined_for_integral_types");
+ assert(denominator != 0);
+ // Dividing the smallest signed integer by -1 is not supported: it would
+ // SIGFPE
+ assert(!std::is_signed<IntegralType>::value ||
+ numerator != std::numeric_limits<IntegralType>::min() ||
+ denominator != -1);
+
+ const IntegralType rounded_toward_zero = numerator / denominator;
+ const IntegralType intermediate_product = rounded_toward_zero * denominator;
+
+ if (ceil) { // Compile-time condition: not an actual branching
+ // When rounded_toward_zero is negative, then an adjustment is never needed:
+ // the real ratio is negative, and so rounded toward zero is the ceil.
+ // When rounded_toward_zero is non-negative, an adjustment is needed if the
+ // sign of the difference numerator - intermediate_product is the same as
+ // the sign of the denominator.
+ //
+ // Using a bool and then a static_cast to IntegralType is not strictly
+ // necessary, but it makes the code clear, and anyway the compiler should
+ // get rid of it.
+ const bool needs_adjustment = (rounded_toward_zero >= 0) &&
+ ((denominator > 0 && numerator > intermediate_product) ||
+ (denominator < 0 && numerator < intermediate_product));
+ const IntegralType adjustment = static_cast<IntegralType>(needs_adjustment);
+ const IntegralType ceil_of_ratio = rounded_toward_zero + adjustment;
+ return ceil_of_ratio;
+ } else {
+ // Floor case: symmetrical to the previous one
+ const bool needs_adjustment = (rounded_toward_zero <= 0) &&
+ ((denominator > 0 && numerator < intermediate_product) ||
+ (denominator < 0 && numerator > intermediate_product));
+ const IntegralType adjustment = static_cast<IntegralType>(needs_adjustment);
+ const IntegralType floor_of_ratio = rounded_toward_zero - adjustment;
+ return floor_of_ratio;
+ }
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_MATHUTIL_H_
diff --git a/tensorflow/stream_executor/lib/notification.h b/tensorflow/stream_executor/lib/notification.h
new file mode 100644
index 0000000000..2baa458fc9
--- /dev/null
+++ b/tensorflow/stream_executor/lib/notification.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_NOTIFICATION_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_NOTIFICATION_H_
+
+#include "tensorflow/core/lib/core/notification.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::Notification;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NOTIFICATION_H_
diff --git a/tensorflow/stream_executor/lib/numbers.cc b/tensorflow/stream_executor/lib/numbers.cc
new file mode 100644
index 0000000000..a9981b0ce6
--- /dev/null
+++ b/tensorflow/stream_executor/lib/numbers.cc
@@ -0,0 +1,27 @@
+#include "tensorflow/stream_executor/lib/numbers.h"
+
+#include <stdlib.h>
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+bool safe_strto32(const char* str, int32* value) {
+ char* endptr;
+ *value = strtol(str, &endptr, 10); // NOLINT
+ if (endptr != str) {
+ while (isspace(*endptr)) ++endptr;
+ }
+ return *str != '\0' && *endptr == '\0';
+}
+
+// Convert strings to floating point values.
+// Leading and trailing spaces are allowed.
+// Values may be rounded on over- and underflow.
+bool safe_strto32(const string& str, int32* value) {
+ return port::safe_strto32(str.c_str(), value);
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/lib/numbers.h b/tensorflow/stream_executor/lib/numbers.h
new file mode 100644
index 0000000000..17b2893743
--- /dev/null
+++ b/tensorflow/stream_executor/lib/numbers.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_NUMBERS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_NUMBERS_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+// Convert strings to floating point values.
+// Leading and trailing spaces are allowed.
+// Values may be rounded on over- and underflow.
+bool safe_strto32(const string& str, int32* value);
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NUMBERS_H_
diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc
new file mode 100644
index 0000000000..a6e76e99b7
--- /dev/null
+++ b/tensorflow/stream_executor/lib/path.cc
@@ -0,0 +1,50 @@
+#include "tensorflow/stream_executor/lib/path.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+
+using ::perftools::gputools::port::StringPiece;
+using ::perftools::gputools::port::StrAppend;
+
+namespace perftools {
+namespace gputools {
+namespace port {
+namespace internal {
+
+static bool IsAbsolutePath(port::StringPiece path) {
+ return !path.empty() && path[0] == '/';
+}
+
+// For an array of paths of length count, append them all together,
+// ensuring that the proper path separators are inserted between them.
+string JoinPathImpl(std::initializer_list<port::StringPiece> paths) {
+ string result;
+
+ for (port::StringPiece path : paths) {
+ if (path.empty()) continue;
+
+ if (result.empty()) {
+ result = path.ToString();
+ continue;
+ }
+
+ if (result[result.size() - 1] == '/') {
+ if (IsAbsolutePath(path)) {
+ StrAppend(&result, path.substr(1));
+ } else {
+ StrAppend(&result, path);
+ }
+ } else {
+ if (IsAbsolutePath(path)) {
+ StrAppend(&result, path);
+ } else {
+ StrAppend(&result, "/", path);
+ }
+ }
+ }
+
+ return result;
+}
+
+} // namespace internal
+} // namespace port
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/lib/path.h b/tensorflow/stream_executor/lib/path.h
new file mode 100644
index 0000000000..1d648e8de1
--- /dev/null
+++ b/tensorflow/stream_executor/lib/path.h
@@ -0,0 +1,44 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_
+
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+namespace internal {
+// TODO(rspringer): Move to cc/implementation file.
+// Not part of the public API.
+string JoinPathImpl(std::initializer_list<port::StringPiece> paths);
+} // namespace internal
+
+// Join multiple paths together.
+// JoinPath unconditionally joins all paths together. For example:
+//
+// Arguments | JoinPath
+// ---------------------------+---------------------
+// '/foo', 'bar' | /foo/bar
+// '/foo/', 'bar' | /foo/bar
+// '/foo', '/bar' | /foo/bar
+// '/foo', '/bar', '/baz' | /foo/bar/baz
+//
+// All paths will be treated as relative paths, regardless of whether or not
+// they start with a leading '/'. That is, all paths will be concatenated
+// together, with the appropriate path separator inserted in between.
+// Arguments must be convertible to port::StringPiece.
+//
+// Usage:
+// string path = file::JoinPath("/var/log", dirname, filename);
+// string path = file::JoinPath(FLAGS_test_srcdir, filename);
+template <typename... T>
+inline string JoinPath(const T&... args) {
+ return internal::JoinPathImpl({args...});
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_
diff --git a/tensorflow/stream_executor/lib/process_state.cc b/tensorflow/stream_executor/lib/process_state.cc
new file mode 100644
index 0000000000..c20493b263
--- /dev/null
+++ b/tensorflow/stream_executor/lib/process_state.cc
@@ -0,0 +1,37 @@
+#include "tensorflow/stream_executor/lib/process_state.h"
+
+#include <unistd.h>
+
+#include <memory>
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+string Hostname() {
+ char hostname[1024];
+ gethostname(hostname, sizeof hostname);
+ hostname[sizeof hostname - 1] = 0;
+ return hostname;
+}
+
+bool GetCurrentDirectory(string* dir) {
+ size_t len = 128;
+ std::unique_ptr<char[]> a(new char[len]);
+ for (;;) {
+ char* p = getcwd(a.get(), len);
+ if (p != NULL) {
+ *dir = p;
+ return true;
+ } else if (errno == ERANGE) {
+ len += len;
+ a.reset(new char[len]);
+ } else {
+ return false;
+ }
+ }
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/lib/process_state.h b/tensorflow/stream_executor/lib/process_state.h
new file mode 100644
index 0000000000..b75879499b
--- /dev/null
+++ b/tensorflow/stream_executor/lib/process_state.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_PROCESS_STATE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_PROCESS_STATE_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+string Hostname();
+bool GetCurrentDirectory(string* dir);
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PROCESS_STATE_H_
diff --git a/tensorflow/stream_executor/lib/ptr_util.h b/tensorflow/stream_executor/lib/ptr_util.h
new file mode 100644
index 0000000000..d10d0bcb8c
--- /dev/null
+++ b/tensorflow/stream_executor/lib/ptr_util.h
@@ -0,0 +1,48 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+// Trait to select overloads and return types for MakeUnique.
+template <typename T>
+struct MakeUniqueResult {
+ using scalar = std::unique_ptr<T>;
+};
+template <typename T>
+struct MakeUniqueResult<T[]> {
+ using array = std::unique_ptr<T[]>;
+};
+template <typename T, size_t N>
+struct MakeUniqueResult<T[N]> {
+ using invalid = void;
+};
+
+// MakeUnique<T>(...) is an early implementation of C++14 std::make_unique.
+// It is designed to be 100% compatible with std::make_unique so that the
+// eventual switchover will be a simple renaming operation.
+template <typename T, typename... Args>
+typename MakeUniqueResult<T>::scalar MakeUnique(Args&&... args) { // NOLINT
+ return std::unique_ptr<T>(
+ new T(std::forward<Args>(args)...)); // NOLINT(build/c++11)
+}
+
+// Overload for array of unknown bound.
+// The allocation of arrays needs to use the array form of new,
+// and cannot take element constructor arguments.
+template <typename T>
+typename MakeUniqueResult<T>::array MakeUnique(size_t n) {
+ return std::unique_ptr<T>(new typename std::remove_extent<T>::type[n]());
+}
+
+// Reject arrays of known bound.
+template <typename T, typename... Args>
+typename MakeUniqueResult<T>::invalid MakeUnique(Args&&... /* args */) =
+ delete; // NOLINT
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_
diff --git a/tensorflow/stream_executor/lib/stacktrace.h b/tensorflow/stream_executor/lib/stacktrace.h
new file mode 100644
index 0000000000..e7d478efe3
--- /dev/null
+++ b/tensorflow/stream_executor/lib/stacktrace.h
@@ -0,0 +1,18 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STACKTRACE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STACKTRACE_H_
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+#if !defined(PLATFORM_GOOGLE)
+inline string CurrentStackTrace() { return "No stack trace available"; }
+#endif
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STACKTRACE_H_
diff --git a/tensorflow/stream_executor/lib/static_threadlocal.h b/tensorflow/stream_executor/lib/static_threadlocal.h
new file mode 100644
index 0000000000..9227b2cf0d
--- /dev/null
+++ b/tensorflow/stream_executor/lib/static_threadlocal.h
@@ -0,0 +1,30 @@
+// Copyright 2006 Google Inc.
+// All rights reserved.
+// Author: Yaz Saito (saito@google.com)
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATIC_THREADLOCAL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATIC_THREADLOCAL_H_
+
+// For POD types in TLS mode, s_obj_VAR is the thread-local variable.
+#define SE_STATIC_THREAD_LOCAL_POD(_Type_, _var_) \
+ static thread_local _Type_ s_obj_##_var_; \
+ namespace { \
+ class ThreadLocal_##_var_ { \
+ public: \
+ ThreadLocal_##_var_() {} \
+ void Init() {} \
+ inline _Type_ *pointer() const { \
+ return &s_obj_##_var_; \
+ } \
+ inline _Type_ *safe_pointer() const { \
+ return &s_obj_##_var_; \
+ } \
+ _Type_ &get() const { \
+ return s_obj_##_var_; \
+ } \
+ bool is_native_tls() const { return true; } \
+ private: \
+ SE_DISALLOW_COPY_AND_ASSIGN(ThreadLocal_##_var_); \
+ } _var_; \
+ }
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATIC_THREADLOCAL_H_
diff --git a/tensorflow/stream_executor/lib/status.h b/tensorflow/stream_executor/lib/status.h
new file mode 100644
index 0000000000..b3ad13b0ae
--- /dev/null
+++ b/tensorflow/stream_executor/lib/status.h
@@ -0,0 +1,23 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
+
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::Status;
+
+#define SE_CHECK_OK(val) \
+ CHECK_EQ(::perftools::gputools::port::Status::OK(), (val))
+#define SE_ASSERT_OK(val) \
+ ASSERT_EQ(::perftools::gputools::port::Status::OK(), (val))
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
diff --git a/tensorflow/stream_executor/lib/status_macros.h b/tensorflow/stream_executor/lib/status_macros.h
new file mode 100644
index 0000000000..7e1de92a98
--- /dev/null
+++ b/tensorflow/stream_executor/lib/status_macros.h
@@ -0,0 +1,54 @@
+// Helper macros for dealing with the port::Status datatype.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_MACROS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_MACROS_H_
+
+// Early-returns the status if it is in error; otherwise, proceeds.
+//
+// The argument expression is guaranteed to be evaluated exactly once.
+#define SE_RETURN_IF_ERROR(__status) \
+ do { \
+ auto status = __status; \
+ if (!status.ok()) { \
+ return status; \
+ } \
+ } while (false)
+
+// Identifier concatenation helper macros.
+#define SE_MACRO_CONCAT_INNER(__x, __y) __x##__y
+#define SE_MACRO_CONCAT(__x, __y) SE_MACRO_CONCAT_INNER(__x, __y)
+
+// Implementation of SE_ASSIGN_OR_RETURN that uses a unique temporary identifier
+// for avoiding collision in the enclosing scope.
+#define SE_ASSIGN_OR_RETURN_IMPL(__lhs, __rhs, __name) \
+ auto __name = (__rhs); \
+ if (!__name.ok()) { \
+ return __name.status(); \
+ } \
+ __lhs = __name.ConsumeValueOrDie();
+
+// Early-returns the status if it is in error; otherwise, assigns the
+// right-hand-side expression to the left-hand-side expression.
+//
+// The right-hand-side expression is guaranteed to be evaluated exactly once.
+#define SE_ASSIGN_OR_RETURN(__lhs, __rhs) \
+ SE_ASSIGN_OR_RETURN_IMPL(__lhs, __rhs, \
+ SE_MACRO_CONCAT(__status_or_value, __COUNTER__))
+
+// Logs the status and returns false if it is in error; otherwise, returns true.
+//
+// The argument expression is guaranteed to be evaluated exactly once.
+//
+// TODO(leary) remove as many of these as possible with port::Status
+// proliferation.
+#define SE_RETURN_STATUS_AS_BOOL(__status) \
+ do { \
+ auto status = __status; \
+ if (__status.ok()) { \
+ return true; \
+ } \
+ LOG(ERROR) << status; \
+ return false; \
+ } while (false)
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_MACROS_H_
diff --git a/tensorflow/stream_executor/lib/statusor.h b/tensorflow/stream_executor/lib/statusor.h
new file mode 100644
index 0000000000..38ce35e46e
--- /dev/null
+++ b/tensorflow/stream_executor/lib/statusor.h
@@ -0,0 +1,234 @@
+// Copyright 2008 Google Inc. All Rights Reserved.
+// Author: acm@google.com (Andrew Morrow)
+// Author: zhengxq@google.com (Xiaoqiang Zheng)
+//
+// StatusOr<T> is the union of a Status object and a T
+// object. StatusOr models the concept of an object that is either a
+// usable value, or an error Status explaining why such a value is
+// not present. To this end, StatusOr<T> does not allow its Status
+// value to be Status::OK. Further, StatusOr<T*> does not allow the
+// contained pointer to be NULL.
+//
+// The primary use-case for StatusOr<T> is as the return value of a
+// function which may fail.
+//
+// Example client usage for a StatusOr<T>, where T is not a pointer:
+//
+// StatusOr<float> result = DoBigCalculationThatCouldFail();
+// if (result.ok()) {
+// float answer = result.ValueOrDie();
+// printf("Big calculation yielded: %f", answer);
+// } else {
+// LOG(ERROR) << result.status();
+// }
+//
+// Example client usage for a StatusOr<T*>:
+//
+// StatusOr<Foo*> result = FooFactory::MakeNewFoo(arg);
+// if (result.ok()) {
+// std::unique_ptr<Foo> foo(result.ValueOrDie());
+// foo->DoSomethingCool();
+// } else {
+// LOG(ERROR) << result.status();
+// }
+//
+// Example client usage for a StatusOr<std::unique_ptr<T>>:
+//
+// StatusOr<std::unique_ptr<Foo>> result = FooFactory::MakeNewFoo(arg);
+// if (result.ok()) {
+// std::unique_ptr<Foo> foo = result.ConsumeValueOrDie();
+// foo->DoSomethingCool();
+// } else {
+// LOG(ERROR) << result.status();
+// }
+//
+// Example factory implementation returning StatusOr<T*>:
+//
+// StatusOr<Foo*> FooFactory::MakeNewFoo(int arg) {
+// if (arg <= 0) {
+// return Status(port::error::INVALID_ARGUMENT,
+// "Arg must be positive");
+// } else {
+// return new Foo(arg);
+// }
+// }
+//
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
+
+#include <new>
+#include "tensorflow/stream_executor/platform/port.h"
+#include <type_traits>
+#include <utility>
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+template<typename T>
+class StatusOr {
+ template<typename U> friend class StatusOr;
+
+ public:
+ // Construct a new StatusOr with Status::UNKNOWN status
+ StatusOr() : status_(error::UNKNOWN, "") {}
+
+ // Construct a new StatusOr with the given non-ok status. After calling
+ // this constructor, calls to ValueOrDie() is invalid.
+ //
+ // NOTE: Not explicit - we want to use StatusOr<T> as a return
+ // value, so it is convenient and sensible to be able to do 'return
+ // Status()' when the return type is StatusOr<T>.
+ //
+ // REQUIRES: status != Status::OK.
+ // In optimized builds, passing Status::OK here will have the effect
+ // of passing PosixErrorSpace::EINVAL as a fallback.
+ StatusOr(const Status& status); // NOLINT
+
+ // Construct a new StatusOr with the given value. If T is a plain pointer,
+ // value must not be NULL. After calling this constructor, calls to
+ // ValueOrDie() will succeed, and calls to status() will return OK.
+ //
+ // NOTE: Not explicit - we want to use StatusOr<T> as a return type
+ // so it is convenient and sensible to be able to do 'return T()'
+ // when when the return type is StatusOr<T>.
+ //
+ // REQUIRES: if T is a plain pointer, value != NULL.
+ // In optimized builds, passing a NULL pointer here will have
+ // the effect of passing PosixErrorSpace::EINVAL as a fallback.
+ StatusOr(const T& value); // NOLINT
+
+ // Conversion copy constructor, T must be copy constructible from U
+ template <typename U>
+ StatusOr(const StatusOr<U>& other) // NOLINT
+ : status_(other.status_),
+ value_(other.value_) {}
+
+ // Conversion assignment operator, T must be assignable from U
+ template <typename U>
+ StatusOr& operator=(const StatusOr<U>& other) {
+ status_ = other.status_;
+ value_ = other.value_;
+ return *this;
+ }
+
+ // Rvalue-reference overloads of the other constructors and assignment
+ // operators, to support move-only types and avoid unnecessary copying.
+ StatusOr(T&& value); // NOLINT
+
+ // Move conversion operator to avoid unecessary copy.
+ // T must be assignable from U.
+ // Not marked with explicit so the implicit conversion can happen.
+ template <typename U>
+ StatusOr(StatusOr<U>&& other) // NOLINT
+ : status_(std::move(other.status_)),
+ value_(std::move(other.value_)) {}
+
+ // Move assignment opeartor to avoid unnecessary copy.
+ // T must be assignable from U
+ template <typename U>
+ StatusOr& operator=(StatusOr<U>&& other) {
+ status_ = std::move(other.status_);
+ value_ = std::move(other.value_);
+ return *this;
+ }
+
+ // Returns a reference to our status. If this contains a T, then
+ // returns Status::OK.
+ const Status& status() const { return status_; }
+
+ // Returns this->status().ok()
+ bool ok() const { return status_.ok(); }
+
+ // Returns a reference to our current value, requires that this->ok().
+ // If you need to initialize a T object from the stored value,
+ // ConsumeValueOrDie() may be more efficient.
+ const T& ValueOrDie() const;
+
+ // Returns our current value, requires this->ok(). Use this if
+ // you would otherwise want to say std::move(s.ValueOrDie()), for example
+ // if you need to initialize a T object from the stored value and you don't
+ // need subsequent access to the stored value. It uses T's move constructor,
+ // if it has one, so it will work with move-only types, and will often be
+ // more efficient than ValueOrDie, but may leave the stored value
+ // in an arbitrary valid state.
+ T ConsumeValueOrDie();
+
+ private:
+ Status status_;
+ T value_;
+
+ void CheckValueNotNull(const T& value);
+
+ template <typename U>
+ struct IsNull {
+ // For non-pointer U, a reference can never be NULL.
+ static inline bool IsValueNull(const U& t) { return false; }
+ };
+
+ template <typename U>
+ struct IsNull<U*> {
+ static inline bool IsValueNull(const U* t) { return t == NULL; }
+ };
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// Implementation details for StatusOr<T>
+
+template <typename T>
+StatusOr<T>::StatusOr(const T& value)
+ : status_(), value_(value) {
+ CheckValueNotNull(value);
+}
+
+template <typename T>
+const T& StatusOr<T>::ValueOrDie() const {
+ assert(status_.ok());
+ return value_;
+}
+
+template <typename T>
+T StatusOr<T>::ConsumeValueOrDie() {
+ assert(status_.ok());
+ return std::move(value_);
+}
+
+template <typename T>
+StatusOr<T>::StatusOr(const Status& status)
+ : status_(status) {
+ assert(!status.ok());
+ if (status.ok()) {
+ status_ =
+ Status(error::INTERNAL,
+ "Status::OK is not a valid constructor argument to StatusOr<T>");
+ }
+}
+
+template <typename T>
+StatusOr<T>::StatusOr(T&& value)
+ : status_() {
+ CheckValueNotNull(value);
+ value_ = std::move(value);
+}
+
+template <typename T>
+void StatusOr<T>::CheckValueNotNull(const T& value) {
+ assert(!IsNull<T>::IsValueNull(value));
+ if (IsNull<T>::IsValueNull(value)) {
+ status_ =
+ Status(error::INTERNAL,
+ "NULL is not a valid constructor argument to StatusOr<T*>");
+ }
+}
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h
new file mode 100644
index 0000000000..021f54dfec
--- /dev/null
+++ b/tensorflow/stream_executor/lib/str_util.h
@@ -0,0 +1,30 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/stream_executor/lib/stringpiece.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::str_util::Join;
+using tensorflow::str_util::Split;
+
+// Returns a copy of the input string 'str' with the given 'suffix'
+// removed. If the suffix doesn't match, returns a copy of the original string.
+inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix) {
+ if (str.ends_with(suffix)) {
+ str.remove_suffix(suffix.size());
+ }
+ return str.ToString();
+}
+
+using tensorflow::str_util::Lowercase;
+using tensorflow::str_util::Uppercase;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_
diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h
new file mode 100644
index 0000000000..b3fe4da327
--- /dev/null
+++ b/tensorflow/stream_executor/lib/strcat.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
+
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::strings::StrCat;
+using tensorflow::strings::StrAppend;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/stream_executor/lib/stringpiece.h
new file mode 100644
index 0000000000..14e6fc99d7
--- /dev/null
+++ b/tensorflow/stream_executor/lib/stringpiece.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::StringPiece;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
diff --git a/tensorflow/stream_executor/lib/stringprintf.h b/tensorflow/stream_executor/lib/stringprintf.h
new file mode 100644
index 0000000000..379e7e9a83
--- /dev/null
+++ b/tensorflow/stream_executor/lib/stringprintf.h
@@ -0,0 +1,18 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPRINTF_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPRINTF_H_
+
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::strings::Printf;
+using tensorflow::strings::Appendf;
+using tensorflow::strings::Appendv;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPRINTF_H_
diff --git a/tensorflow/stream_executor/lib/thread_options.h b/tensorflow/stream_executor/lib/thread_options.h
new file mode 100644
index 0000000000..7d436578d6
--- /dev/null
+++ b/tensorflow/stream_executor/lib/thread_options.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_THREAD_OPTIONS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_THREAD_OPTIONS_H_
+
+#include "tensorflow/core/public/env.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::ThreadOptions;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_THREAD_OPTIONS_H_
diff --git a/tensorflow/stream_executor/lib/threadpool.h b/tensorflow/stream_executor/lib/threadpool.h
new file mode 100644
index 0000000000..3cf297d57b
--- /dev/null
+++ b/tensorflow/stream_executor/lib/threadpool.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_THREADPOOL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_LIB_THREADPOOL_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/thread_options.h"
+
+namespace perftools {
+namespace gputools {
+namespace port {
+
+using tensorflow::thread::ThreadPool;
+
+} // namespace port
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_THREADPOOL_H_
diff --git a/tensorflow/stream_executor/machine_manager.cc b/tensorflow/stream_executor/machine_manager.cc
new file mode 100644
index 0000000000..6d7bc50379
--- /dev/null
+++ b/tensorflow/stream_executor/machine_manager.cc
@@ -0,0 +1,276 @@
+#include "tensorflow/stream_executor/machine_manager.h"
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/dso_loader.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+mutex MachineManager::mu_{LINKER_INITIALIZED};
+
+MachineManager *MachineManager::singleton_ = nullptr;
+
+PlatformKind MachineManager::DetectPreferredPlatform() {
+// TODO(leary) for KNC card experiments, figure out a legitimate way to
+// determine this. For now, we use a compile-time hint so we can compile tests
+// for both.
+#if defined TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_PREFER_OPENCL
+ return PlatformKind::kOpenCL;
+#elif defined TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_PREFER_HOST
+ return PlatformKind::kHost;
+#else
+ return PlatformKind::kCuda;
+#endif
+}
+
+/* static */ port::StatusOr<std::unique_ptr<MachineManager>>
+MachineManager::Create(PlatformKind kind, DeviceOptions options,
+ const PluginConfig &config) {
+ std::unique_ptr<MachineManager> machine_manager{
+ new MachineManager{kind, options, config}};
+ auto init_status = machine_manager->Init();
+ if (!init_status.ok()) {
+ return init_status;
+ }
+
+ return std::move(machine_manager);
+}
+
+MachineManager::MachineManager(PlatformKind platform,
+ DeviceOptions device_options,
+ const PluginConfig &config)
+ : platform_(platform),
+ device_options_(device_options),
+ plugin_config_(config),
+ min_numa_node_(0),
+ limit_numa_node_(0) {}
+
+port::Status MachineManager::Init() {
+ // Initialize the first StreamExecutor, then use that platform interface to
+ // grab the device count.
+ executors_.resize(1);
+ executors_[0].reset(new StreamExecutor{platform_, plugin_config_});
+ auto status = executors_[0]->Init(0 /* = device_ordinal */, device_options_);
+ if (!status.ok()) {
+ return port::Status{
+ port::error::FAILED_PRECONDITION,
+ port::StrCat(
+ "failed to initialize StreamExecutor for device ordinal 0: ",
+ status.ToString())};
+ }
+ int device_count = executors_[0]->PlatformDeviceCount();
+ if (device_count == 0) {
+ LOG(WARNING) << "no devices found for platform "
+ << PlatformKindString(platform_);
+ min_numa_node_ = limit_numa_node_ = 0;
+ return port::Status::OK();
+ }
+
+ streams_.resize(device_count);
+ streams_[0].reset(new Stream(executors_[0].get()));
+ if (!streams_[0]->Init().ok()) {
+ return port::Status{
+ port::error::FAILED_PRECONDITION,
+ "failed to initialize default stream for device ordinal 0"};
+ }
+
+ min_numa_node_ = executors_[0]->GetDeviceDescription().numa_node();
+ limit_numa_node_ = min_numa_node_ + 1;
+
+ executors_.resize(device_count);
+ for (int device_ordinal = 1; device_ordinal < device_count;
+ ++device_ordinal) {
+ StreamExecutor *stream_exec = new StreamExecutor{platform_, plugin_config_};
+ executors_[device_ordinal].reset(stream_exec);
+ auto status = stream_exec->Init(device_ordinal, device_options_);
+ if (!status.ok()) {
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ port::StrCat(
+ "failed to initialize StreamExecutor for device ordinal ",
+ device_ordinal, ": ", status.ToString()));
+ }
+
+ min_numa_node_ = std::min(min_numa_node_,
+ stream_exec->GetDeviceDescription().numa_node());
+ limit_numa_node_ = std::max(
+ limit_numa_node_, stream_exec->GetDeviceDescription().numa_node() + 1);
+
+ if (!stream_exec->GetDeviceDescription().ecc_enabled()) {
+ LOG(WARNING) << "ECC not enabled for device ordinal: " << device_ordinal;
+ }
+
+ streams_[device_ordinal].reset(
+ new Stream(executors_[device_ordinal].get()));
+ if (!streams_[device_ordinal]->Init().ok()) {
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ port::StrCat(
+ "failed to initialize default stream for device ordinal ",
+ device_ordinal));
+ }
+ }
+
+ return port::Status::OK();
+}
+
+int MachineManager::device_count() const { return executors_.size(); }
+
+port::Status MachineManager::EnablePeerAccess() {
+ auto peer_access_map = GetPeerAccessMap();
+ for (const auto &access : *peer_access_map) {
+ auto devices = access.first;
+ if (access.second) {
+ StreamExecutor *from = executors_[devices.first].get();
+ StreamExecutor *to = executors_[devices.second].get();
+ auto status = from->EnablePeerAccessTo(to);
+ if (!status.ok()) {
+ return status;
+ }
+ } else {
+ LOG(INFO) << "cannot enable peer access from device ordinal "
+ << devices.first << " to device ordinal " << devices.second;
+ }
+ }
+ return port::Status::OK();
+}
+
+std::unique_ptr<std::map<std::pair<int, int>, bool>>
+MachineManager::GetPeerAccessMap() {
+ auto *map = new std::map<std::pair<int, int>, bool>;
+ for (int i = 0; i < device_count(); ++i) {
+ for (int j = 0; j < device_count(); ++j) {
+ StreamExecutor *from = executors_[i].get();
+ StreamExecutor *to = executors_[j].get();
+ (*map)[{i, j}] = from->CanEnablePeerAccessTo(to);
+ }
+ }
+
+ return std::unique_ptr<std::map<std::pair<int, int>, bool>>{map};
+}
+
+StreamExecutor *MachineManager::executor_for_device(int device_ordinal) const {
+ CHECK_GE(device_ordinal, 0) << "device ordinal must be non-negative";
+ CHECK(0 <= device_ordinal && device_ordinal < device_count())
+ << "device " << device_ordinal << " out of range with device count "
+ << device_count();
+ StreamExecutor *executor = executors_[device_ordinal].get();
+ CHECK(executor != nullptr);
+ return executor;
+}
+
+int MachineManager::ExecutorToBus(const StreamExecutor *stream_exec) const {
+ return stream_exec->GetDeviceDescription().numa_node() - min_numa_node_;
+}
+
+int MachineManager::DeviceToBus(int device_ordinal) const {
+ return ExecutorToBus(executor_for_device(device_ordinal));
+}
+
+int MachineManager::ExecutorToNumaNode(
+ const StreamExecutor *stream_exec) const {
+ return stream_exec->GetDeviceDescription().numa_node();
+}
+
+int MachineManager::DeviceToNumaNode(int device_ordinal) const {
+ return ExecutorToNumaNode(executor_for_device(device_ordinal));
+}
+
+StreamExecutor *MachineManager::first_executor_for_bus(int bus_ordinal) {
+ CHECK_LT(bus_ordinal, bus_count()) << "bus ordinal out of available range";
+ for (auto &executor : executors_) {
+ if (ExecutorToBus(executor.get()) == bus_ordinal) {
+ return executor.get();
+ }
+ }
+
+ LOG(WARNING) << "could not find executor requested for bus ordinal: "
+ << bus_ordinal;
+ return nullptr;
+}
+
+StreamExecutor *MachineManager::first_executor_for_numa_node(int numa_node) {
+ for (auto &executor : executors_) {
+ if (ExecutorToNumaNode(executor.get()) == numa_node) {
+ return executor.get();
+ }
+ }
+
+ LOG(WARNING) << "could not find executor requested for numa_node: "
+ << numa_node;
+ return nullptr;
+}
+
+Stream *MachineManager::stream_for_device(int device_ordinal) {
+ CHECK(0 <= device_ordinal && device_ordinal < device_count());
+ Stream *stream = streams_[device_ordinal].get();
+ CHECK(stream != nullptr);
+ return stream;
+}
+
+/* static */ port::StatusOr<MachineManager *>
+MachineManager::CreateSingletonInternal(PlatformKind platform,
+ DeviceOptions options,
+ const PluginConfig &config) {
+ if (singleton_ != nullptr) {
+ return port::Status{
+ port::error::ALREADY_EXISTS,
+ "cannot create machine manager singleton; one already exists"};
+ }
+
+ auto create_status = Create(platform, options, config);
+ if (!create_status.ok()) {
+ return create_status.status();
+ }
+
+ singleton_ = create_status.ConsumeValueOrDie().release();
+
+ VLOG(1) << "machine manager singleton is " << singleton_ << " with platform "
+ << PlatformKindString(platform) << " and device options "
+ << options.ToString();
+
+ return singleton_;
+}
+
+/* static */ MachineManager *MachineManager::CreateSingletonOrDie(
+ PlatformKind platform, DeviceOptions options, const PluginConfig &config) {
+ auto status = CreateSingleton(platform, options, config);
+ if (!status.ok()) {
+ LOG(FATAL) << "failed to create MachineManager singleton: "
+ << status.status();
+ }
+ return status.ValueOrDie();
+}
+
+/* static */ port::StatusOr<MachineManager *> MachineManager::CreateSingleton(
+ PlatformKind platform, DeviceOptions device_options,
+ const PluginConfig &config) {
+ mutex_lock lock{mu_};
+ return CreateSingletonInternal(platform, device_options, config);
+}
+
+/* static */ MachineManager *MachineManager::singleton() {
+ mutex_lock lock{mu_};
+ if (singleton_ == nullptr) {
+ PlatformKind platform = DetectPreferredPlatform();
+ DeviceOptions options = DeviceOptions::Default();
+ auto status = CreateSingletonInternal(platform, options, PluginConfig());
+ if (!status.ok()) {
+ LOG(FATAL)
+ << "failed to create MachineManager singleton: "
+ "singleton accessor attempted lazy construction but failed: "
+ << status.status();
+ }
+ return status.ValueOrDie();
+ }
+
+ return singleton_;
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/machine_manager.h b/tensorflow/stream_executor/machine_manager.h
new file mode 100644
index 0000000000..bcff7a9da0
--- /dev/null
+++ b/tensorflow/stream_executor/machine_manager.h
@@ -0,0 +1,197 @@
+// This interface provides a machine-wide resource management singleton
+// interface as a convenience for users who will want to exploit all of the GPU
+// resources present on the system.
+//
+// To use the singleton interface:
+//
+// // At start of program or in your module initializer.
+// // Do not call this with different sets of arguments!
+// MachineManager::CreateSingletonOrDie(
+// MachineManager::DetectPreferredPlatform(), DeviceOptions::Default());
+//
+// // At any point after that, this convenience interface avoids you having to
+// // pass those two parameters:
+// StreamExecutor *device0_executor =
+// MachineManager::singleton()->executor_for_device(0 /* = ordinal */);
+// ...
+
+// ----------------- THIS CLASS IS DEPRECATED - DO NOT USE ------------------
+// This class is not suitable for open-sourcing, as it does not support
+// plugins and depends on hardcoded PlatformKind enums. MultiPlatformManager and
+// Platform plugins are the replacements.
+// ----------------- THIS CLASS IS DEPRECATED - DO NOT USE ------------------
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_H_
+
+#include <map>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/stream_executor/device_options.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace perftools {
+namespace gputools {
+
+// MachineManager is used to instantiate and manage singleton resources for
+// all the GPUs present on a machine. This basically amounts to having a
+// StreamExecutor-per-device pool.
+//
+// Thread-safe.
+class MachineManager {
+ public:
+ // Inspects the host to determine the preferred GPU execution platform.
+ // To force OpenCL from a build target on a machine that has both OpenCL and
+ // CUDA capabilities, link against the :stream_executor_prefer_opencl target.
+ static PlatformKind DetectPreferredPlatform();
+
+ // Returns the machine manager singleton.
+ // If the singleton has not yet been created when this is invoked, this
+ // creates it with resonable default options, otherwise it returns the
+ // already-created singleton. If there are errors during creation, this call
+ // will terminate the program.
+ static MachineManager *singleton();
+
+ // Returns a singleton instance of the machine manager -- it's generally
+ // assumed that users will have one of these for a real-world application as a
+ // form of resource manager.
+ //
+ // This should only be called once, at the initialization of an application,
+ // if at all -- MachineManager::singleton() will return a value with sensible
+ // default as determined by DetectPreferredPlatform. Attempts to create the
+ // singleton with options multiple times will result in an error.
+ static port::StatusOr<MachineManager *> CreateSingleton(
+ PlatformKind platform, DeviceOptions device_options,
+ const PluginConfig &config = PluginConfig());
+
+ // Convenience "or die" wrapper around the above call.
+ static MachineManager *CreateSingletonOrDie(
+ PlatformKind platform, DeviceOptions device_options,
+ const PluginConfig &config = PluginConfig());
+
+ // Creates a new instantiation of the MachineManager.
+ // Warning: generally users will want to use the singleton form, see
+ // MachineManager::singleton().
+ //
+ // The machine manager has a number of devices that it detects on creation
+ // that does not change over the course of its lifetime. This does not support
+ // things like hot-plugging of GPUs or the event of GPUs dropping off the bus
+ // in a recoverable manner.
+ static port::StatusOr<std::unique_ptr<MachineManager>> Create(
+ PlatformKind kind, DeviceOptions options,
+ const PluginConfig &config = PluginConfig());
+
+ // Returns the number of devices visible to the machine manager.
+ int device_count() const;
+
+ // Returns the StreamExecutor for one of the machine-manager visible devices.
+ // Checks that device_ordinal is within device_count() bound.
+ StreamExecutor *executor_for_device(int device_ordinal) const;
+
+ // Returns the bus ordinal count (as determined by the span of NUMA nodes
+ // associated with the available devices).
+ int bus_count() const { return limit_numa_node_ - min_numa_node_; }
+
+ // Returns the bus ordinal associated with a given device ordinal.
+ int DeviceToBus(int device_ordinal) const;
+
+ // Returns the NUMA node associated with a given device ordinal.
+ int DeviceToNumaNode(int device_ordinal) const;
+
+ // Returns the first StreamExecutor (within device_count() ordinals that has
+ // the corresponding bus ordinal, or nullptr if none is found.
+ //
+ // The valid bus ordinals can be enumerated by scanning through the executors
+ // and seeing what bus number they are on.
+ StreamExecutor *first_executor_for_bus(int bus_ordinal);
+
+ // Returns the first StreamExecutor associated with the specified
+ // numa_node, or nullptr if none is found.
+ StreamExecutor *first_executor_for_numa_node(int numa_node);
+
+ // Returns the default stream for the default executor (that returned by
+ // executor_for_device()). The same stream will be returned for all calls to
+ // stream_for_device() (with the same device_ordinal).
+ Stream *stream_for_device(int device_ordinal);
+
+ // Returns the platform that this machine manager was created to target.
+ PlatformKind platform() const { return platform_; }
+
+ // Enables peer access between all possible devices on this platform.
+ // Only dies due to failure to enable peer access for devices in which
+ // GetPeerAccessMap() is true.
+ port::Status EnablePeerAccess();
+
+ // Returns a map that says, for pairs (device ordinal i, device ordinal j),
+ // whether i can access j's memory space.
+ std::unique_ptr<std::map<std::pair<int, int>, bool>> GetPeerAccessMap();
+
+ private:
+ // Guts of the singleton creation mechanism that requires the exclusive
+ // singleton lock to be held, in order to prevent deadlock due to method
+ // composition.
+ static port::StatusOr<MachineManager *> CreateSingletonInternal(
+ PlatformKind platform, DeviceOptions options, const PluginConfig &config)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Private constructor used in singleton creation.
+ MachineManager(PlatformKind platform, DeviceOptions options,
+ const PluginConfig &config);
+
+ // Populates the executors_ vector with an executor per observable device
+ // ordinal on the platform. Logs and returns false if any of the
+ // Stream Executors cannot be created.
+ port::Status Init();
+
+ // Converts a StreamExecutor's NUMA node association into a bus ordinal for
+ // this machine.
+ int ExecutorToBus(const StreamExecutor *stream_exec) const;
+
+ // Returns the NUMA node association for the StreamExecutor.
+ int ExecutorToNumaNode(const StreamExecutor *stream_exec) const;
+
+ // Mutex that guards the initialization of the machine manager static
+ // variable.
+ static mutex mu_;
+
+ // Singleton MachineManager value -- assignment to this is protected by a
+ // static singleton guard clause.
+ static MachineManager *singleton_ GUARDED_BY(mu_);
+
+ // Holds an executor associated with each device ordinal present in the
+ // system, which are the indices. Immutable after initialization.
+ std::vector<std::unique_ptr<StreamExecutor>> executors_;
+
+ // Holds an stream associated with each device ordinal present in the
+ // system, which are the indices. Immutable after initialization.
+ std::vector<std::unique_ptr<Stream>> streams_;
+
+ // The platform that this is managing for the machine.
+ PlatformKind platform_;
+
+ // Options used to create StreamExecutors on each of the respective devices.
+ DeviceOptions device_options_;
+
+ // Plugin configuration to use for all StreamExecutors created by this object.
+ PluginConfig plugin_config_;
+
+ // The smallest NUMA node value for any device managed by this machine
+ // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus
+ // ordinals. The NUMA node space occupied by GPUs is assumed to be dense.
+ int min_numa_node_;
+
+ // Larger than the NUMA node value for any device managed by this machine
+ // manager.
+ int limit_numa_node_;
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_H_
diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc
new file mode 100644
index 0000000000..a65add05c5
--- /dev/null
+++ b/tensorflow/stream_executor/multi_platform_manager.cc
@@ -0,0 +1,66 @@
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/str_util.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace perftools {
+namespace gputools {
+
+/* static */ mutex MultiPlatformManager::platforms_mutex_(LINKER_INITIALIZED);
+
+/* static */ port::Status MultiPlatformManager::RegisterPlatform(
+ std::unique_ptr<Platform> platform) {
+ CHECK(platform != nullptr);
+ string key = port::Lowercase(platform->Name());
+ mutex_lock lock(platforms_mutex_);
+ if (GetPlatformMap()->find(key) != GetPlatformMap()->end()) {
+ return port::Status(port::error::INTERNAL,
+ "platform is already registered with name: \"" +
+ platform->Name() + "\"");
+ }
+ GetPlatformByIdMap()->insert(std::make_pair(platform->id(), platform.get()));
+ // Release ownership/uniqueness to prevent destruction on program exit.
+ // This avoids Platforms "cleaning up" on program exit, because otherwise,
+ // there are _very_ tricky races between StreamExecutor and underlying
+ // platforms (CUDA, OpenCL) during exit. Since these are fixed-size and 1x per
+ // program, these are deemed acceptable.
+ (*GetPlatformMap())[key] = platform.release();
+ return port::Status::OK();
+}
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithName(
+ const string& target) {
+ mutex_lock lock(platforms_mutex_);
+ auto it = GetPlatformMap()->find(port::Lowercase(target));
+
+ if (it == GetPlatformMap()->end()) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ "could not find registered platform with name: \"" + target + "\"");
+ }
+
+ return it->second;
+}
+
+/* static */ port::StatusOr<Platform*> MultiPlatformManager::PlatformWithId(
+ const Platform::Id& id) {
+ mutex_lock lock(platforms_mutex_);
+ auto it = GetPlatformByIdMap()->find(id);
+ if (it == GetPlatformByIdMap()->end()) {
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::Printf("could not find registered platform with id: 0x%p", id));
+ }
+
+ return it->second;
+}
+
+/* static */ void MultiPlatformManager::ClearPlatformRegistry() {
+ mutex_lock lock(platforms_mutex_);
+ GetPlatformMap()->clear();
+ GetPlatformByIdMap()->clear();
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h
new file mode 100644
index 0000000000..ade7fac24b
--- /dev/null
+++ b/tensorflow/stream_executor/multi_platform_manager.h
@@ -0,0 +1,144 @@
+// This is a registration-oriented interface for multiple platforms. It will
+// replace the MachineManager singleton interface, as MachineManager does not
+// currently support simultaneous use of multiple platforms.
+//
+// Usage:
+//
+// In your BUILD rule, add a dependency on a platform plugin that you'd like
+// to use, such as:
+//
+// //perftools/gputools/executor/cuda:cuda_platform
+// //perftools/gputools/executor/opencl:opencl_platform
+//
+// This will register platform plugins that can be discovered via this
+// interface. Sample API usage:
+//
+// port::StatusOr<Platform*> platform_status =
+// gpu::MultiPlatformManager::PlatformWithName("OpenCL");
+// if (!platform_status.ok()) { ... }
+// Platform* platform = platform_status.ValueOrDie();
+// LOG(INFO) << platform->VisibleDeviceCount() << " devices visible";
+// if (platform->VisibleDeviceCount() <= 0) { return; }
+//
+// for (int i = 0; i < platform->VisibleDeviceCount(); ++i) {
+// port::StatusOr<StreamExecutor*> executor_status =
+// platform->ExecutorForDevice(i);
+// if (!executor_status.ok()) {
+// LOG(INFO) << "could not retrieve executor for device ordinal " << i
+// << ": " << executor_status.status();
+// continue;
+// }
+// LOG(INFO) << "found usable executor: " << executor_status.ValueOrDie();
+// }
+//
+// A few things to note:
+// - There is no standard formatting/practice for identifying the name of a
+// platform. Ideally, a platform will list its registered name in its header
+// or in other associated documentation.
+// - Platform name lookup is case-insensitive. "OpenCL" or "opencl" (or even
+// ("OpEnCl") would work correctly in the above example.
+//
+// And similarly, for standard interfaces (BLAS, RNG, etc.) you can add
+// dependencies on support libraries, e.g.:
+//
+// //perftools/gputools/executor/cuda:pluton_blas_plugin
+// //perftools/gputools/executor/cuda:cudnn_plugin
+// //perftools/gputools/executor/cuda:cublas_plugin
+// //perftools/gputools/executor/cuda:curand_plugin
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_
+
+#include <functional>
+#include <map>
+#include <memory>
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+// Manages multiple platforms that may be present on the current machine.
+class MultiPlatformManager {
+ public:
+ // Registers a platform object, returns an error status if the platform is
+ // already registered. The associated listener, if not null, will be used to
+ // trace events for ALL executors for that platform.
+ // Takes ownership of listener.
+ static port::Status RegisterPlatform(std::unique_ptr<Platform> platform);
+
+ // Retrieves the platform registered with the given platform name; e.g.
+ // "CUDA", "OpenCL", ...
+ //
+ // If the requested platform is not registered, an error status is returned.
+ // Ownership of the platform is NOT transferred to the caller --
+ // the MultiPlatformManager owns the platforms in a singleton-like fashion.
+ static port::StatusOr<Platform*> PlatformWithName(const string& target);
+
+ // Retrieves the platform registered with the given platform ID, which
+ // is an opaque (but comparable) value.
+ //
+ // If the requested platform is not registered, an error status is returned.
+ // Ownership of the platform is NOT transferred to the caller --
+ // the MultiPlatformManager owns the platforms in a singleton-like fashion.
+ static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id);
+
+ // Clears the set of registered platforms, primarily used for testing.
+ static void ClearPlatformRegistry();
+
+ // Although the MultiPlatformManager "owns" its platforms, it holds them as
+ // undecorated pointers to prevent races during program exit (between this
+ // object's data and the underlying platforms (e.g., CUDA, OpenCL).
+ // Because certain platforms have unpredictable deinitialization
+ // times/sequences, it is not possible to strucure a safe deinitialization
+ // sequence. Thus, we intentionally "leak" allocated platforms to defer
+ // cleanup to the OS. This should be acceptable, as these are one-time
+ // allocations per program invocation.
+ // The MultiPlatformManager should be considered the owner
+ // of any platforms registered with it, and leak checking should be disabled
+ // during allocation of such Platforms, to avoid spurious reporting at program
+ // exit.
+ using PlatformMap = std::map<string, Platform*>;
+
+ // Provides access to the available set of platforms under a lock.
+ static port::Status WithPlatforms(
+ std::function<port::Status(PlatformMap*)> callback) {
+ mutex_lock lock(platforms_mutex_);
+ return callback(GetPlatformMap());
+ }
+
+ private:
+ // mutex that guards the platform map.
+ static mutex platforms_mutex_;
+
+ // TODO(b/22689637): Clean up these two maps; make sure they coexist nicely.
+ // TODO(b/22689637): Move this (whatever the final/"official" map is) to
+ // plugin_regstry.h, along with the associated functionality.
+ // Platform-name-to-object mapping. These platforms are registered via module
+ // initializers, and linkage determines which platforms are available to a
+ // given target.
+ static PlatformMap* GetPlatformMap() {
+ static PlatformMap* instance = new PlatformMap;
+ return instance;
+ }
+
+ // Holds a Platform::Id-to-object mapping.
+ // Unlike platforms_ above, this map does not own its contents.
+ static std::map<Platform::Id, Platform*>* GetPlatformByIdMap() {
+ using PlatformIdMap = std::map<Platform::Id, Platform*>;
+ static PlatformIdMap* instance = new PlatformIdMap;
+ return instance;
+ }
+
+ SE_DISALLOW_COPY_AND_ASSIGN(MultiPlatformManager);
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_
diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc
new file mode 100644
index 0000000000..8be9353bbe
--- /dev/null
+++ b/tensorflow/stream_executor/platform.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/stream_executor/platform.h"
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+namespace perftools {
+namespace gputools {
+
+string PlatformKindString(PlatformKind kind) {
+ switch (kind) {
+ case PlatformKind::kCuda:
+ return "CUDA";
+ case PlatformKind::kOpenCL:
+ return "OpenCL";
+ case PlatformKind::kOpenCLAltera:
+ return "OpenCL+Altera";
+ case PlatformKind::kHost:
+ return "Host";
+ case PlatformKind::kMock:
+ return "Mock";
+ default:
+ return port::StrCat("InvalidPlatformKind(", static_cast<int>(kind), ")");
+ }
+}
+
+PlatformKind PlatformKindFromString(string kind) {
+ for (int i = 0; i < static_cast<int>(PlatformKind::kSize); ++i) {
+ if (kind == PlatformKindString(static_cast<PlatformKind>(i))) {
+ return static_cast<PlatformKind>(i);
+ }
+ }
+
+ return PlatformKind::kInvalid;
+}
+
+bool PlatformIsRunnable(PlatformKind kind) {
+ switch (kind) {
+ case PlatformKind::kCuda:
+ case PlatformKind::kOpenCL:
+ case PlatformKind::kHost:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool PlatformIsRunnableOnDevice(PlatformKind kind) {
+ switch (kind) {
+ case PlatformKind::kCuda:
+ case PlatformKind::kOpenCL:
+ return true;
+ default:
+ return false;
+ }
+}
+
+void CheckPlatformKindIsValid(PlatformKind kind) {
+ CHECK(static_cast<int>(PlatformKind::kCuda) <= static_cast<int>(kind) &&
+ static_cast<int>(kind) <= static_cast<int>(PlatformKind::kMock))
+ << "invalid GPU executor kind: " << PlatformKindString(kind);
+}
+
+StreamExecutorConfig::StreamExecutorConfig()
+ : ordinal(-1), device_options(DeviceOptions::Default()) {}
+
+StreamExecutorConfig::StreamExecutorConfig(int ordinal_in)
+ : ordinal(ordinal_in), device_options(DeviceOptions::Default()) {}
+
+Platform::~Platform() {}
+
+port::Status Platform::ForceExecutorShutdown() {
+ return port::Status(port::error::UNIMPLEMENTED,
+ "executor shutdown is not supported on this platform");
+}
+
+std::unique_ptr<Platform::PeerAccessMap> Platform::GetPeerAccessMap() {
+ auto *map = new PeerAccessMap;
+
+ int device_count = VisibleDeviceCount();
+ for (int i = 0; i < device_count; ++i) {
+ for (int j = 0; j < device_count; ++j) {
+ StreamExecutor *from = ExecutorForDevice(i).ValueOrDie();
+ StreamExecutor *to = ExecutorForDevice(j).ValueOrDie();
+ (*map)[{i, j}] = from->CanEnablePeerAccessTo(to);
+ }
+ }
+
+ return std::unique_ptr<Platform::PeerAccessMap>{map};
+}
+
+port::Status Platform::EnablePeerAccess() {
+ auto peer_access_map = GetPeerAccessMap();
+ for (const auto &access : *peer_access_map) {
+ auto devices = access.first;
+ if (access.second) {
+ StreamExecutor *from = ExecutorForDevice(devices.first).ValueOrDie();
+ StreamExecutor *to = ExecutorForDevice(devices.second).ValueOrDie();
+ auto status = from->EnablePeerAccessTo(to);
+ if (!status.ok()) {
+ return status;
+ }
+ } else {
+ LOG(INFO) << "cannot enable peer access from device ordinal "
+ << devices.first << " to device ordinal " << devices.second;
+ }
+ }
+ return port::Status::OK();
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h
new file mode 100644
index 0000000000..c8b500b424
--- /dev/null
+++ b/tensorflow/stream_executor/platform.h
@@ -0,0 +1,185 @@
+// Defines types and declares functions for identifying and extracting
+// information about the types of platforms and supporting libraries for which
+// StreamExecutor implementations exist.
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_
+
+#include <map>
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/device_options.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin.h"
+#include "tensorflow/stream_executor/trace_listener.h"
+
+namespace perftools {
+namespace gputools {
+
+class StreamExecutor;
+
+// Describes the platform for a StreamExecutor instantiation to act upon.
+//
+// Implementors: if you add a value here be sure to update PlatformKindString
+// and CheckPlatformKindIsValid.
+enum class PlatformKind {
+ kInvalid,
+ kCuda,
+ kOpenCL,
+ kOpenCLAltera, // Altera FPGA OpenCL platform.
+ // See documentation: go/fpgaopencl
+ // (StreamExecutor integration)
+ kHost,
+ kMock,
+ kSize,
+};
+
+// Returns true if kind represents a valid platform capable of enqueuing items
+// on a stream, but not necessarily on an accelerator device.
+// Returns false for kMock and any invalid PlatformKind values.
+bool PlatformIsRunnable(PlatformKind kind);
+
+// Returns true if kind represents a valid platform capable of running kernels
+// on an accelerator device. Returns false for kHost*, kMock and any invalid
+// PlatformKind values.
+bool PlatformIsRunnableOnDevice(PlatformKind kind);
+
+// Returns a printable description of a PlatformKind.
+string PlatformKindString(PlatformKind kind);
+
+// Returns the PlatformKind corresponding to the input string; returns kInvalid
+// in the case of no match.
+PlatformKind PlatformKindFromString(string platform_string);
+
+// Checks that kind takes on a valid value.
+void CheckPlatformKindIsValid(PlatformKind kind);
+
+// StreamExecutorConfig encapsulates the set of options for constructing a
+// StreamExecutor for a given platform.
+struct StreamExecutorConfig {
+ // Sets members to defaults: -1 for ordinal (must be changed), and default
+ // PluginConfig and DeviceOptions.
+ StreamExecutorConfig();
+
+ // Simple ordinal-setting constructor.
+ explicit StreamExecutorConfig(int ordinal);
+
+ // The ordinal of the device to be managed by the returned StreamExecutor.
+ int ordinal;
+
+ // The PluginConfig for the returned StreamExecutor.
+ PluginConfig plugin_config;
+
+ // The DeviceOptions for the returned StreamExecutor.
+ DeviceOptions device_options;
+};
+
+// Abstract base class for a platform registered with the MultiPlatformManager.
+class Platform {
+ public:
+ virtual ~Platform();
+
+ // A platform ID is a unique identifier for each registered platform type -
+ // each platform is required to expose an ID to ensure unique registration and
+ // as a target against which plugins can register.
+ //
+ // The macro below is provided to help generate a [process-unique] identifer.
+ using Id = void*;
+
+// Helper macro to define a plugin ID. To be used only inside plugin
+// implementation files. Works by "reserving" an address/value (guaranteed to be
+// unique) inside a process space.
+#define PLATFORM_DEFINE_ID(ID_VAR_NAME) \
+ namespace { \
+ int plugin_id_value; \
+ } \
+ const perftools::gputools::Platform::Id ID_VAR_NAME = &plugin_id_value;
+
+ // Returns a key uniquely identifying this platform.
+ virtual Id id() const = 0;
+
+ // Returns the number of devices accessible on this platform.
+ //
+ // Note that, though these devices are visible, if there is only one userspace
+ // context allowed for the device at a time and another process is using this
+ // device, a call to ExecutorForDevice may return an error status.
+ virtual int VisibleDeviceCount() const = 0;
+
+ // Name of this platform.
+ virtual const string& Name() const = 0;
+
+ // Returns a device with the given ordinal on this platform with a default
+ // plugin configuration or, if none can be found with the given ordinal or
+ // there is an error in opening a context to communicate with the device, an
+ // error status is returned.
+ //
+ // Ownership of the executor is NOT transferred to the caller --
+ // the Platform owns the executors in a singleton-like fashion.
+ virtual port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) = 0;
+
+ // Returns a device or error, as above, with the specified plugins.
+ //
+ // Ownership of the executor is NOT transferred to the caller.
+ virtual port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
+ int ordinal, const PluginConfig& plugin_config) = 0;
+
+ // Returns a device constructed with the options specified in "config".
+ // Ownership of the executor is NOT transferred to the caller.
+ virtual port::StatusOr<StreamExecutor*> GetExecutor(
+ const StreamExecutorConfig& config) = 0;
+
+ // Returns a device constructed with the options specified in "config" without
+ // looking in or storing to the Platform's executor cache.
+ // Ownership IS transferred to the caller.
+ virtual port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+ const StreamExecutorConfig& config) = 0;
+
+ // Warning: this is a dangerous API and should be used with caution.
+ //
+ // Forces the platform to delete executor instances, releasing their
+ // associated device contexts. There must be no held instances of the executor
+ // and there must be no outstanding activity on the devices for this platform.
+ //
+ // This is only useful on platforms which bind a device to a single process
+ // that has obtained the device context. May return UNIMPLEMENTED on platforms
+ // that have no reason to destroy device contexts.
+ virtual port::Status ForceExecutorShutdown();
+
+ // Registers a TraceListener to listen to all StreamExecutors for this
+ // platform.
+ // Takes ownership of listener.
+ virtual void RegisterTraceListener(
+ std::unique_ptr<TraceListener> listener) = 0;
+
+ // Removes the specified TraceListener from all StreamExecutors.
+ virtual void UnregisterTraceListener(TraceListener* listener) = 0;
+
+ // Map of executor-to-executor coordinate and boolean, indicating if the first
+ // executor can access the second's memory.
+ using PeerAccessMap = std::map<std::pair<int, int>, bool>;
+
+ // Returns a matrix indicating which executors can access which other
+ // executors' memory.
+ virtual std::unique_ptr<PeerAccessMap> GetPeerAccessMap();
+
+ // Attempts to enable all peer-to-peer access links described by the result of
+ // GetPeerAccessMap(). Note that calling this routine will force the creation
+ // of a default-argument (see StreamExecutorConfig) StreamExecutor object for
+ // each device ordinal in the system, should any not yet exist.
+ virtual port::Status EnablePeerAccess();
+
+ protected:
+ // SE_DISALLOW_COPY_AND_ASSIGN declares a constructor, which suppresses the
+ // presence of the default constructor. This statement re-enables it, which
+ // simplifies subclassing.
+ Platform() = default;
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(Platform);
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_
diff --git a/tensorflow/stream_executor/platform/default/mutex.h b/tensorflow/stream_executor/platform/default/mutex.h
new file mode 100644
index 0000000000..371eb7f156
--- /dev/null
+++ b/tensorflow/stream_executor/platform/default/mutex.h
@@ -0,0 +1,60 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
+
+#include <chrono> // NOLINT
+#include <condition_variable> // NOLINT
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+// std::shared_timed_mutex is a C++14 feature.
+#if (__cplusplus >= 201402L)
+#define STREAM_EXECUTOR_USE_SHARED_MUTEX
+#endif // __cplusplus >= 201402L
+
+#ifdef STREAM_EXECUTOR_USE_SHARED_MUTEX
+#include <shared_mutex> // NOLINT
+#else
+#include <mutex> // NOLINT
+#endif
+
+namespace perftools {
+namespace gputools {
+
+enum ConditionResult { kCond_Timeout, kCond_MaybeNotified };
+
+#ifdef STREAM_EXECUTOR_USE_SHARED_MUTEX
+typedef std::shared_timed_mutex BaseMutex;
+#else
+typedef std::mutex BaseMutex;
+#endif
+
+// A class that wraps around the std::mutex implementation, only adding an
+// additional LinkerInitialized constructor interface.
+class mutex : public BaseMutex {
+ public:
+ mutex() {}
+ // The default implementation of std::mutex is safe to use after the linker
+ // initializations
+ explicit mutex(LinkerInitialized x) {}
+};
+
+typedef std::unique_lock<BaseMutex> mutex_lock;
+
+#ifdef STREAM_EXECUTOR_USE_SHARED_MUTEX
+typedef std::shared_lock<BaseMutex> shared_lock;
+#else
+typedef std::unique_lock<BaseMutex> shared_lock;
+#endif
+
+using std::condition_variable;
+
+inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
+ condition_variable* cv, int64 ms) {
+ std::cv_status s = cv->wait_for(*mu, std::chrono::milliseconds(ms));
+ return (s == std::cv_status::timeout) ? kCond_Timeout : kCond_MaybeNotified;
+}
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_
diff --git a/tensorflow/stream_executor/platform/logging.h b/tensorflow/stream_executor/platform/logging.h
new file mode 100644
index 0000000000..a3e2385dd3
--- /dev/null
+++ b/tensorflow/stream_executor/platform/logging.h
@@ -0,0 +1,21 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_LOGGING_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_LOGGING_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+#if !defined(PLATFORM_GOOGLE)
+
+// A CHECK() macro that lets you assert the success of a function that
+// returns -1 and sets errno in case of an error. E.g.
+//
+// CHECK_ERR(mkdir(path, 0700));
+//
+// or
+//
+// int fd = open(filename, flags); CHECK_ERR(fd) << ": open " << filename;
+#define CHECK_ERR(invocation) CHECK((invocation) != -1) << #invocation
+
+#endif
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_LOGGING_H_
diff --git a/tensorflow/stream_executor/platform/mutex.h b/tensorflow/stream_executor/platform/mutex.h
new file mode 100644
index 0000000000..21b1894737
--- /dev/null
+++ b/tensorflow/stream_executor/platform/mutex.h
@@ -0,0 +1,12 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
+
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/stream_executor/platform/google/mutex.h"
+#else
+#include "tensorflow/stream_executor/platform/default/mutex.h"
+#endif
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_MUTEX_H_
diff --git a/tensorflow/stream_executor/platform/port.h b/tensorflow/stream_executor/platform/port.h
new file mode 100644
index 0000000000..ebe0cf517b
--- /dev/null
+++ b/tensorflow/stream_executor/platform/port.h
@@ -0,0 +1,40 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
+
+#include "tensorflow/core/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+using tensorflow::int8;
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+
+using tensorflow::uint8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+
+#if !defined(PLATFORM_GOOGLE)
+using std::string;
+#endif
+
+#if !defined(COMPILER_MSVC)
+#define ARRAYSIZE(a) \
+ ((sizeof(a) / sizeof(*(a))) / \
+ static_cast<size_t>(!(sizeof(a) % sizeof(*(a)))))
+#endif
+
+using tensorflow::LinkerInitialized;
+using tensorflow::LINKER_INITIALIZED;
+
+#define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED
+
+} // namespace gputools
+} // namespace perftools
+
+#define SE_DISALLOW_COPY_AND_ASSIGN TF_DISALLOW_COPY_AND_ASSIGN
+#define SE_MUST_USE_RESULT TF_MUST_USE_RESULT
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
diff --git a/tensorflow/stream_executor/platform/thread_annotations.h b/tensorflow/stream_executor/platform/thread_annotations.h
new file mode 100644
index 0000000000..bce4bb3794
--- /dev/null
+++ b/tensorflow/stream_executor/platform/thread_annotations.h
@@ -0,0 +1,6 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_THREAD_ANNOTATIONS_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_THREAD_ANNOTATIONS_H_
+
+#include "tensorflow/core/platform/thread_annotations.h"
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_THREAD_ANNOTATIONS_H_
diff --git a/tensorflow/stream_executor/plugin.cc b/tensorflow/stream_executor/plugin.cc
new file mode 100644
index 0000000000..8ca8ecff38
--- /dev/null
+++ b/tensorflow/stream_executor/plugin.cc
@@ -0,0 +1,40 @@
+#include "tensorflow/stream_executor/plugin.h"
+
+namespace perftools {
+namespace gputools {
+
+// Mostly-arbitrary ID only used as a sentinel "not otherwise initialized"
+// value. This value should never [need to] be specified aside by initialization
+// functions defined in this file and in PluginRegistry.
+PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(PluginConfig::kDefault);
+
+PluginConfig::PluginConfig()
+ : blas_(kDefault), dnn_(kDefault), fft_(kDefault), rng_(kDefault) {}
+
+bool PluginConfig::operator==(const PluginConfig& rhs) const {
+ return blas_ == rhs.blas_ && dnn_ == rhs.dnn_ && fft_ == rhs.fft_ &&
+ rng_ == rhs.rng_;
+}
+
+PluginConfig& PluginConfig::SetBlas(PluginId blas) {
+ blas_ = blas;
+ return *this;
+}
+
+PluginConfig& PluginConfig::SetDnn(PluginId dnn) {
+ dnn_ = dnn;
+ return *this;
+}
+
+PluginConfig& PluginConfig::SetFft(PluginId fft) {
+ fft_ = fft;
+ return *this;
+}
+
+PluginConfig& PluginConfig::SetRng(PluginId rng) {
+ rng_ = rng;
+ return *this;
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/plugin.h b/tensorflow/stream_executor/plugin.h
new file mode 100644
index 0000000000..5dc39b7928
--- /dev/null
+++ b/tensorflow/stream_executor/plugin.h
@@ -0,0 +1,74 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_
+
+namespace perftools {
+namespace gputools {
+
+// A plugin ID is a unique identifier for each registered plugin type.
+typedef void* PluginId;
+
+// Helper macro to define a plugin ID. To be used only inside plugin
+// implementation files. Works by "reserving" an address/value (guaranteed to be
+// unique) inside a process space.
+#define PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(ID_VAR_NAME) \
+ namespace { \
+ int plugin_id_value; \
+ } \
+ const PluginId ID_VAR_NAME = &plugin_id_value;
+
+// kNullPlugin denotes an invalid plugin identifier.
+extern const PluginId kNullPlugin;
+
+// Enumeration to list the supported types of plugins / support libraries.
+enum class PluginKind {
+ kInvalid,
+ kBlas,
+ kDnn,
+ kFft,
+ kRng,
+};
+
+// A PluginConfig describes the set of plugins to be used by a StreamExecutor
+// instance. Each plugin is defined by an arbitrary identifier, usually best set
+// to the address static member in the implementation (to avoid conflicts).
+//
+// A PluginConfig may be passed to the StreamExecutor constructor - the plugins
+// described therein will be used to provide BLAS, DNN, FFT, and RNG
+// functionality. Platform-approprate defaults will be used for any un-set
+// libraries. If a platform does not support a specified plugin (ex. cuBLAS on
+// an OpenCL executor), then an error will be logged and no plugin operations
+// will succeed.
+//
+// The StreamExecutor BUILD target does not link ANY plugin libraries - even
+// common host fallbacks! Any plugins must be explicitly linked by dependent
+// targets. See the cuda, opencl and host BUILD files for implemented plugin
+// support (search for "plugin").
+class PluginConfig {
+ public:
+ // Value specifying the platform's default option for that plugin.
+ static const PluginId kDefault;
+
+ // Initializes all members to the default options.
+ PluginConfig();
+
+ bool operator==(const PluginConfig& rhs) const;
+
+ // Sets the appropriate library kind to that passed in.
+ PluginConfig& SetBlas(PluginId blas);
+ PluginConfig& SetDnn(PluginId dnn);
+ PluginConfig& SetFft(PluginId fft);
+ PluginConfig& SetRng(PluginId rng);
+
+ PluginId blas() const { return blas_; }
+ PluginId dnn() const { return dnn_; }
+ PluginId fft() const { return fft_; }
+ PluginId rng() const { return rng_; }
+
+ private:
+ PluginId blas_, dnn_, fft_, rng_;
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_
diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc
new file mode 100644
index 0000000000..eda44d1146
--- /dev/null
+++ b/tensorflow/stream_executor/plugin_registry.cc
@@ -0,0 +1,228 @@
+#include "tensorflow/stream_executor/plugin_registry.h"
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+namespace perftools {
+namespace gputools {
+
+const PluginId kNullPlugin = nullptr;
+
+// Returns the string representation of the specified PluginKind.
+string PluginKindString(PluginKind plugin_kind) {
+ switch (plugin_kind) {
+ case PluginKind::kBlas:
+ return "BLAS";
+ case PluginKind::kDnn:
+ return "DNN";
+ case PluginKind::kFft:
+ return "FFT";
+ case PluginKind::kRng:
+ return "RNG";
+ case PluginKind::kInvalid:
+ default:
+ return "kInvalid";
+ }
+}
+
+PluginRegistry::DefaultFactories::DefaultFactories() :
+ blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
+
+/* static */ mutex PluginRegistry::mu_(LINKER_INITIALIZED);
+/* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
+
+PluginRegistry::PluginRegistry() {}
+
+/* static */ PluginRegistry* PluginRegistry::Instance() {
+ mutex_lock lock{mu_};
+ if (instance_ == nullptr) {
+ instance_ = new PluginRegistry();
+ }
+ return instance_;
+}
+
+void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
+ Platform::Id platform_id) {
+ platform_id_by_kind_[platform_kind] = platform_id;
+}
+
+template <typename FACTORY_TYPE>
+port::Status PluginRegistry::RegisterFactoryInternal(
+ PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
+ std::map<PluginId, FACTORY_TYPE>* factories) {
+ mutex_lock lock{mu_};
+
+ if (factories->find(plugin_id) != factories->end()) {
+ return port::Status{
+ port::error::ALREADY_EXISTS,
+ port::Printf("Attempting to register factory for plugin %s when "
+ "one has already been registered",
+ plugin_name.c_str())};
+ }
+
+ (*factories)[plugin_id] = factory;
+ plugin_names_[plugin_id] = plugin_name;
+ return port::Status::OK();
+}
+
+template <typename FACTORY_TYPE>
+port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
+ PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
+ const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
+ auto iter = factories.find(plugin_id);
+ if (iter == factories.end()) {
+ iter = generic_factories.find(plugin_id);
+ if (iter == generic_factories.end()) {
+ return port::Status{
+ port::error::NOT_FOUND,
+ port::Printf("Plugin ID %p not registered.", plugin_id)};
+ }
+ }
+
+ return iter->second;
+}
+
+bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
+ PluginKind plugin_kind,
+ PluginId plugin_id) {
+ if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
+ port::StatusOr<Platform*> status =
+ MultiPlatformManager::PlatformWithId(platform_id);
+ string platform_name = "<unregistered platform>";
+ if (status.ok()) {
+ platform_name = status.ValueOrDie()->Name();
+ }
+
+ LOG(ERROR) << "A factory must be registered for a platform before being "
+ << "set as default! "
+ << "Platform name: " << platform_name
+ << ", PluginKind: " << PluginKindString(plugin_kind)
+ << ", PluginId: " << plugin_id;
+ return false;
+ }
+
+ switch (plugin_kind) {
+ case PluginKind::kBlas:
+ default_factories_[platform_id].blas = plugin_id;
+ break;
+ case PluginKind::kDnn:
+ default_factories_[platform_id].dnn = plugin_id;
+ break;
+ case PluginKind::kFft:
+ default_factories_[platform_id].fft = plugin_id;
+ break;
+ case PluginKind::kRng:
+ default_factories_[platform_id].rng = plugin_id;
+ break;
+ default:
+ LOG(ERROR) << "Invalid plugin kind specified: "
+ << static_cast<int>(plugin_kind);
+ return false;
+ }
+
+ return true;
+}
+
+bool PluginRegistry::HasFactory(const PluginFactories& factories,
+ PluginKind plugin_kind,
+ PluginId plugin_id) const {
+ switch (plugin_kind) {
+ case PluginKind::kBlas:
+ return factories.blas.find(plugin_id) != factories.blas.end();
+ case PluginKind::kDnn:
+ return factories.dnn.find(plugin_id) != factories.dnn.end();
+ case PluginKind::kFft:
+ return factories.fft.find(plugin_id) != factories.fft.end();
+ case PluginKind::kRng:
+ return factories.rng.find(plugin_id) != factories.rng.end();
+ default:
+ LOG(ERROR) << "Invalid plugin kind specified: "
+ << PluginKindString(plugin_kind);
+ return false;
+ }
+}
+
+bool PluginRegistry::HasFactory(Platform::Id platform_id,
+ PluginKind plugin_kind,
+ PluginId plugin_id) const {
+ auto iter = factories_.find(platform_id);
+ if (iter != factories_.end()) {
+ if (HasFactory(iter->second, plugin_kind, plugin_id)) {
+ return true;
+ }
+ }
+
+ return HasFactory(generic_factories_, plugin_kind, plugin_id);
+}
+
+// Explicit instantiations to support types exposed in user/public API.
+#define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
+ template port::StatusOr<PluginRegistry::FACTORY_TYPE> \
+ PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
+ PluginId plugin_id, \
+ const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories, \
+ const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& \
+ generic_factories) const; \
+ \
+ template port::Status \
+ PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
+ PluginId plugin_id, const string& plugin_name, \
+ PluginRegistry::FACTORY_TYPE factory, \
+ std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories); \
+ \
+ template <> \
+ port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
+ Platform::Id platform_id, PluginId plugin_id, const string& name, \
+ PluginRegistry::FACTORY_TYPE factory) { \
+ return RegisterFactoryInternal(plugin_id, name, factory, \
+ &factories_[platform_id].FACTORY_VAR); \
+ } \
+ \
+ template <> \
+ port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \
+ PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \
+ PluginRegistry::FACTORY_TYPE factory) { \
+ return RegisterFactoryInternal(plugin_id, name, factory, \
+ &generic_factories_.FACTORY_VAR); \
+ } \
+ \
+ template <> \
+ port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
+ Platform::Id platform_id, PluginId plugin_id) { \
+ if (plugin_id == PluginConfig::kDefault) { \
+ plugin_id = default_factories_[platform_id].FACTORY_VAR; \
+ \
+ if (plugin_id == kNullPlugin) { \
+ return port::Status{port::error::FAILED_PRECONDITION, \
+ "No suitable " PLUGIN_STRING \
+ " plugin registered, default or otherwise."}; \
+ } else { \
+ VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \
+ << plugin_names_[plugin_id]; \
+ } \
+ } \
+ return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
+ generic_factories_.FACTORY_VAR); \
+ } \
+ \
+ /* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \
+ template <> \
+ port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
+ PlatformKind platform_kind, PluginId plugin_id) { \
+ auto iter = platform_id_by_kind_.find(platform_kind); \
+ if (iter == platform_id_by_kind_.end()) { \
+ return port::Status{port::error::FAILED_PRECONDITION, \
+ port::Printf("Platform kind %d not registered.", \
+ static_cast<int>(platform_kind))}; \
+ } \
+ return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
+ }
+
+EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
+EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
+EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
+EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
new file mode 100644
index 0000000000..f1ea59853d
--- /dev/null
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -0,0 +1,155 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
+#define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
+
+#include <map>
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/stream_executor/fft.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/plugin.h"
+#include "tensorflow/stream_executor/rng.h"
+
+namespace perftools {
+namespace gputools {
+
+namespace internal {
+class StreamExecutorInterface;
+}
+
+// The PluginRegistry is a singleton that maintains the set of registered
+// "support library" plugins. Currently, there are four kinds of plugins:
+// BLAS, DNN, FFT, and RNG. Each interface is defined in the corresponding
+// gpu_{kind}.h header.
+//
+// At runtime, a StreamExecutor object will query the singleton registry to
+// retrieve the plugin kind that StreamExecutor was configured with (refer to
+// the StreamExecutor and PluginConfig declarations).
+//
+// Plugin libraries are best registered using REGISTER_MODULE_INITIALIZER,
+// but can be registered at any time. When registering a DSO-backed plugin, it
+// is usually a good idea to load the DSO at registration time, to prevent
+// late-loading from distorting performance/benchmarks as much as possible.
+class PluginRegistry {
+ public:
+ typedef blas::BlasSupport* (*BlasFactory)(internal::StreamExecutorInterface*);
+ typedef dnn::DnnSupport* (*DnnFactory)(internal::StreamExecutorInterface*);
+ typedef fft::FftSupport* (*FftFactory)(internal::StreamExecutorInterface*);
+ typedef rng::RngSupport* (*RngFactory)(internal::StreamExecutorInterface*);
+
+ // Gets (and creates, if necessary) the singleton PluginRegistry instance.
+ static PluginRegistry* Instance();
+
+ // Registers the specified factory with the specified platform.
+ // Returns a non-successful status if the factory has already been registered
+ // with that platform (but execution should be otherwise unaffected).
+ template <typename FactoryT>
+ port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id,
+ const string& name, FactoryT factory);
+
+ // Registers the specified factory as usable by _all_ platform types.
+ // Reports errors just as RegisterFactory.
+ template <typename FactoryT>
+ port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id,
+ const string& name,
+ FactoryT factory);
+
+ // TODO(b/22689637): Setter for temporary mapping until all users are using
+ // MultiPlatformManager / PlatformId.
+ void MapPlatformKindToId(PlatformKind platform_kind,
+ Platform::Id platform_id);
+
+ // Potentially sets the plugin identified by plugin_id to be the default
+ // for the specified platform and plugin kind. If this routine is called
+ // multiple types for the same PluginKind, the PluginId given in the last call
+ // will be used.
+ bool SetDefaultFactory(Platform::Id platform_id, PluginKind plugin_kind,
+ PluginId plugin_id);
+
+ // Return true if the factory/id has been registered for the
+ // specified platform and plugin kind and false otherwise.
+ bool HasFactory(Platform::Id platform_id, PluginKind plugin_kind,
+ PluginId plugin) const;
+
+ // Retrieves the factory registered for the specified kind,
+ // or a port::Status on error.
+ template <typename FactoryT>
+ port::StatusOr<FactoryT> GetFactory(Platform::Id platform_id,
+ PluginId plugin_id);
+
+ // TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
+ // on MultiPlatformManager / PlatformId.
+ template <typename FactoryT>
+ port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
+ PluginId plugin_id);
+
+ private:
+ // Containers for the sets of registered factories, by plugin kind.
+ struct PluginFactories {
+ std::map<PluginId, BlasFactory> blas;
+ std::map<PluginId, DnnFactory> dnn;
+ std::map<PluginId, FftFactory> fft;
+ std::map<PluginId, RngFactory> rng;
+ };
+
+ // Simple structure to hold the currently configured default plugins (for a
+ // particular Platform).
+ struct DefaultFactories {
+ DefaultFactories();
+ PluginId blas, dnn, fft, rng;
+ };
+
+ PluginRegistry();
+
+ // Actually performs the work of registration.
+ template <typename FactoryT>
+ port::Status RegisterFactoryInternal(PluginId plugin_id,
+ const string& plugin_name,
+ FactoryT factory,
+ std::map<PluginId, FactoryT>* factories);
+
+ // Actually performs the work of factory retrieval.
+ template <typename FactoryT>
+ port::StatusOr<FactoryT> GetFactoryInternal(
+ PluginId plugin_id, const std::map<PluginId, FactoryT>& factories,
+ const std::map<PluginId, FactoryT>& generic_factories) const;
+
+ // Returns true if the specified plugin has been registered with the specified
+ // platform factories. Unlike the other overload of this method, this does
+ // not implicitly examine the default factory lists.
+ bool HasFactory(const PluginFactories& factories, PluginKind plugin_kind,
+ PluginId plugin) const;
+
+ // As this object is a singleton, a global mutex can be used for static and
+ // instance protection.
+ static mutex mu_;
+
+ // The singleton itself.
+ static PluginRegistry* instance_;
+
+ // TODO(b/22689637): Temporary mapping until all users are using
+ // MultiPlatformManager / PlatformId.
+ std::map<PlatformKind, Platform::Id> platform_id_by_kind_;
+
+ // The set of registered factories, keyed by platform ID.
+ std::map<Platform::Id, PluginFactories> factories_;
+
+ // Plugins supported for all platform kinds.
+ PluginFactories generic_factories_;
+
+ // The sets of default factories, keyed by platform ID.
+ std::map<Platform::Id, DefaultFactories> default_factories_;
+
+ // Lookup table for plugin names.
+ std::map<PluginId, string> plugin_names_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry);
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
diff --git a/tensorflow/stream_executor/rng.cc b/tensorflow/stream_executor/rng.cc
new file mode 100644
index 0000000000..052b502194
--- /dev/null
+++ b/tensorflow/stream_executor/rng.cc
@@ -0,0 +1,36 @@
+#include "tensorflow/stream_executor/rng.h"
+
+#include "tensorflow/stream_executor/platform/logging.h"
+
+namespace perftools {
+namespace gputools {
+namespace rng {
+
+bool RngSupport::CheckSeed(const uint8 *seed, uint64 seed_bytes) {
+ CHECK(seed != nullptr);
+
+ if (seed_bytes < kMinSeedBytes) {
+ LOG(INFO) << "Insufficient RNG seed data specified: " << seed_bytes
+ << ". At least " << RngSupport::kMinSeedBytes
+ << " bytes are required.";
+ return false;
+ }
+
+ if (seed_bytes > kMaxSeedBytes) {
+ LOG(INFO) << "Too much RNG seed data specified: " << seed_bytes
+ << ". At most " << RngSupport::kMaxSeedBytes
+ << " bytes may be provided.";
+ return false;
+ }
+
+ return true;
+}
+
+#if defined(__APPLE__)
+const int RngSupport::kMinSeedBytes;
+const int RngSupport::kMaxSeedBytes;
+#endif
+
+} // namespace rng
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/rng.h b/tensorflow/stream_executor/rng.h
new file mode 100644
index 0000000000..797631d01d
--- /dev/null
+++ b/tensorflow/stream_executor/rng.h
@@ -0,0 +1,80 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_RNG_H_
+#define TENSORFLOW_STREAM_EXECUTOR_RNG_H_
+
+#include <limits.h>
+#include <complex>
+
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+template <typename ElemT>
+class DeviceMemory;
+
+namespace rng {
+
+// Random-number-generation support interface -- this can be derived from a GPU
+// executor when the underlying platform has an RNG library implementation
+// available. See StreamExecutor::AsRng().
+// When a seed is not specified, the backing RNG will be initialized with the
+// default seed for that implementation.
+//
+// Thread-hostile: see StreamExecutor class comment for details on
+// thread-hostility.
+class RngSupport {
+ public:
+ static const int kMinSeedBytes = 16;
+ static const int kMaxSeedBytes = INT_MAX;
+
+ // Releases any random-number-generation resources associated with this
+ // support object in the underlying platform implementation.
+ virtual ~RngSupport() {}
+
+ // Populates a GPU memory allocation with random values appropriate for the
+ // DeviceMemory element type; i.e. populates DeviceMemory<float> with random
+ // float values.
+ virtual bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<float> *v) = 0;
+ virtual bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<double> *v) = 0;
+ virtual bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<float>> *v) = 0;
+ virtual bool DoPopulateRandUniform(Stream *stream,
+ DeviceMemory<std::complex<double>> *v) = 0;
+
+ // Populates a GPU memory allocation with random values sampled from a
+ // Gaussian distribution with the given mean and standard deviation.
+ virtual bool DoPopulateRandGaussian(Stream *stream, float mean, float stddev,
+ DeviceMemory<float> *v) {
+ LOG(ERROR)
+ << "platform's random number generator does not support gaussian";
+ return false;
+ }
+ virtual bool DoPopulateRandGaussian(Stream *stream, double mean,
+ double stddev, DeviceMemory<double> *v) {
+ LOG(ERROR)
+ << "platform's random number generator does not support gaussian";
+ return false;
+ }
+
+ // Specifies the seed used to initialize the RNG.
+ // This call does not transfer ownership of the buffer seed; its data should
+ // not be altered for the lifetime of this call. At least 16 bytes of seed
+ // data must be provided, but not all seed data will necessarily be used.
+ // seed: Pointer to seed data. Must not be null.
+ // seed_bytes: Size of seed buffer in bytes. Must be >= 16.
+ virtual bool SetSeed(Stream *stream, const uint8 *seed,
+ uint64 seed_bytes) = 0;
+
+ protected:
+ static bool CheckSeed(const uint8 *seed, uint64 seed_bytes);
+};
+
+} // namespace rng
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_RNG_H_
diff --git a/tensorflow/stream_executor/shared_memory_config.h b/tensorflow/stream_executor/shared_memory_config.h
new file mode 100644
index 0000000000..f2bfe27117
--- /dev/null
+++ b/tensorflow/stream_executor/shared_memory_config.h
@@ -0,0 +1,21 @@
+// This file defines a uniform interface to configuration options for shared
+// memory for supported devices. As with many StreamExecutor-supported features,
+// support for the options defined herein is device-dependent.
+#ifndef TENSORFLOW_STREAM_EXECUTOR_SHARED_MEMORY_CONFIG_H_
+#define TENSORFLOW_STREAM_EXECUTOR_SHARED_MEMORY_CONFIG_H_
+
+namespace perftools {
+namespace gputools {
+
+// SharedMemoryConfig enum describes potential widths of shared memory banks for
+// a device or kernel.
+enum class SharedMemoryConfig {
+ kDefault, // Use the device default configuration.
+ kFourByte, // Sets shared memory banks to be four bytes wide.
+ kEightByte, // Sets shared memory banks to be eight bytes wide.
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_SHARED_MEMORY_CONFIG_H_
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
new file mode 100644
index 0000000000..ca3ef9aa1a
--- /dev/null
+++ b/tensorflow/stream_executor/stream.cc
@@ -0,0 +1,3329 @@
+#include "tensorflow/stream_executor/stream.h"
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+namespace perftools {
+namespace gputools {
+
+namespace {
+static internal::StreamInterface *CreateStreamImplementation(
+ StreamExecutor *parent) {
+ PlatformKind platform_kind = parent->platform_kind();
+ if (platform_kind == PlatformKind::kCuda) {
+ return (*internal::MakeCUDAStreamImplementation())(parent);
+ } else if (platform_kind == PlatformKind::kOpenCL ||
+ platform_kind == PlatformKind::kOpenCLAltera) {
+ return (*internal::MakeOpenCLStreamImplementation())(parent);
+ } else if (platform_kind == PlatformKind::kHost) {
+ return internal::MakeHostStreamImplementation(parent);
+ } else {
+ LOG(FATAL) << "cannot create stream implementation for platform kind: "
+ << PlatformKindString(platform_kind);
+ }
+}
+
+// Code to turn parameters to functions on stream into strings that
+// will be VLOG'ed. We need overloads, instead of
+// e.g. BatchDescriptorToVlogString(), as the code that calls these
+// functions does not know what the type of the parameter is.
+string ToVlogString(const dnn::BatchDescriptor &descriptor) {
+ return descriptor.ToShortString();
+}
+
+string ToVlogString(const dnn::FilterDescriptor &descriptor) {
+ return descriptor.ToShortString();
+}
+
+string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
+ return descriptor.ToShortString();
+}
+
+string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
+ return descriptor.ToShortString();
+}
+
+string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
+ return descriptor.ToShortString();
+}
+
+string ToVlogString(dnn::ActivationMode mode) {
+ return dnn::ActivationModeString(mode);
+}
+
+string ToVlogString(dnn::ElementwiseOperation op) {
+ return dnn::ElementwiseOperationString(op);
+}
+
+string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
+
+string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
+
+string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
+
+string ToVlogString(blas::Side s) { return blas::SideString(s); }
+
+string ToVlogString(const void *ptr) {
+ if (ptr == nullptr) {
+ return "null";
+ }
+
+ // StrCat does not convert pointers to text.
+ std::ostringstream out;
+ out << ptr;
+ return out.str();
+}
+
+template <class T>
+string ToVlogString(const std::complex<T> &c) {
+ // StrCat does not convert std::complex to text.
+ std::ostringstream out;
+ out << c;
+ return out.str();
+}
+
+template <class T>
+string ToVlogString(const std::function<T> &f) {
+ return f == nullptr ? "null" : "<non-null function>";
+}
+
+string ToVlogString(const DeviceMemoryBase &memory) {
+ return ToVlogString(memory.opaque());
+}
+
+string ToVlogString(const DeviceMemoryBase *memory) {
+ return ToVlogString(*memory);
+}
+
+string ToVlogString(int i) { return port::StrCat(i); }
+
+string ToVlogString(uint32 i) { return port::StrCat(i); }
+
+string ToVlogString(uint64 i) { return port::StrCat(i); }
+
+string ToVlogString(float f) { return port::StrCat(f); }
+
+string ToVlogString(double d) { return port::StrCat(d); }
+
+template <class T>
+string ToVlogString(port::ArraySlice<T> elements) {
+ string str = port::StrCat(
+ ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
+ elements.size(), "]{");
+ const char *separator = "";
+ size_t max_to_show = std::numeric_limits<size_t>::max();
+ if (!VLOG_IS_ON(2)) {
+ max_to_show = 5;
+ } else if (!VLOG_IS_ON(3)) {
+ max_to_show = 20;
+ } else if (!VLOG_IS_ON(11)) {
+ max_to_show = 1000;
+ }
+ for (size_t i = 0; i < elements.size(); ++i) {
+ if (i == max_to_show) {
+ str += ", ...";
+ break;
+ }
+ port::StrAppend(&str, separator, ToVlogString(elements[i]));
+ separator = ", ";
+ }
+ str += "}";
+ return str;
+}
+
+template <class T>
+string ToVlogString(port::MutableArraySlice<T> elements) {
+ return ToVlogString(port::ArraySlice<T>(elements));
+}
+
+// Used together with PARAM to VLOG calls made to the stream. Intended
+// to be used like this:
+//
+// VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
+//
+// where a and b are the parameters to MyFunction.
+//
+// See VLOG_CALL for a short-hand for this. This way of doing it saves
+// a tremendous amount of boilerplate code given how many functions
+// there are on Stream and how many parameters they each have.
+string CallStr(const char *function_name, Stream *stream,
+ std::vector<std::pair<const char *, string>> params) {
+ // Do not call this function unless VLOG is on since just
+ // constructing all the strings in params is expensive.
+ CHECK(VLOG_IS_ON(1));
+
+ string str = port::StrCat("Called Stream::", function_name, "(");
+ const char *separator = "";
+ for (const auto &param : params) {
+ port::StrAppend(&str, separator, param.first, "=", param.second);
+ separator = ", ";
+ }
+ port::StrAppend(&str, ") stream=", ToVlogString(stream));
+ return str;
+}
+
+// Use this macro to avoid having to type every parameter twice to log
+// it with VLOG and CallStr.
+#define PARAM(parameter) \
+ { #parameter, ToVlogString(parameter) }
+
+// Use this macro to avoid having to type out the name of each
+// function and to save some boilerplate. Intended to be used like this:
+//
+// VLOG_CALL(PARAM(a), PARAM(b))
+//
+// This saves a tremendous amount of boilerplate compared to the alternative:
+//
+// VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
+// << ", b=" << ToVlogString(b);
+//
+// Note here that most of the parameter names are not short and that
+// most of the functions take many more than 2 parameters.
+#define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
+
+} // namespace
+
+Stream::Stream(StreamExecutor *parent)
+ : implementation_(CreateStreamImplementation(parent)),
+ parent_(parent),
+ allocated_(false),
+ ok_(false),
+ temporary_memory_manager_(this) {
+ VLOG_CALL(PARAM(parent));
+}
+
+Stream::Stream(StreamExecutor *parent,
+ internal::StreamInterface *implementation)
+ : implementation_(implementation),
+ parent_(parent),
+ allocated_(false),
+ ok_(false),
+ temporary_memory_manager_(this) {
+ VLOG_CALL(PARAM(parent), PARAM(implementation));
+}
+
+Stream::~Stream() {
+ VLOG_CALL();
+
+ temporary_memory_manager_.ForceDeallocateAll();
+
+ if (allocated_) {
+ parent_->DeallocateStream(this);
+ }
+}
+
+Stream &Stream::Init() {
+ VLOG_CALL();
+
+ mutex_lock lock{mu_};
+ CHECK_EQ(false, allocated_)
+ << "stream appears to already have been initialized";
+ CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
+
+ if (parent_->AllocateStream(this)) {
+ // Successful initialization!
+ allocated_ = true;
+ ok_ = true;
+ } else {
+ LOG(ERROR) << "failed to allocate stream during initialization";
+ }
+
+ return *this;
+}
+
+Stream &Stream::InitTimer(Timer *timer) {
+ VLOG_CALL(PARAM(timer));
+
+ if (ok()) {
+ CheckError(parent_->AllocateTimer(timer));
+ } else {
+ LOG(INFO) << "did not allocate timer: " << timer;
+ }
+ return *this;
+}
+
+Stream &Stream::InitWithTimer(Timer *timer) {
+ VLOG_CALL(PARAM(timer));
+
+ return Init().InitTimer(timer);
+}
+
+Stream &Stream::ThenRecordEvent(Event *event) {
+ VLOG_CALL(PARAM(event));
+
+ port::Status status = parent_->RecordEvent(this, event);
+ if (!status.ok()) {
+ LOG(ERROR) << "Error recording event in stream: " << status.error_message()
+ << "; not marking stream as bad, as the Event object may be "
+ << "at fault. Monitor for further errors.";
+ }
+
+ return *this;
+}
+
+Stream &Stream::ThenConvolve(
+ const dnn::BatchDescriptor &batch_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output) {
+ VLOG_CALL(PARAM(batch_descriptor), PARAM(input_data),
+ PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(convolution_descriptor), PARAM(output_descriptor),
+ PARAM(output));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoConvolve(
+ this, batch_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenSeparableConvolve(
+ const dnn::BatchDescriptor &batch_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
+ const DeviceMemory<float> &first_weights,
+ const DeviceMemory<float> &second_weights,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output) {
+ VLOG_CALL(
+ PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
+ PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
+ PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoSeparableConvolve(
+ this, batch_descriptor, input_data, filter_descriptor,
+ depth_multiplier, first_weights, second_weights,
+ convolution_descriptor, output_descriptor, output));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardData(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data) {
+ VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(input_descriptor),
+ PARAM(backward_input_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoConvolveBackwardData(
+ this, filter_descriptor, filter_data, output_descriptor,
+ backward_output_data, convolution_descriptor, input_descriptor,
+ backward_input_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenConvolveBackwardFilter(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
+ PARAM(output_descriptor), PARAM(backward_output_data),
+ PARAM(convolution_descriptor), PARAM(filter_descriptor),
+ PARAM(backward_filter_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoConvolveBackwardFilter(
+ this, input_descriptor, input_data, output_descriptor,
+ backward_output_data, convolution_descriptor, filter_descriptor,
+ backward_filter_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
+ const DeviceMemory<float> &weights,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
+ PARAM(output_dimensions), PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
+ output_dimensions, output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMatMulQuantized(
+ const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
+ PARAM(input_dimensions), PARAM(output_dimensions),
+ PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
+ weight_scales, input_dimensions,
+ output_dimensions, output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMatMulQuantized(
+ const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
+ PARAM(input_dimensions), PARAM(output_dimensions),
+ PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
+ weight_scales, input_dimensions,
+ output_dimensions, output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
+ const DeviceMemory<float> &biases,
+ const dnn::BatchDescriptor &dimensions,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
+ PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(
+ dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPoolForward(
+ const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
+ input_data, output_dimensions,
+ output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPoolBackward(
+ const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ const DeviceMemory<float> &output_data,
+ const DeviceMemory<float> &input_diff_data,
+ DeviceMemory<float> *output_diff_data) {
+ VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(input_diff_data), PARAM(output_diff_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
+ input_data, output_dimensions, output_data,
+ input_diff_data, output_diff_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenNormalize(
+ const dnn::NormalizeDescriptor &normalize_descriptor,
+ const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(normalize_descriptor), PARAM(input_data), PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoNormalize(this, normalize_descriptor, input_data,
+ output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
+ PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
+ output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenDepthConcatenate(
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float> *> input_data,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
+ output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenElementwiseOperate(
+ dnn::ElementwiseOperation operation,
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float> *> input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
+ PARAM(output_dimensions), PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
+ input_data, output_dimensions,
+ output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemcpyD2HQuantized(
+ const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<uint8> host_dst) {
+ VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(
+ dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemcpyD2HQuantized(
+ const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<uint16> host_dst) {
+ VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(
+ dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemcpyD2HQuantized(
+ const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<int32> host_dst) {
+ VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(
+ dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemcpyH2DQuantized(
+ port::ArraySlice<uint8> host_src,
+ DeviceMemory<float> *gpu_unquantized_dst) {
+ VLOG_CALL(PARAM(host_src), PARAM(gpu_unquantized_dst));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(
+ dnn->DoMemcpyH2DQuantized(this, host_src, gpu_unquantized_dst));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
+Stream *Stream::GetOrCreateSubStream() {
+ mutex_lock lock{mu_};
+ for (auto &stream : sub_streams_) {
+ if (stream.second) {
+ stream.second = false;
+ return stream.first.get();
+ }
+ }
+ sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
+ false);
+ Stream *sub_stream = sub_streams_.back().first.get();
+ sub_stream->Init();
+ CHECK(ok_) << "sub-stream failed to be initialized";
+
+ return sub_stream;
+}
+
+void Stream::ReturnSubStream(Stream *sub_stream) {
+ mutex_lock lock{mu_};
+ for (auto &stream : sub_streams_) {
+ if (stream.first.get() == sub_stream) {
+ stream.second = true;
+ return;
+ }
+ }
+ LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
+}
+
+Stream &Stream::ThenStartTimer(Timer *t) {
+ VLOG_CALL(PARAM(t));
+
+ if (ok()) {
+ CheckError(parent_->StartTimer(this, t));
+ } else {
+ LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenStopTimer(Timer *t) {
+ VLOG_CALL(PARAM(t));
+
+ if (ok()) {
+ CheckError(parent_->StopTimer(this, t));
+ } else {
+ LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenWaitFor(Stream *other) {
+ VLOG_CALL(PARAM(other));
+
+ CHECK(this != other) << "stream cannot wait for itself";
+ if (ok() && other->ok()) {
+ CheckError(parent_->CreateStreamDependency(this, other));
+ } else {
+ SetError();
+ LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenWaitFor(std::vector<std::unique_ptr<Stream>> *others) {
+ VLOG_CALL(PARAM(others));
+
+ for (auto &stream : *others) {
+ CHECK_NE(stream.get(), this);
+ ThenWaitFor(stream.get());
+ }
+ return *this;
+}
+
+Stream &Stream::ThenWaitFor(Event *event) {
+ VLOG_CALL(PARAM(event));
+
+ if (ok()) {
+ port::Status status = parent_->WaitForEvent(this, event);
+ if (!status.ok()) {
+ LOG(ERROR) << "Error waiting for event in stream: "
+ << status.error_message()
+ << "; not marking stream as bad, as the Event object may be "
+ << "at fault. Monitor for further errors.";
+ }
+ } else {
+ LOG(INFO) << "stream " << this << " did not wait for an event.";
+ }
+ return *this;
+}
+
+// A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
+// functions and logs for errors.
+template <typename... Args>
+struct ThenBlasImpl {
+ // blas_func is the DoBlasXXX member function pointer, and args are its
+ // arguments except the first one of Stream* type.
+ Stream &operator()(Stream *stream,
+ bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
+ Args... args);
+};
+
+template <typename... Args>
+Stream &ThenBlasImpl<Args...>::operator()(
+ Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
+ Args... args) {
+ if (stream->ok()) {
+ if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
+ stream->CheckError((blas->*blas_func)(stream, args...));
+ } else {
+ stream->CheckError(false);
+ LOG(WARNING)
+ << "attempting to perform BLAS operation using StreamExecutor "
+ "without BLAS support";
+ }
+ }
+ return *stream;
+}
+
+Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasAsum(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<float> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<float> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasAsum(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<double> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
+ DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
+ y, incy);
+}
+
+Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
+ y, incy);
+}
+
+Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<std::complex<float>> *y,
+ int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
+ y, incy);
+}
+
+Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<std::complex<double>> *y,
+ int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
+ y, incy);
+}
+
+Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasCopy(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<std::complex<float>> *y,
+ int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasCopy(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<std::complex<double>> *y,
+ int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
+ const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
+ result);
+}
+
+Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
+ result);
+}
+
+Stream &Stream::ThenBlasDotc(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx,
+ const DeviceMemory<std::complex<float>> &y,
+ int incy,
+ DeviceMemory<std::complex<float>> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
+ incy, result);
+}
+
+Stream &Stream::ThenBlasDotc(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx,
+ const DeviceMemory<std::complex<double>> &y,
+ int incy,
+ DeviceMemory<std::complex<double>> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
+ incy, result);
+}
+
+Stream &Stream::ThenBlasDotu(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx,
+ const DeviceMemory<std::complex<float>> &y,
+ int incy,
+ DeviceMemory<std::complex<float>> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
+ incy, result);
+}
+
+Stream &Stream::ThenBlasDotu(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx,
+ const DeviceMemory<std::complex<double>> &y,
+ int incy,
+ DeviceMemory<std::complex<double>> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
+ incy, result);
+}
+
+Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasNrm2(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<float> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<float> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasNrm2(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<double> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy, float c,
+ float s) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(c), PARAM(s));
+
+ ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
+ float, float> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
+ c, s);
+}
+
+Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
+ int incx, DeviceMemory<double> *y, int incy,
+ double c, double s) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(c), PARAM(s));
+
+ ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
+ double, double> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
+ c, s);
+}
+
+Stream &Stream::ThenBlasRot(uint64 elem_count,
+ DeviceMemory<std::complex<float>> *x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy,
+ float c, float s) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(c), PARAM(s));
+
+ ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
+ DeviceMemory<std::complex<float>> *, int, float, float> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
+ c, s);
+}
+
+Stream &Stream::ThenBlasRot(uint64 elem_count,
+ DeviceMemory<std::complex<double>> *x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy,
+ double c, double s) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(c), PARAM(s));
+
+ ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
+ DeviceMemory<std::complex<double>> *, int, double, double> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
+ c, s);
+}
+
+Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
+ DeviceMemory<float> *c, DeviceMemory<float> *s) {
+ VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
+
+ ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
+ DeviceMemory<float> *, DeviceMemory<float> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
+}
+
+Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
+ DeviceMemory<double> *c, DeviceMemory<double> *s) {
+ VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
+
+ ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
+ DeviceMemory<double> *, DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
+}
+
+Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
+ DeviceMemory<std::complex<float>> *b,
+ DeviceMemory<float> *c,
+ DeviceMemory<std::complex<float>> *s) {
+ VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
+
+ ThenBlasImpl<DeviceMemory<std::complex<float>> *,
+ DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
+ DeviceMemory<std::complex<float>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
+}
+
+Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
+ DeviceMemory<std::complex<double>> *b,
+ DeviceMemory<double> *c,
+ DeviceMemory<std::complex<double>> *s) {
+ VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
+
+ ThenBlasImpl<DeviceMemory<std::complex<double>> *,
+ DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
+ DeviceMemory<std::complex<double>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
+}
+
+Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
+ int incx, DeviceMemory<float> *y, int incy,
+ const DeviceMemory<float> &param) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(param));
+
+ ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
+ const DeviceMemory<float> &> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
+ incy, param);
+}
+
+Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
+ int incx, DeviceMemory<double> *y, int incy,
+ const DeviceMemory<double> &param) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
+ PARAM(param));
+
+ ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
+ const DeviceMemory<double> &> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
+ incy, param);
+}
+
+Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
+ DeviceMemory<float> *x1,
+ const DeviceMemory<float> &y1,
+ DeviceMemory<float> *param) {
+ VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
+
+ ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
+ DeviceMemory<float> *, const DeviceMemory<float> &,
+ DeviceMemory<float> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
+}
+
+Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
+ DeviceMemory<double> *d2,
+ DeviceMemory<double> *x1,
+ const DeviceMemory<double> &y1,
+ DeviceMemory<double> *param) {
+ VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
+
+ ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
+ DeviceMemory<double> *, const DeviceMemory<double> &,
+ DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
+}
+
+Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
+ DeviceMemory<float> *x, int incx) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
+}
+
+Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
+ DeviceMemory<double> *x, int incx) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
+}
+
+Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
+}
+
+Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
+}
+
+Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
+}
+
+Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<uint64, std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
+}
+
+Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
+ int incx, DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
+ int incx, DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasSwap(uint64 elem_count,
+ DeviceMemory<std::complex<float>> *x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasSwap(uint64 elem_count,
+ DeviceMemory<std::complex<double>> *x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
+ incy);
+}
+
+Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasIamax(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<int> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasIamax(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<int> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasIamin(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<int> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasIamin(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<int> *result) {
+ VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
+
+ ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<int> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
+ result);
+}
+
+Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
+ uint64 kl, uint64 ku, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
+ PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &,
+ int, float, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
+ a, lda, x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
+ uint64 kl, uint64 ku, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
+ PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
+ const DeviceMemory<double> &, int, const DeviceMemory<double> &,
+ int, double, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
+ a, lda, x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
+ uint64 kl, uint64 ku, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
+ PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
+ a, lda, x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
+ uint64 kl, uint64 ku, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
+ PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
+ a, lda, x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &,
+ int, float, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, double,
+ const DeviceMemory<double> &, int, const DeviceMemory<double> &,
+ int, double, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
+ PARAM(incy));
+
+ ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda) {
+ VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
+ const DeviceMemory<float> &, int, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda) {
+ VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, DeviceMemory<double> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx,
+ const DeviceMemory<std::complex<float>> &y,
+ int incy, DeviceMemory<std::complex<float>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<uint64, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx,
+ const DeviceMemory<std::complex<double>> &y,
+ int incy, DeviceMemory<std::complex<double>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<uint64, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx,
+ const DeviceMemory<std::complex<float>> &y,
+ int incy, DeviceMemory<std::complex<float>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<uint64, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx,
+ const DeviceMemory<std::complex<double>> &y,
+ int incy, DeviceMemory<std::complex<double>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
+ PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<uint64, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
+ PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
+ PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
+ incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
+ incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<std::complex<float>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
+ lda);
+}
+
+Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<std::complex<double>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
+ lda);
+}
+
+Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx,
+ const DeviceMemory<std::complex<float>> &y,
+ int incy, DeviceMemory<std::complex<float>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx,
+ const DeviceMemory<std::complex<double>> &y,
+ int incy, DeviceMemory<std::complex<double>> *a,
+ int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &ap,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &,
+ const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
+ beta, y, incy);
+}
+
+Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &ap,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &,
+ const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
+ beta, y, incy);
+}
+
+Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx, DeviceMemory<std::complex<float>> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
+}
+
+Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx, DeviceMemory<std::complex<double>> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
+}
+
+Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x,
+ int incx,
+ const DeviceMemory<std::complex<float>> &y,
+ int incy, DeviceMemory<std::complex<float>> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
+ incy, ap);
+}
+
+Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x,
+ int incx,
+ const DeviceMemory<std::complex<double>> &y,
+ int incy, DeviceMemory<std::complex<double>> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
+ incy, ap);
+}
+
+Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ float alpha, const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
+ PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &,
+ int, float, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
+ PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
+ const DeviceMemory<double> &, int, const DeviceMemory<double> &,
+ int, double, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
+ x, incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &ap,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
+ const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
+ beta, y, incy);
+}
+
+Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &ap,
+ const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
+ const DeviceMemory<double> &, int, double,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
+ beta, y, incy);
+}
+
+Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
+ int, DeviceMemory<float> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
+}
+
+Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
+ int, DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
+}
+
+Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
+ int, const DeviceMemory<float> &, int,
+ DeviceMemory<float> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
+ incy, ap);
+}
+
+Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *ap) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(ap));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
+ int, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
+ incy, ap);
+}
+
+Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
+ int, const DeviceMemory<float> &, int, float,
+ DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
+ incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
+ PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
+ int, const DeviceMemory<double> &, int, double,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
+ incx, beta, y, incy);
+}
+
+Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *a, int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
+ int, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
+ lda);
+}
+
+Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *a, int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
+ int, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
+ lda);
+}
+
+Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
+ int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda) {
+ VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
+ PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
+
+ ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
+ int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
+ incy, a, lda);
+}
+
+Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
+ PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ uint64, const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap,
+ DeviceMemory<float> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap,
+ DeviceMemory<double> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<float>> &,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<double>> &,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap,
+ DeviceMemory<float> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap,
+ DeviceMemory<double> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<float>> &,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
+ PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<double>> &,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
+ incx);
+}
+
+Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<float> &, int, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<double> &, int, DeviceMemory<double> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<float> &, int, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<double> &, int, DeviceMemory<double> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *x,
+ int incx) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
+ PARAM(lda), PARAM(x), PARAM(incx));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
+ lda, x, incx);
+}
+
+Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &,
+ int, float, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
+ alpha, a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
+ const DeviceMemory<double> &, int, const DeviceMemory<double> &,
+ int, double, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
+ alpha, a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &b,
+ int ldb, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
+ alpha, a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &b,
+ int ldb, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
+ alpha, a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &b,
+ int ldb, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
+ lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &b,
+ int ldb, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
+ lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, float alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, float beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
+ const DeviceMemory<std::complex<float>> &, int, float,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
+ lda, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, double alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, double beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
+ const DeviceMemory<std::complex<double>> &, int, double,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
+ lda, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &b,
+ int ldb, float beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, const DeviceMemory<std::complex<float>> &, int, float,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
+ a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &b,
+ int ldb, double beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, const DeviceMemory<std::complex<double>> &, int, double,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
+ a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &,
+ int, float, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
+ lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
+ const DeviceMemory<double> &, int, const DeviceMemory<double> &,
+ int, double, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
+ lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &b,
+ int ldb, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
+ lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &b,
+ int ldb, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
+ lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda, float beta,
+ DeviceMemory<float> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
+ const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
+ lda, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ double beta, DeviceMemory<double> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
+ const DeviceMemory<double> &, int, double,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
+ lda, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
+ lda, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
+ lda, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
+ const DeviceMemory<float> &, int, const DeviceMemory<float> &,
+ int, float, DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
+ a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
+ const DeviceMemory<double> &, int, const DeviceMemory<double> &,
+ int, double, DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
+ a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &b,
+ int ldb, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, const DeviceMemory<std::complex<float>> &, int,
+ std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
+ a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &b,
+ int ldb, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc) {
+ VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
+ PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
+ PARAM(ldc));
+
+ ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, const DeviceMemory<std::complex<double>> &, int,
+ std::complex<double>, DeviceMemory<std::complex<double>> *,
+ int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
+ a, lda, b, ldb, beta, c, ldc);
+}
+
+Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *b, int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, float, const DeviceMemory<float> &, int,
+ DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *b, int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, double, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *b,
+ int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *b,
+ int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *b, int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, float, const DeviceMemory<float> &, int,
+ DeviceMemory<float> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *b, int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, double, const DeviceMemory<double> &, int,
+ DeviceMemory<double> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda, DeviceMemory<std::complex<float>> *b,
+ int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, std::complex<float>,
+ const DeviceMemory<std::complex<float>> &, int,
+ DeviceMemory<std::complex<float>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag,
+ uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda, DeviceMemory<std::complex<double>> *b,
+ int ldb) {
+ VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
+ PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
+
+ ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
+ uint64, uint64, std::complex<double>,
+ const DeviceMemory<std::complex<double>> &, int,
+ DeviceMemory<std::complex<double>> *, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
+ n, alpha, a, lda, b, ldb);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
+ int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const port::ArraySlice<DeviceMemory<float> *> &, int,
+ const port::ArraySlice<DeviceMemory<float> *> &, int, float,
+ const port::ArraySlice<DeviceMemory<float> *> &, int, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
+ int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
+ double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
+ const port::ArraySlice<DeviceMemory<double> *> &, int,
+ const port::ArraySlice<DeviceMemory<double> *> &, int, double,
+ const port::ArraySlice<DeviceMemory<double> *> &, int, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
+ std::complex<float> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
+ int,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
+ int, std::complex<float>,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
+ int, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
+ std::complex<double> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
+ PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
+ int,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
+ int, std::complex<double>,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
+ int, int> impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+}
+
+Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
+ VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
+
+ if (ok()) {
+ if (rng::RngSupport *rng = parent_->AsRng()) {
+ CheckError(rng->SetSeed(this, seed, seed_bytes));
+ } else {
+ SetError();
+ LOG(INFO) << "stream " << this << " unable to initialize RNG";
+ }
+ } else {
+ LOG(INFO) << "stream " << this
+ << " did not set RNG seed: " << static_cast<const void *>(seed)
+ << "; bytes: " << seed_bytes;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
+ VLOG_CALL(PARAM(values));
+
+ if (ok()) {
+ if (rng::RngSupport *rng = parent_->AsRng()) {
+ CheckError(rng->DoPopulateRandUniform(this, values));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
+ "without RNG support.";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
+ DeviceMemory<float> *values) {
+ VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
+
+ if (ok()) {
+ if (rng::RngSupport *rng = parent_->AsRng()) {
+ CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
+ "without RNG support.";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
+ DeviceMemory<double> *values) {
+ VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
+
+ if (ok()) {
+ if (rng::RngSupport *rng = parent_->AsRng()) {
+ CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
+ "without RNG support.";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
+ VLOG_CALL(PARAM(values));
+
+ if (ok()) {
+ if (rng::RngSupport *rng = parent_->AsRng()) {
+ CheckError(rng->DoPopulateRandUniform(this, values));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
+ "without RNG support.";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPopulateRandUniform(
+ DeviceMemory<std::complex<float>> *values) {
+ VLOG_CALL(PARAM(values));
+
+ if (ok()) {
+ if (rng::RngSupport *rng = parent_->AsRng()) {
+ CheckError(rng->DoPopulateRandUniform(this, values));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
+ "without RNG support.";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenPopulateRandUniform(
+ DeviceMemory<std::complex<double>> *values) {
+ VLOG_CALL(PARAM(values));
+
+ if (ok()) {
+ if (rng::RngSupport *rng = parent_->AsRng()) {
+ CheckError(rng->DoPopulateRandUniform(this, values));
+ } else {
+ SetError();
+ LOG(INFO) << "stream " << this
+ << " attempting to perform RNG operation using StreamExecutor "
+ "without RNG support.";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size) {
+ VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
+
+ if (ok()) {
+ CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
+ } else {
+ LOG(INFO) << "stream " << this
+ << " did not memcpy device-to-host; source: " << gpu_src.opaque();
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size) {
+ VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
+
+ if (ok()) {
+ CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
+ } else {
+ LOG(INFO) << "stream " << this
+ << " did not memcpy host-to-device; source: " << host_src;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size) {
+ VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
+
+ if (ok()) {
+ CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
+ } else {
+ LOG(INFO) << "stream " << this
+ << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
+ VLOG_CALL(PARAM(location), PARAM(size));
+
+ if (ok()) {
+ CheckError(parent_->MemZero(this, location, size));
+ } else {
+ LOG(INFO) << "stream " << this
+ << " did not memzero GPU location; source: " << location;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,
+ uint64 size) {
+ VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
+
+ if (ok()) {
+ CheckError(parent_->Memset32(this, location, pattern, size));
+ } else {
+ LOG(INFO) << "stream " << this
+ << " did not memset GPU location; source: " << location
+ << "; size: " << size << "; pattern: " << std::hex << pattern;
+ }
+ return *this;
+}
+
+Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
+ VLOG_CALL(PARAM(callback));
+
+ return ThenDoHostCallback(callback);
+}
+
+Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
+ VLOG_CALL(PARAM(callback));
+
+ if (ok()) {
+ CheckError(parent_->HostCallback(this, callback));
+ } else {
+ LOG(INFO) << "stream " << this
+ << " was in error state before adding host callback";
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<std::complex<float>> *output) {
+ VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
+
+ if (ok()) {
+ if (fft::FftSupport *fft = parent_->AsFft()) {
+ CheckError(fft->DoFft(this, plan, input, output));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
+ "without FFT support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<std::complex<double>> *output) {
+ VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
+
+ if (ok()) {
+ if (fft::FftSupport *fft = parent_->AsFft()) {
+ CheckError(fft->DoFft(this, plan, input, output));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
+ "without FFT support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
+ DeviceMemory<std::complex<float>> *output) {
+ VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
+
+ if (ok()) {
+ if (fft::FftSupport *fft = parent_->AsFft()) {
+ CheckError(fft->DoFft(this, plan, input, output));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
+ "without FFT support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
+ DeviceMemory<std::complex<double>> *output) {
+ VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
+
+ if (ok()) {
+ if (fft::FftSupport *fft = parent_->AsFft()) {
+ CheckError(fft->DoFft(this, plan, input, output));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
+ "without FFT support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<float> *output) {
+ VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
+
+ if (ok()) {
+ if (fft::FftSupport *fft = parent_->AsFft()) {
+ CheckError(fft->DoFft(this, plan, input, output));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
+ "without FFT support";
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<double> *output) {
+ VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
+
+ if (ok()) {
+ if (fft::FftSupport *fft = parent_->AsFft()) {
+ CheckError(fft->DoFft(this, plan, input, output));
+ } else {
+ SetError();
+ LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
+ "without FFT support";
+ }
+ }
+ return *this;
+}
+
+// It looks confusing, but all this is doing is inserting a callback at the
+// present point in the stream to then enqueue a task on the host executor.
+Stream &Stream::ThenEnqueueOnBackgroundThread(
+ std::function<void(StreamExecutor *)> task) {
+ VLOG_CALL(PARAM(task));
+
+ StreamExecutor *stream_executor = this->parent_;
+ std::function<void()> bound_task = std::bind(task, stream_executor);
+
+ return ThenDoHostCallback([stream_executor, bound_task]() {
+ stream_executor->EnqueueOnBackgroundThread(bound_task);
+ });
+}
+
+bool Stream::BlockHostUntilDone() {
+ VLOG_CALL();
+
+ if (!ok()) {
+ LOG(INFO)
+ << "stream " << this
+ << " did not block host until done; was already in an error state";
+ return false;
+ }
+
+ {
+ // Wait until all active sub-streams have done their tasks.
+ mutex_lock lock{mu_};
+ for (auto &stream : sub_streams_) {
+ if (!stream.second) {
+ CheckError(stream.first->BlockHostUntilDone());
+ // Set this sub-stream as available.
+ stream.second = true;
+ }
+ }
+ }
+
+ temporary_memory_manager_.DeallocateFinalizedTemporaries();
+
+ CheckError(parent_->BlockHostUntilDone(this));
+ return ok();
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
new file mode 100644
index 0000000000..d4d5e7729b
--- /dev/null
+++ b/tensorflow/stream_executor/stream.h
@@ -0,0 +1,1258 @@
+// The Stream is used in conjunction with the StreamExecutor "parent" to
+// perform actions with a linear stream of dependencies. Dependencies can also
+// be created between Streams to do task management (i.e. limit which tasks
+// can be performed concurrently and specify what task dependencies exist).
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
+
+#include <complex>
+#include <functional>
+#include <memory>
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/fft.h"
+#include "tensorflow/stream_executor/kernel.h"
+#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/temporary_memory_manager.h"
+
+namespace perftools {
+namespace gputools {
+
+namespace host {
+class HostBlas;
+class HostFft;
+class HostRng;
+class HostTimer;
+} // namespace host
+
+namespace ocl {
+class CLBlas;
+} // namespace ocl
+
+namespace internal {
+class StreamInterface;
+} // namespace internal
+
+class DeviceMemoryBase;
+template <typename ElemT>
+class DeviceMemory;
+
+class Timer;
+
+namespace dnn {
+struct BatchDescriptor;
+struct FilterDescriptor;
+struct ConvolutionDescriptor;
+} // namespace dnn
+
+class StreamExecutor;
+
+// Represents a stream of dependent computations on a GPU device.
+//
+// The operations within a stream execute linearly and asynchronously until
+// BlockHostUntilDone() is invoked, which synchronously joins host code with
+// the execution of the stream.
+//
+// If any given operation fails when entraining work for the stream, ok() will
+// indicate that an error has occurred. After initialization, once a stream is
+// !ok(), it will never be ok().
+//
+// Thread-safe post-initialization.
+class Stream {
+ public:
+ // Instantiate a stream tied to parent as a platform executor. Work
+ // entrained onto this stream will be launched/managed on that
+ // StreamExecutor's platform.
+ explicit Stream(StreamExecutor *parent);
+
+ // Test only. Use an externally-populated value (like a mock) for the
+ // platform-specific stream implementation.
+ Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
+
+ // Deallocates any stream resources that the parent StreamExecutor has
+ // bestowed
+ // upon this object.
+ ~Stream();
+
+ // Returns whether any errors have occurred while entraining work for this
+ // stream.
+ bool ok() const { return !InErrorState(); }
+
+ // Initialize the stream. This must be performed before entraining any other
+ // operations.
+ Stream &Init();
+
+ // Initializes timer t via the StreamExecutor.
+ Stream &InitTimer(Timer *t);
+
+ // Convenience wrapper around Init() and InitTimer().
+ Stream &InitWithTimer(Timer *t);
+
+ // Warning! After calling BlockHostUntilDone(), all sub-streams will be
+ // returned and hence invalid. This may be a temporary solution to the issue
+ // b/18070215.
+ // Get or create a sub-stream from this stream. If there is any sub-stream
+ // in the pool that can be reused then just return this sub-stream.
+ // Otherwise
+ // create a new sub-stream.
+ Stream *GetOrCreateSubStream();
+
+ // Return the sub-stream back to the host stream so that it can be reused
+ // later.
+ void ReturnSubStream(Stream *sub_stream);
+
+ // Allocate temporary memories. The stream will deallocate them when blocked
+ // or destroyed.
+ template <typename T>
+ port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
+ AllocateTemporaryArray(uint64 element_count);
+
+ // Entrains onto the stream of operations: a kernel launch with the given
+ // (variadic) parameters for the invocation. These arguments can be things
+ // like DeviceMemory or primitive types such as int. What arguments you may
+ // pass to a given kernel are noted as the template parameters to the
+ // TypedKernel type that the machocc compiler generates.
+ //
+ // Template parameters:
+ // Params... The type list of formal parameters that the typed kernel
+ // expects, which is matched against Args...
+ // Args... The deduced type list for passed actual arguments
+ //
+ // Implementation: A compile-time compatibility check is performed that has
+ // some leniency versus an exact parameter pack match -- for example,
+ // `const DeviceMemory<T>` is considered "pack compatible" with a
+ // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
+ // perfect forwarding support without rvalue references. It also attempts to
+ // spit out helpful static_assert error traces with information as to the
+ // argument number and types that were mismatched.
+ template <typename... Params, typename... Args>
+ Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
+ const TypedKernel<Params...> &kernel, Args... args);
+
+ // Record a "start" event for the interval timer at this point in the
+ // stream's
+ // execution (relative to the previously and subsequently enqueued items in
+ // the stream's execution). Streams may be started/stopped multiple times.
+ Stream &ThenStartTimer(Timer *t);
+
+ // Record a "stop" event for the interval timer at this point in the
+ // stream's
+ // execution. See also Stream::ThenStartTimer.
+ Stream &ThenStopTimer(Timer *t);
+
+ // TODO(leary) If work is added to the stream that is being depended upon,
+ // then what? Have to describe what happens.
+ template <typename... Params>
+ Stream &ThenWaitFor(Stream *other, Params... more_streams) {
+ return ThenWaitFor(more_streams...).ThenWaitFor(other);
+ }
+
+ // Create a dependency for this stream's next work on the other stream
+ // completing. Does not take ownership of other, and other must not be
+ // null.
+ //
+ // Checks that a stream does not wait for itself, and it is up to the
+ // user to guarantee that a stream does not come to wait on itself in a
+ // cyclic
+ // manner; in that case, behavior is undefined.
+ //
+ // N.B. Base recursion case for the variadic ThenWaitFor.
+ Stream &ThenWaitFor(Stream *other);
+
+ // Waits for all streams values in others.
+ // Checks that there is no shallow circular wait (i.e. that "this" is not in
+ // others).
+ Stream &ThenWaitFor(std::vector<std::unique_ptr<Stream>> *others);
+
+ // Waits for an event object to be set.
+ // Note that ThenRecordEvent must have been called on the event before
+ // you call this function; otherwise the event will be considered complete
+ // and this wait will do nothing.
+ Stream &ThenWaitFor(Event *event);
+
+ // Inserts the specified event into the end of this stream. Once the stream
+ // has processed all events prior to the insertion point, the event will be
+ // marked as completed.
+ // The stream does not take ownership of event - meaning that event's lifetime
+ // must extend past the point at which it is marked complete!
+ Stream &ThenRecordEvent(Event *event);
+
+ ////////////////
+ // DNN support
+ //
+ // See DnnSupport::* for comments on the following methods.
+
+ // TODO(leary) add double-precision version of this interface.
+ Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output);
+
+ Stream &ThenSeparableConvolve(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
+ const DeviceMemory<float> &first_weights,
+ const DeviceMemory<float> &second_weights,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output);
+
+ Stream &ThenConvolveBackwardData(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data);
+
+ Stream &ThenConvolveBackwardFilter(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data);
+
+ Stream &ThenMatMul(const DeviceMemory<float> &input_data,
+ const DeviceMemory<float> &weights,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
+ const DeviceMemory<int8> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
+ const DeviceMemory<int16> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
+ const DeviceMemory<float> &biases,
+ const dnn::BatchDescriptor &dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ const DeviceMemory<float> &output_data,
+ const DeviceMemory<float> &input_diff_data,
+ DeviceMemory<float> *output_diff_data);
+
+ Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor,
+ const DeviceMemory<float> &input_data,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenActivate(dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenDepthConcatenate(
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float> *> input_data,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenElementwiseOperate(
+ dnn::ElementwiseOperation operation,
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float> *> input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ // See DnnSupport::DoMemcpyD2HQuantized.
+ // TODO(wgulland) Use a template to merge the versions of
+ // ThenMemcpyD2HQuantized.
+ Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<uint8> host_dst);
+
+ // See DnnSupport::DoMemcpyD2HQuantized.
+ Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<uint16> host_dst);
+
+ // See DnnSupport::DoMemcpyD2HQuantized.
+ Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<int32> host_dst);
+
+ // See DnnSupport::DoMemcpyH2DQuantized.
+ Stream &ThenMemcpyH2DQuantized(port::ArraySlice<uint8> host_src,
+ DeviceMemory<float> *gpu_unquantized_dst);
+
+ /////////////////
+ // BLAS support
+
+ // See BlasSupport::DoBlasAsum.
+ Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *result);
+ Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *result);
+ Stream &ThenBlasAsum(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result);
+ Stream &ThenBlasAsum(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result);
+
+ // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
+ // present in DeviceMemory, it must be an execution-time constant (i.e. a
+ // value
+ // that the stream does not change or populate during the course of
+ // execution). The value is effectively captured at stream-enqueue time.
+ Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasCopy.
+ Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasCopy(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasCopy(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasDot.
+ Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *result);
+ Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *result);
+
+ // See BlasSupport::DoBlasDotc.
+ Stream &ThenBlasDotc(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result);
+ Stream &ThenBlasDotc(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result);
+
+ // See BlasSupport::DoBlasDotu.
+ Stream &ThenBlasDotu(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result);
+ Stream &ThenBlasDotu(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result);
+
+ // See BlasSupport::DoBlasNrm2.
+ Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *result);
+ Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *result);
+ Stream &ThenBlasNrm2(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result);
+ Stream &ThenBlasNrm2(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result);
+
+ // See BlasSupport::DoBlasRot.
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy, float c, float s);
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy, double c, double s);
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
+ int incx, DeviceMemory<std::complex<float>> *y, int incy,
+ float c, float s);
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
+ int incx, DeviceMemory<std::complex<double>> *y, int incy,
+ double c, double s);
+
+ // See BlasSupport::DoBlasRotg.
+ Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
+ DeviceMemory<float> *c, DeviceMemory<float> *s);
+ Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
+ DeviceMemory<double> *c, DeviceMemory<double> *s);
+ Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
+ DeviceMemory<std::complex<float>> *b,
+ DeviceMemory<float> *c,
+ DeviceMemory<std::complex<float>> *s);
+ Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
+ DeviceMemory<std::complex<double>> *b,
+ DeviceMemory<double> *c,
+ DeviceMemory<std::complex<double>> *s);
+
+ // See BlasSupport::DoBlasRotm.
+ Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy,
+ const DeviceMemory<float> &param);
+ Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy,
+ const DeviceMemory<double> &param);
+
+ // See BlasSupport::DoBlasRotmg.
+ Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
+ DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
+ DeviceMemory<float> *param);
+ Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
+ DeviceMemory<double> *x1,
+ const DeviceMemory<double> &y1,
+ DeviceMemory<double> *param);
+
+ // See BlasSupport::DoBlasScal.
+ Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
+ int incx);
+ Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
+ int incx);
+ Stream &ThenBlasScal(uint64 elem_count, float alpha,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasScal(uint64 elem_count, double alpha,
+ DeviceMemory<std::complex<double>> *x, int incx);
+ Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasSwap.
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
+ int incx, DeviceMemory<std::complex<float>> *y,
+ int incy);
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
+ int incx, DeviceMemory<std::complex<double>> *y,
+ int incy);
+
+ // See BlasSupport::DoBlasIamax.
+ Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamax(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result);
+ Stream &ThenBlasIamax(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<int> *result);
+
+ // See BlasSupport::DoBlasIamin.
+ Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamin(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result);
+ Stream &ThenBlasIamin(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<int> *result);
+
+ // See BlasSupport::DoBlasGbmv.
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x, int incx,
+ float beta, DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasGemv.
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasGer.
+ Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda);
+ Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda);
+
+ // See BlasSupport::DoBlasGerc.
+ Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasGeru.
+ Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasHbmv.
+ Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasHemv.
+ Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasHer.
+ Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasHer2.
+ Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasHpmv.
+ Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &ap,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &ap,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasHpr.
+ Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *ap);
+ Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *ap);
+
+ // See BlasSupport::DoBlasHpr2.
+ Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *ap);
+ Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *ap);
+
+ // See BlasSupport::DoBlasSbmv.
+ Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+
+ // See BlasSupport::DoBlasSpmv.
+ Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &ap,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &ap,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+
+ // See BlasSupport::DoBlasSpr.
+ Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *ap);
+ Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *ap);
+
+ // See BlasSupport::DoBlasSpr2.
+ Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *ap);
+ Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *ap);
+
+ // See BlasSupport::DoBlasSymv.
+ Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+
+ // See BlasSupport::DoBlasSyr.
+ Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *a, int lda);
+ Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *a, int lda);
+
+ // See BlasSupport::DoBlasSyr2.
+ Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda);
+ Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda);
+
+ // See BlasSupport::DoBlasTbmv.
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTbsv.
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTpmv.
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx);
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap, DeviceMemory<double> *x,
+ int incx);
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTpsv.
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx);
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap, DeviceMemory<double> *x,
+ int incx);
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTrmv.
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTrsv.
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasGemm.
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasGemmBatched.
+ Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<float> *> &a,
+ int lda,
+ const port::ArraySlice<DeviceMemory<float> *> &b,
+ int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<float> *> &c,
+ int ldc, int batch_count);
+ Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, double alpha,
+ const port::ArraySlice<DeviceMemory<double> *> &a,
+ int lda,
+ const port::ArraySlice<DeviceMemory<double> *> &b,
+ int ldb, double beta,
+ const port::ArraySlice<DeviceMemory<double> *> &c,
+ int ldc, int batch_count);
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
+ std::complex<float> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
+ int batch_count);
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
+ std::complex<double> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
+ int batch_count);
+
+ // See BlasSupport::DoBlasHemm.
+ Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasHerk.
+ Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, float alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc);
+ Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, double alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc);
+
+ // See BlasSupport::DoBlasHer2k.
+ Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc);
+ Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc);
+
+ // See BlasSupport::DoBlasSymm.
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasSyrk.
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a,
+ int lda, float beta, DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a,
+ int lda, double beta, DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasSyr2k.
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasTrmm.
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, DeviceMemory<float> *b, int ldb);
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, DeviceMemory<double> *b, int ldb);
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb);
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb);
+
+ // See BlasSupport::DoBlasTrsm.
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, DeviceMemory<float> *b, int ldb);
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, DeviceMemory<double> *b, int ldb);
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb);
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb);
+
+ // See FftSupport::DoFft.
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<std::complex<float>> *output);
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<std::complex<double>> *output);
+ Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
+ DeviceMemory<std::complex<float>> *output);
+ Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
+ DeviceMemory<std::complex<double>> *output);
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<float> *output);
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<double> *output);
+
+ // Makes the RNG use the provided value as the basis for further generation.
+ // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
+ // sources of seed data if the default (high quality) sources are not
+ // desired.
+ // For most use cases, this function will not be necessary; each provided
+ // back-end implementation will be appropriately seeded by default.
+ // At a minimum 16 bytes of data are required in the seed buffer.
+ //
+ // To seed with good (non-reproducable) data:
+ // File* f = File::Open("/dev/random", "r");
+ // int64 bytes_read = f->Read(seed_data, bytes_to_read);
+ // < error checking >
+ // stream.ThenSetRngSeed(seed_data, bytes_read);
+ //
+ // To seed with reproducible data:
+ // uint64_t seed_data[2] = { <data> };
+ // stream.ThenSetRngSeed(seed_data, 16);
+ Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
+
+ // Populates the memory indicated by values with uniform-random-distribution
+ // values. TODO(leary) seeding API/description
+ //
+ // Uses the type and size of the DeviceMemory to infer what data should be
+ // populated.
+ Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
+ Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
+ Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
+ Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
+ Stream &ThenPopulateRandGaussian(float mean, float stddev,
+ DeviceMemory<float> *values);
+ Stream &ThenPopulateRandGaussian(double mean, double stddev,
+ DeviceMemory<double> *values);
+
+ // Entrain onto the stream: a memcpy to a host destination from a GPU source
+ // of the given target size. host_dst must be a pointer to host memory
+ // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
+ // then registered with StreamExecutor::HostMemoryRegister.
+ Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size);
+
+ // Entrain onto the stream: a memcpy to a GPU destination from a host source
+ // of the given target size. host_src must be a pointer to host memory
+ // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
+ // then registered with StreamExecutor::HostMemoryRegister.
+ Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size);
+
+ // Alternative interface for memcpying from device to host that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <typename T>
+ Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
+ port::MutableArraySlice<T> host_dst) {
+ auto host_size = host_dst.size() * sizeof(T);
+ CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
+ return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
+ }
+
+ // Alternative interface for memcpying from host to device that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <typename T>
+ Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
+ DeviceMemory<T> *gpu_dst) {
+ auto host_size = host_src.size() * sizeof(T);
+ CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
+ return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
+ }
+
+ // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
+ // of the given target size. gpu_src/dst must be pointers to GPU memory and
+ // peer access must be enabled between their owning StreamExecutors.
+ Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size);
+
+ // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
+ // ensuring that the host pointer isn't getting confused accidentally with a
+ // device pointer if you're not doing metaprogramming against the API.
+ Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size) {
+ return ThenMemcpy(gpu_dst, gpu_src, size);
+ }
+
+ // Entrain onto the stream: a memset of zero at a GPU location of size
+ // bytes.
+ // The location must not be null.
+ // TODO(leary) Presently the size must be a 4-byte multiple.
+ Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
+
+ // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location
+ // of
+ // size bytes, where bytes must be evenly 32-bit sized (i.e. evently
+ // divisible
+ // by 4). The location must not be null.
+ Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,
+ uint64 size);
+
+ // (Synchronously) block the host code waiting for the operations entrained
+ // on
+ // the stream (enqueued to this point in program execution) to complete.
+ bool BlockHostUntilDone();
+
+ // Warning! This method interacts with internal threads in
+ // sometimes-unpredictable ways and is intended for GPU-Executor-internal
+ // use
+ // only. Please check with a member of the FASTR team before making use of
+ // this method.
+ //
+ // Entrains onto the stream a function to be executed on the host at some
+ // point in the future.
+ // Async host callbacks DO NOT block the stream as device functions (or as
+ // synchronous host callbacks). No synchronization is possible with
+ // asynchronous callbacks; they are strictly fire-and-forget.
+ // This method is private due to the potential for undefined behavior with
+ // synchronization using OpenCL user events.
+ // The ONLY lifetime guarantee in these calls is that the StreamExecutor
+ // parameter will still be valid - this Stream may not be!
+ // Any callbacks requiring device API calls must use this method.
+ Stream &ThenEnqueueOnBackgroundThread(
+ std::function<void(StreamExecutor *)> task);
+
+ // Returns the (opaque) platform-specific backing object. Ownership is not
+ // transferred to the caller.
+ internal::StreamInterface *implementation() { return implementation_.get(); }
+
+ // Entrains onto the stream a callback to the host (from the device).
+ // Host callbacks block/occupy the stream just as device functions
+ // (execute one at a time, block later stream operations).
+ // Behavior is undefined when synchronizing using OpenCL user events.
+ // Behavior is undefined if host callbacks call device routines or insert
+ // them into any stream.
+ // On certain platforms, ThenDoHostCallback is expected to have significant
+ // negative effects on performance.
+ Stream &ThenDoHostCallback(std::function<void()> callback);
+
+ // Identical to ThenDoHostCallback; only exposed for testing purposes.
+ Stream &ThenDoHostCallbackForTest(std::function<void()> callback);
+
+ // Returns the StreamExecutor (parent object) associated with this stream.
+ StreamExecutor *parent() const {
+ CHECK(parent_ != nullptr);
+ return parent_;
+ }
+
+ // Returns the (internal usage) temporary-memory-allocation manager associated
+ // with this stream.
+ internal::TemporaryMemoryManager *temporary_memory_manager();
+
+ private:
+ friend class host::HostBlas; // for parent_.
+ friend class host::HostFft; // for parent_.
+ friend class host::HostRng; // for parent_.
+ template <typename... Args>
+ friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
+ friend class ocl::CLBlas; // for parent_.
+
+ bool InErrorState() const {
+ shared_lock lock{mu_};
+ return !ok_;
+ }
+
+ // Sets the error state if operation_retcode is false.
+ // This is a useful shorthand for many stream routines.
+ void CheckError(bool operation_retcode) {
+ if (operation_retcode) {
+ return;
+ }
+ mutex_lock lock{mu_};
+ ok_ = false;
+ }
+
+ void SetError() { CheckError(false /* = operation_retcode */); }
+
+ // The platform-dependent implementation that the StreamExecutor interface
+ // delegates to.
+ std::unique_ptr<internal::StreamInterface> implementation_;
+
+ // The StreamExecutor that supports the operation of this stream.
+ StreamExecutor *parent_;
+
+ // mutex that guards the allocation / error state flags.
+ // Mutable so that it can be obtained via const reader lock.
+ mutable mutex mu_;
+
+ // Whether Init() was successfully called to allocate this stream on the
+ // underlying platform. It simply flips from 0 to 1 with a sanity check.
+ // See StreamExecutor::AllocateStream.
+ bool allocated_ GUARDED_BY(mu_);
+
+ // Whether all operations have entrained successfully to the current program
+ // point.
+ bool ok_ GUARDED_BY(mu_);
+
+ // Sub-streams that are generated from this stream. Each element has a pointer
+ // to sub-stream and a boolean value indicating if this substream is ready to
+ // be reused.
+ std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
+ GUARDED_BY(mu_);
+
+ // Streams can allocate temporary memories to help with work they enqueue
+ // (e.g. for scratch memory spaces). This member tracks those allocations and
+ // notes when they can be reclaimed -- reclamation is attempted when
+ // BlockHostUntilDone() is called.
+ internal::TemporaryMemoryManager temporary_memory_manager_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(Stream);
+};
+
+////////////
+// Inlines
+
+template <typename T>
+inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
+Stream::AllocateTemporaryArray(uint64 element_count) {
+ return temporary_memory_manager_.AllocateArray<T>(element_count);
+}
+
+inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
+ return &temporary_memory_manager_;
+}
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
diff --git a/tensorflow/stream_executor/stream_executor.h b/tensorflow/stream_executor/stream_executor.h
new file mode 100644
index 0000000000..3bccaec5e3
--- /dev/null
+++ b/tensorflow/stream_executor/stream_executor.h
@@ -0,0 +1,50 @@
+// The StreamExecutor is a single-device abstraction for:
+//
+// * Loading/launching data-parallel-kernels
+// * Invoking pre-canned high-performance library routines (like matrix
+// multiply)
+//
+// The appropriately-typed kernel and "loader spec" are automatically generated
+// for the user within a namespace by the gcudacc compiler output, so typical
+// use looks like so:
+//
+// namespace gpu = ::perftools::gputools;
+// namespace gcudacc = ::platforms::gpus::gcudacc;
+//
+// gpu::StreamExecutor stream_exec{PlatformKind::kCuda};
+// gcudacc::kernel::MyKernel my_kernel{&stream_exec};
+// bool ok = stream_exec.GetKernel(gcudacc::spec::MyKernelSpec(),
+// &my_kernel);
+// if (!ok) { ... }
+// gpu::DeviceMemory<int> result = stream_exec.AllocateZeroed<int>();
+// if (result == nullptr) { ... }
+// int host_result;
+// gpu::Stream my_stream{&stream_exec};
+// my_stream
+// .Init()
+// .ThenLaunch(ThreadDim{1024}, BlockDim{1}, my_kernel, result)
+// .ThenMemcpy(&host_result, result, sizeof(host_result))
+// .BlockHostUntilDone()
+// if (!my_stream.ok()) { ... }
+// printf("%d\n", host_result);
+//
+// Since the device may operate asynchronously to the host, the
+// Stream::BlockHostUntilDone() call forces the calling host thread to wait for
+// the chain of commands specified for the Stream to complete execution.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
+
+#include "tensorflow/stream_executor/device_description.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/device_memory.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/device_options.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/event.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/kernel.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/kernel_spec.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/launch_dim.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/platform.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/stream.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/stream_executor_pimpl.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/timer.h" // IWYU pragma: export
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc
new file mode 100644
index 0000000000..b2785e0874
--- /dev/null
+++ b/tensorflow/stream_executor/stream_executor_internal.cc
@@ -0,0 +1,65 @@
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+
+// -- CUDA
+
+StreamExecutorFactory* MakeCUDAExecutorImplementation() {
+ static StreamExecutorFactory instance;
+ return &instance;
+}
+EventFactory* MakeCUDAEventImplementation() {
+ static EventFactory instance;
+ return &instance;
+}
+StreamFactory* MakeCUDAStreamImplementation() {
+ static StreamFactory instance;
+ return &instance;
+}
+TimerFactory* MakeCUDATimerImplementation() {
+ static TimerFactory instance;
+ return &instance;
+}
+KernelFactory* MakeCUDAKernelImplementation() {
+ static KernelFactory instance;
+ return &instance;
+}
+
+// -- OpenCL
+
+StreamExecutorFactory* MakeOpenCLExecutorImplementation() {
+ static StreamExecutorFactory instance;
+ return &instance;
+}
+StreamExecutorFactory* MakeOpenCLAlteraExecutorImplementation() {
+ static StreamExecutorFactory instance;
+ return &instance;
+}
+StreamFactory* MakeOpenCLStreamImplementation() {
+ static StreamFactory instance;
+ return &instance;
+}
+TimerFactory* MakeOpenCLTimerImplementation() {
+ static TimerFactory instance;
+ return &instance;
+}
+KernelFactory* MakeOpenCLKernelImplementation() {
+ static KernelFactory instance;
+ return &instance;
+}
+
+// -- Host
+
+StreamExecutorFactory MakeHostExecutorImplementation;
+StreamFactory MakeHostStreamImplementation;
+TimerFactory MakeHostTimerImplementation;
+
+
+} // namespace internal
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
new file mode 100644
index 0000000000..5b4e596cfe
--- /dev/null
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -0,0 +1,364 @@
+// Interfaces for platform-dependent implementations to satisfy. This are
+// delegated to from the StreamExecutor in pointer-to-implementation style; i.e.
+// the StreamExecutor is just a husk that delegates calls to the
+// platform-specific objects which implement the interfaces defined here.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/stream_executor/device_description.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/device_options.h"
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/kernel.h"
+#include "tensorflow/stream_executor/kernel_cache_config.h"
+#include "tensorflow/stream_executor/kernel_spec.h"
+#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/shared_memory_config.h"
+#include "tensorflow/stream_executor/trace_listener.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
+
+namespace perftools {
+namespace gputools {
+
+class KernelBase;
+class Stream;
+class Timer;
+
+namespace blas {
+class BlasSupport;
+} // namespace blas
+
+namespace fft {
+class Support;
+} // namespace fft
+
+namespace rng {
+class RngSupport;
+} // namespace rng
+
+} // namespace gputools
+} // namespace perftools
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+
+// Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
+//
+// Various platforms will provide an implementation that satisfy this interface.
+class StreamExecutorInterface {
+ public:
+ // Default constructor for the abstract interface.
+ StreamExecutorInterface() {}
+
+ // Default destructor for the abstract interface.
+ virtual ~StreamExecutorInterface() {}
+
+ // Returns the (transitively) wrapped executor if this executor is
+ // wrapping another executor; otherwise, returns this.
+ virtual StreamExecutorInterface *GetUnderlyingExecutor() { return this; }
+
+ // See the StreamExecutor interface for comments on the same-named methods.
+ virtual port::Status Init(int device_ordinal,
+ DeviceOptions device_options) = 0;
+ virtual bool GetKernel(const MultiKernelLoaderSpec &spec,
+ KernelBase *kernel) {
+ return false;
+ }
+ virtual bool Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &k,
+ const std::vector<KernelArg> &args) {
+ return false;
+ }
+ virtual void *Allocate(uint64 size) = 0;
+ virtual void *AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset,
+ uint64 size) = 0;
+ virtual void Deallocate(DeviceMemoryBase *mem) = 0;
+ virtual void *HostMemoryAllocate(uint64 size) = 0;
+ virtual void HostMemoryDeallocate(void *mem) = 0;
+ virtual bool HostMemoryRegister(void *mem, uint64 size) = 0;
+ virtual bool HostMemoryUnregister(void *mem) = 0;
+ virtual bool SynchronizeAllActivity() = 0;
+ virtual bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) = 0;
+ virtual bool SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) = 0;
+ virtual bool SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
+ const void *host_src, uint64 size) = 0;
+ virtual bool SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) = 0;
+ virtual bool SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) = 0;
+ virtual bool MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) = 0;
+ virtual bool Memset32(Stream *stream, DeviceMemoryBase *location,
+ uint32 pattern, uint64 size) = 0;
+ virtual bool Memcpy(Stream *stream, void *host_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size) = 0;
+ virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
+ const void *host_src, uint64 size) = 0;
+ virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &host_src,
+ uint64 size) = 0;
+ virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0;
+ virtual port::Status AllocateEvent(Event *event) = 0;
+ virtual port::Status DeallocateEvent(Event *event) = 0;
+ virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
+ virtual port::Status WaitForEvent(Stream *stream, Event *event) = 0;
+ virtual Event::Status PollForEventStatus(Event *event) = 0;
+ virtual bool AllocateStream(Stream *stream) = 0;
+ virtual void DeallocateStream(Stream *stream) = 0;
+ virtual bool CreateStreamDependency(Stream *dependent, Stream *other) = 0;
+ virtual bool AllocateTimer(Timer *timer) = 0;
+ virtual void DeallocateTimer(Timer *timer) = 0;
+ virtual bool StartTimer(Stream *stream, Timer *timer) = 0;
+ virtual bool StopTimer(Stream *stream, Timer *timer) = 0;
+ virtual bool BlockHostUntilDone(Stream *stream) = 0;
+ virtual int PlatformDeviceCount() = 0;
+ virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0;
+ virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0;
+ virtual SharedMemoryConfig GetDeviceSharedMemoryConfig() = 0;
+ virtual port::Status SetDeviceSharedMemoryConfig(
+ SharedMemoryConfig config) = 0;
+
+ virtual bool DeviceMemoryUsage(int64 *free, int64 *total) const {
+ return false;
+ }
+
+ // Retrieves device pointer and size for a symbol. The device pointer is
+ // stored at mem, and the size is stored at size. Either mem or bytes can be
+ // null, however, both of them cannot be null at the same time. To use
+ // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol
+ // is found.
+ virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) {
+ return false;
+ }
+
+ // Creates a new DeviceDescription object. Ownership is transferred to the
+ // caller.
+ virtual DeviceDescription *PopulateDeviceDescription() const = 0;
+
+ virtual KernelArg DeviceMemoryToKernelArg(
+ const DeviceMemoryBase &gpu_mem) const = 0;
+
+ // Attempts to register the provided TraceListener with the device-specific
+ // Executor implementation. When this is called, the PIMPL interface has
+ // already taken ownership of the object and is managing the generic tracing
+ // events. The device-specific implementation must determine if the passed
+ // listener is of a type appropriate for it to trace during registration (and
+ // before dispatching events to it).
+ // Returns true if the listener was successfully registered, false otherwise.
+ // Does not take ownership of listener.
+ virtual bool RegisterTraceListener(TraceListener* listener) { return false; }
+
+ // Unregisters the specified listener from the device-specific Executor.
+ // Returns true if the listener was successfully registered, false otherwise.
+ virtual bool UnregisterTraceListener(TraceListener* listener) {
+ return false;
+ }
+
+ // Returns whether this StreamExecutor has BLAS support for its underlying
+ // platform.
+ virtual bool SupportsBlas() const { return false; }
+
+ // Creates a new BlasSupport object, ownership is transferred to the caller.
+ // If SupportsBlas() is false, this will always return null.
+ //
+ // If SupportsBlas() is true, this may return null, for example, if the BLAS
+ // initialization fails.
+ virtual blas::BlasSupport *CreateBlas() { return nullptr; }
+
+ // Returns whether this StreamExecutor has FFT support for its underlying
+ // platform.
+ virtual bool SupportsFft() const { return false; }
+
+ // Creates a new fft::FftSupport object, ownership is transferred to the
+ // caller.
+ // If SupportsFft() is false, this will always return null.
+ //
+ // If SupportsFft() is true, this may return null, for example, if the FFT
+ // initialization fails.
+ virtual fft::FftSupport *CreateFft() { return nullptr; }
+
+ // Returns whether this StreamExecutor has Random Number Generation support
+ // for
+ // its underlying platform.
+ virtual bool SupportsRng() const { return false; }
+
+ // Returns whether this StreamExecutor has neural net support for its
+ // underlying
+ // platform.
+ virtual bool SupportsDnn() const { return false; }
+
+ // Creates a new RngSupport object, ownership is transferred to the caller.
+ // If SupportsRng() is false, this will always return null.
+ //
+ // If SupportsRng() is true, this may return null, for example, if the RNG
+ // initialization fails.
+ virtual rng::RngSupport *CreateRng() { return nullptr; }
+
+ // Creates a new DnnSupport object, ownership is transferred to the caller.
+ // If SupportsDnn() is false, this will always return null.
+ //
+ // If SupportsDnn() is true, this may return null, for example, if the RNG
+ // initialization fails.
+ virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
+
+ // Please read the warning below. This method is only temporary. See
+ // http://b/15759750
+ //
+ // Returns the CUDA context associated with this StreamExecutor platform
+ // implementation.
+ //
+ // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
+ // fatal error if it is not. This hack is made available solely for use from
+ // distbelief code, which temporarily has strong ties to CUDA as a platform.
+ virtual void *CudaContextHack() { return nullptr; }
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
+};
+
+// Pointer-to-implementation object type (i.e. the KernelBase class delegates to
+// this interface) with virtual destruction. This class exists for the
+// platform-dependent code to hang any kernel data/resource info/functionality
+// off of.
+class KernelInterface {
+ public:
+ // Default constructor for the abstract interface.
+ KernelInterface() {}
+
+ // Default destructor for the abstract interface.
+ virtual ~KernelInterface() {}
+
+ // Returns the number of formal parameters that this kernel accepts.
+ virtual unsigned Arity() const = 0;
+
+ // Sets the preferred cache configuration.
+ virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
+
+ // Gets the preferred cache configuration.
+ virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
+};
+
+// Platform-dependent interface class for the generic Events interface, in
+// the PIMPL style.
+class EventInterface {
+ public:
+ EventInterface() {}
+ virtual ~EventInterface() {}
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
+};
+
+// Pointer-to-implementation object type (i.e. the Stream class delegates to
+// this interface) with virtual destruction. This class exists for the
+// platform-dependent code to hang any kernel data/resource info/functionality
+// off of.
+class StreamInterface {
+ public:
+ // Default constructor for the abstract interface.
+ StreamInterface() {}
+
+ // Default destructor for the abstract interface.
+ virtual ~StreamInterface() {}
+
+ // Please read the warning below. This method is only temporary. See
+ // http://b/15759750
+ //
+ // Returns the CUDA stream associated with this platform's stream
+ // implementation.
+ //
+ // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
+ // fatal error if it is not. This hack is made available solely for use from
+ // distbelief code, which temporarily has strong ties to CUDA as a platform.
+ virtual void *CudaStreamHack() { return nullptr; }
+
+ // Please read the warning above. This method is only temporary. See
+ // http://b/15759750
+ //
+ // See the above comment on CudaStreamHack -- this further breaks abstraction
+ // for Eigen within distbelief, which has strong ties to CUDA as a platform,
+ // and a historical attachment to a programming model which takes a
+ // stream-slot rather than a stream-value.
+ virtual void **CudaStreamMemberHack() { return nullptr; }
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
+};
+
+// Pointer-to-implementation object type (i.e. the Timer class delegates to
+// this interface) with virtual destruction. This class exists for the
+// platform-dependent code to hang any timer data/resource info/functionality
+// off of.
+class TimerInterface {
+ public:
+ // Default constructor for the abstract interface.
+ TimerInterface() {}
+
+ // Default destructor for the abstract interface.
+ virtual ~TimerInterface() {}
+
+ // Returns the number of microseconds elapsed in a completed timer.
+ virtual uint64 Microseconds() const = 0;
+
+ // Returns the number of nanoseconds elapsed in a completed timer.
+ virtual uint64 Nanoseconds() const = 0;
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
+};
+
+// Extern functions for constructing platform-specific instances that conform to
+// the StreamExecutor interface. (Defining constructor functions extern in this
+// way prevents CUDA/OpenCL headers from leaking into any shared header files.)
+//
+// TODO(leary) switch this all over to registries.
+
+using StreamExecutorFactory =
+ std::function<StreamExecutorInterface *(const PluginConfig &)>;
+using EventFactory = std::function<EventInterface *(StreamExecutor *)>;
+using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>;
+using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>;
+using KernelFactory = std::function<KernelInterface*()>;
+
+EventFactory* MakeCUDAEventImplementation();
+StreamExecutorFactory* MakeCUDAExecutorImplementation();
+StreamFactory* MakeCUDAStreamImplementation();
+TimerFactory* MakeCUDATimerImplementation();
+KernelFactory* MakeCUDAKernelImplementation();
+
+StreamExecutorFactory* MakeOpenCLExecutorImplementation();
+StreamExecutorFactory* MakeOpenCLAlteraExecutorImplementation();
+StreamFactory* MakeOpenCLStreamImplementation();
+TimerFactory* MakeOpenCLTimerImplementation();
+KernelFactory* MakeOpenCLKernelImplementation();
+
+extern StreamExecutorFactory MakeHostExecutorImplementation;
+extern StreamFactory MakeHostStreamImplementation;
+extern TimerFactory MakeHostTimerImplementation;
+
+
+} // namespace internal
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
new file mode 100644
index 0000000000..22b7a50b79
--- /dev/null
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -0,0 +1,642 @@
+// Implements the StreamExecutor interface by passing through to its
+// implementation_ value (in pointer-to-implementation style), which
+// implements StreamExecutorInterface.
+
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+#include <atomic>
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/fft.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/lib/threadpool.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace {
+bool FLAGS_check_gpu_leaks = false;
+} // namespace
+
+namespace perftools {
+namespace gputools {
+namespace {
+
+// Maximum stack depth to report when generating backtrace on mem allocation
+// (for GPU memory leak checker)
+static const int kMaxStackDepth = 256;
+
+// Make sure the executor is done with its work; we know (because this isn't
+// publicly visible) that all enqueued work is quick.
+void BlockOnThreadExecutor(port::ThreadPool *executor) {
+ port::Notification n;
+ executor->Schedule([&n]() { n.Notify(); });
+ n.WaitForNotification();
+}
+
+internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
+ PlatformKind platform_kind, const PluginConfig &plugin_config) {
+ // Note: we use this factory-assignment-in-switch pattern instead of just
+ // invoking the callable in case linkage is messed up -- instead of invoking a
+ // nullptr std::function (due to failed registration) we give a nice
+ // LOG(FATAL) message.
+ internal::StreamExecutorFactory factory;
+ switch (platform_kind) {
+ case PlatformKind::kCuda:
+ factory = *internal::MakeCUDAExecutorImplementation();
+ break;
+ case PlatformKind::kOpenCL:
+ factory = *internal::MakeOpenCLExecutorImplementation();
+ break;
+ case PlatformKind::kOpenCLAltera:
+ factory = *internal::MakeOpenCLAlteraExecutorImplementation();
+ break;
+ case PlatformKind::kHost:
+ factory = internal::MakeHostExecutorImplementation;
+ break;
+ default:
+ factory = nullptr;
+ }
+ if (factory == nullptr) {
+ LOG(FATAL)
+ << "cannot create GPU executor implementation for platform kind: "
+ << PlatformKindString(platform_kind);
+ }
+ return factory(plugin_config);
+}
+
+std::atomic_int_fast64_t correlation_id_generator(0);
+
+} // namespace
+
+template <typename BeginCallT, typename CompleteCallT,
+ typename ReturnT, typename... BeginArgsT>
+class ScopedTracer {
+ public:
+ ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
+ CompleteCallT complete_call, const ReturnT *result,
+ BeginArgsT... begin_args)
+ : stream_exec_(stream_exec),
+ complete_call_(complete_call),
+ result_(result) {
+ if (stream_exec_->tracing_enabled_) {
+ correlation_id_ =
+ correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
+ Trace(begin_call, begin_args...);
+ }
+ }
+
+ ~ScopedTracer() {
+ if (stream_exec_->tracing_enabled_) {
+ Trace(complete_call_, result_);
+ }
+ }
+
+ private:
+ template <typename CallbackT, typename... TraceArgsT>
+ void Trace(CallbackT callback, TraceArgsT... args) {
+ {
+ // Instance tracers held in a block to limit the lock lifetime.
+ shared_lock lock{stream_exec_->mu_};
+ for (TraceListener *listener : stream_exec_->listeners_) {
+ (listener->*callback)(correlation_id_,
+ std::forward<TraceArgsT>(args)...);
+ }
+ }
+ }
+
+ StreamExecutor *stream_exec_;
+ CompleteCallT complete_call_;
+ const ReturnT* result_;
+ int64 correlation_id_;
+};
+
+template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
+ typename... BeginArgsT>
+ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
+MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
+ CompleteCallT complete_call, ReturnT *result,
+ BeginArgsT... begin_args) {
+ return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
+ stream_exec, begin_call, complete_call, result,
+ std::forward<BeginArgsT>(begin_args)...);
+}
+
+#define SCOPED_TRACE(LOC, ...) \
+ auto tracer = MakeScopedTracer(this, &LOC ## Begin, \
+ &LOC ## Complete, ## __VA_ARGS__);
+
+/* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED};
+
+StreamExecutor::StreamExecutor(PlatformKind platform_kind,
+ const PluginConfig &plugin_config)
+ : implementation_(StreamExecutorImplementationFromPlatformKind(
+ platform_kind, plugin_config)),
+ platform_kind_(platform_kind),
+ device_ordinal_(-1),
+ background_threads_(new port::ThreadPool(
+ port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
+ live_stream_count_(0),
+ tracing_enabled_(false) {
+ CheckPlatformKindIsValid(platform_kind);
+}
+
+StreamExecutor::StreamExecutor(
+ PlatformKind platform_kind,
+ internal::StreamExecutorInterface *implementation)
+ : implementation_(implementation),
+ platform_kind_(platform_kind),
+ device_ordinal_(-1),
+ background_threads_(new port::ThreadPool(
+ port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
+ live_stream_count_(0),
+ tracing_enabled_(false) {
+ CheckPlatformKindIsValid(platform_kind);
+}
+
+StreamExecutor::~StreamExecutor() {
+ BlockOnThreadExecutor(background_threads_.get());
+
+ if (live_stream_count_.load() != 0) {
+ LOG(WARNING) << "Not all streams were deallocated at executor destruction "
+ << "time. This may lead to unexpected/bad behavior - "
+ << "especially if any stream is still active!";
+ }
+
+ if (FLAGS_check_gpu_leaks) {
+ for (auto it : mem_allocs_) {
+ LOG(INFO) << "Memory alloced at executor exit: addr: "
+ << port::Printf("%p", it.first)
+ << ", bytes: " << it.second.bytes << ", trace: \n"
+ << it.second.stack_trace;
+ }
+ }
+}
+
+port::Status StreamExecutor::Init(int device_ordinal,
+ DeviceOptions device_options) {
+ device_ordinal_ = device_ordinal;
+ return implementation_->Init(device_ordinal, device_options);
+}
+
+port::Status StreamExecutor::Init() {
+ return Init(0, DeviceOptions::Default());
+}
+
+bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
+ KernelBase *kernel) {
+ return implementation_->GetKernel(spec, kernel);
+}
+
+void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
+ VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
+ << ") mem->size()=" << mem->size();
+
+ if (mem->opaque() != nullptr) {
+ EraseAllocRecord(mem->opaque());
+ }
+ implementation_->Deallocate(mem);
+ mem->Reset(nullptr, 0);
+}
+
+void StreamExecutor::GetMemAllocs(std::map<void *, AllocRecord> *records_out) {
+ shared_lock lock{mu_};
+ *records_out = mem_allocs_;
+}
+
+bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) {
+ return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
+}
+
+port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) {
+ return implementation_->EnablePeerAccessTo(other->implementation_.get());
+}
+
+SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() {
+ return implementation_->GetDeviceSharedMemoryConfig();
+}
+
+port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
+ SharedMemoryConfig config) {
+ if (config != SharedMemoryConfig::kDefault &&
+ config != SharedMemoryConfig::kFourByte &&
+ config != SharedMemoryConfig::kEightByte) {
+ string error_msg = port::Printf(
+ "Invalid shared memory config specified: %d", static_cast<int>(config));
+ LOG(ERROR) << error_msg;
+ return port::Status{port::error::INVALID_ARGUMENT, error_msg};
+ }
+ return implementation_->SetDeviceSharedMemoryConfig(config);
+}
+
+const DeviceDescription &StreamExecutor::GetDeviceDescription() const {
+ mutex_lock lock{mu_};
+ if (device_description_ != nullptr) {
+ return *device_description_;
+ }
+
+ device_description_.reset(PopulateDeviceDescription());
+ return *device_description_;
+}
+
+int StreamExecutor::PlatformDeviceCount() const {
+ return implementation_->PlatformDeviceCount();
+}
+
+bool StreamExecutor::SupportsBlas() const {
+ return implementation_->SupportsBlas();
+}
+
+bool StreamExecutor::SupportsRng() const {
+ return implementation_->SupportsRng();
+}
+
+bool StreamExecutor::SupportsDnn() const {
+ return implementation_->SupportsDnn();
+}
+
+dnn::DnnSupport *StreamExecutor::AsDnn() {
+ mutex_lock lock{mu_};
+ if (dnn_ != nullptr) {
+ return dnn_.get();
+ }
+
+ dnn_.reset(implementation_->CreateDnn());
+ return dnn_.get();
+}
+
+blas::BlasSupport *StreamExecutor::AsBlas() {
+ mutex_lock lock{mu_};
+ if (blas_ != nullptr) {
+ return blas_.get();
+ }
+
+ blas_.reset(implementation_->CreateBlas());
+ return blas_.get();
+}
+
+fft::FftSupport *StreamExecutor::AsFft() {
+ mutex_lock lock{mu_};
+ if (fft_ != nullptr) {
+ return fft_.get();
+ }
+
+ fft_.reset(implementation_->CreateFft());
+ return fft_.get();
+}
+
+rng::RngSupport *StreamExecutor::AsRng() {
+ mutex_lock lock{mu_};
+ if (rng_ != nullptr) {
+ return rng_.get();
+ }
+
+ rng_.reset(implementation_->CreateRng());
+ return rng_.get();
+}
+
+bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims,
+ const KernelBase &kernel,
+ const std::vector<KernelArg> &args) {
+ SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
+ kernel, args);
+
+ return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
+}
+
+bool StreamExecutor::BlockHostUntilDone(Stream *stream) {
+ bool result;
+ SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
+
+ result = implementation_->BlockHostUntilDone(stream);
+ return result;
+}
+
+void *StreamExecutor::Allocate(uint64 size) {
+ void *buf = implementation_->Allocate(size);
+ VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
+ << ") returns " << buf;
+ CreateAllocRecord(buf, size);
+
+ return buf;
+}
+
+bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
+ size_t *bytes) {
+ return implementation_->GetSymbol(symbol_name, mem, bytes);
+}
+
+void *StreamExecutor::HostMemoryAllocate(uint64 size) {
+ void *buffer = implementation_->HostMemoryAllocate(size);
+ VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
+ << ") returns " << buffer;
+ return buffer;
+}
+
+void StreamExecutor::HostMemoryDeallocate(void *location) {
+ VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location="
+ << location << ")";
+
+ return implementation_->HostMemoryDeallocate(location);
+}
+
+bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
+ VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
+ << ", size=" << size << ")";
+ if (location == nullptr || size == 0) {
+ LOG(WARNING) << "attempting to register null or zero-sized memory: "
+ << location << "; size " << size;
+ }
+ return implementation_->HostMemoryRegister(location, size);
+}
+
+bool StreamExecutor::HostMemoryUnregister(void *location) {
+ VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
+ << ")";
+ return implementation_->HostMemoryUnregister(location);
+}
+
+bool StreamExecutor::SynchronizeAllActivity() {
+ VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()";
+ bool ok = implementation_->SynchronizeAllActivity();
+
+ // This should all be quick and infallible work, so we can perform the
+ // synchronization even in the case of failure.
+ BlockOnThreadExecutor(background_threads_.get());
+
+ return ok;
+}
+
+bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
+ uint64 size) {
+ VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location="
+ << location << ", size=" << size << ")";
+
+ return implementation_->SynchronousMemZero(location, size);
+}
+
+bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) {
+ VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location="
+ << location << ", value=" << value << ", size=" << size << ")";
+
+ return implementation_->SynchronousMemSet(location, value, size);
+}
+
+bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
+ const void *host_src, uint64 size) {
+ VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(gpu_dst="
+ << gpu_dst->opaque() << ", host_src=" << host_src << ", size=" << size
+ << ") H2D";
+
+ // Tracing overloaded methods is very difficult due to issues with type
+ // inference on template args. Since use of these overloaded methods is
+ // discouraged anyway, this isn't a huge deal.
+ return implementation_->SynchronousMemcpy(gpu_dst, host_src, size);
+}
+
+bool StreamExecutor::SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) {
+ VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst="
+ << host_dst << ", gpu_src=" << gpu_src.opaque() << ", size=" << size
+ << ") D2H";
+
+ return implementation_->SynchronousMemcpy(host_dst, gpu_src, size);
+}
+
+bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) {
+ VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(gpu_dst="
+ << gpu_dst->opaque() << ", gpu_src=" << gpu_src.opaque() << ", size=" << size
+ << ") D2D";
+
+ return implementation_->SynchronousMemcpyDeviceToDevice(gpu_dst, gpu_src,
+ size);
+}
+
+port::Status StreamExecutor::SynchronousMemcpyD2H(
+ const DeviceMemoryBase &gpu_src, int64 size, void *host_dst) {
+ VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(gpu_src="
+ << gpu_src.opaque() << ", size=" << size << ", host_dst=" << host_dst << ")";
+
+ port::Status result{port::Status::OK()};
+ SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H,
+ &result, gpu_src, size, host_dst);
+
+ if (!implementation_->SynchronousMemcpy(host_dst, gpu_src, size)) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf(
+ "failed to synchronously memcpy device-to-host: GPU %p to host %p "
+ "size %lld",
+ gpu_src.opaque(), host_dst, size)};
+ }
+
+ return result;
+}
+
+port::Status StreamExecutor::SynchronousMemcpyH2D(const void *host_src,
+ int64 size,
+ DeviceMemoryBase *gpu_dst) {
+ VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src="
+ << host_src << ", size=" << size << ", gpu_dst" << gpu_dst->opaque() << ")";
+
+ port::Status result{port::Status::OK()};
+ SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D,
+ &result, host_src, size, gpu_dst);
+
+ if (!implementation_->SynchronousMemcpy(gpu_dst, host_src, size)) {
+ result = port::Status{
+ port::error::INTERNAL,
+ port::Printf("failed to synchronously memcpy host-to-device: host "
+ "%p to GPU %p size %lld",
+ host_src, gpu_dst->opaque(), size)};
+ }
+
+ return result;
+}
+
+bool StreamExecutor::Memcpy(Stream *stream, void *host_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size) {
+ return implementation_->Memcpy(stream, host_dst, gpu_src, size);
+}
+
+bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
+ const void *host_src, uint64 size) {
+ return implementation_->Memcpy(stream, gpu_dst, host_src, size);
+}
+
+bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream,
+ DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) {
+ return implementation_->MemcpyDeviceToDevice(stream, gpu_dst, gpu_src, size);
+}
+
+bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) {
+ return implementation_->MemZero(stream, location, size);
+}
+
+bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
+ uint32 pattern, uint64 size) {
+ CHECK_EQ(0, size % 4)
+ << "need 32-bit multiple size to fill with 32-bit pattern";
+ return implementation_->Memset32(stream, location, pattern, size);
+}
+
+bool StreamExecutor::HostCallback(Stream *stream,
+ std::function<void()> callback) {
+ return implementation_->HostCallback(stream, callback);
+}
+
+port::Status StreamExecutor::AllocateEvent(Event *event) {
+ return implementation_->AllocateEvent(event);
+}
+
+port::Status StreamExecutor::DeallocateEvent(Event *event) {
+ return implementation_->DeallocateEvent(event);
+}
+
+port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) {
+ return implementation_->RecordEvent(stream, event);
+}
+
+port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) {
+ return implementation_->WaitForEvent(stream, event);
+}
+
+Event::Status StreamExecutor::PollForEventStatus(Event *event) {
+ return implementation_->PollForEventStatus(event);
+}
+
+bool StreamExecutor::AllocateStream(Stream *stream) {
+ live_stream_count_.fetch_add(1, std::memory_order_relaxed);
+ if (!implementation_->AllocateStream(stream)) {
+ auto count = live_stream_count_.fetch_sub(1);
+ CHECK_GE(count, 0) << "live stream count should not dip below zero";
+ LOG(INFO) << "failed to allocate stream; live stream count: " << count;
+ return false;
+ }
+
+ return true;
+}
+
+void StreamExecutor::DeallocateStream(Stream *stream) {
+ implementation_->DeallocateStream(stream);
+ CHECK_GE(live_stream_count_.fetch_sub(1), 0)
+ << "live stream count should not dip below zero";
+}
+
+bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
+ return implementation_->CreateStreamDependency(dependent, other);
+}
+
+bool StreamExecutor::AllocateTimer(Timer *timer) {
+ return implementation_->AllocateTimer(timer);
+}
+
+void StreamExecutor::DeallocateTimer(Timer *timer) {
+ return implementation_->DeallocateTimer(timer);
+}
+
+bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) {
+ return implementation_->StartTimer(stream, timer);
+}
+
+bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) {
+ return implementation_->StopTimer(stream, timer);
+}
+
+DeviceDescription *StreamExecutor::PopulateDeviceDescription() const {
+ return implementation_->PopulateDeviceDescription();
+}
+
+bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
+ return implementation_->DeviceMemoryUsage(free, total);
+}
+
+KernelArg StreamExecutor::DeviceMemoryToKernelArg(
+ const DeviceMemoryBase &gpu_mem) const {
+ return implementation_->DeviceMemoryToKernelArg(gpu_mem);
+}
+
+void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
+ background_threads_->Schedule(task);
+}
+
+void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
+ if (FLAGS_check_gpu_leaks && opaque != nullptr && bytes != 0) {
+ mutex_lock lock{mu_};
+ mem_allocs_[opaque] = AllocRecord{
+ bytes, ""};
+ }
+}
+
+void StreamExecutor::EraseAllocRecord(void *opaque) {
+ if (FLAGS_check_gpu_leaks && opaque != nullptr) {
+ mutex_lock lock{mu_};
+ if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
+ LOG(ERROR) << "Deallocating unknown pointer: "
+ << port::Printf("0x%p", opaque);
+ } else {
+ mem_allocs_.erase(opaque);
+ }
+ }
+}
+
+void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
+
+void StreamExecutor::RegisterTraceListener(TraceListener *listener) {
+ {
+ mutex_lock lock{mu_};
+ if (listeners_.find(listener) != listeners_.end()) {
+ LOG(INFO) << "Attempt to register already-registered listener, "
+ << listener;
+ } else {
+ listeners_.insert(listener);
+ }
+ }
+
+ implementation_->RegisterTraceListener(listener);
+}
+
+bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) {
+ {
+ mutex_lock lock{mu_};
+ if (listeners_.find(listener) == listeners_.end()) {
+ LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
+ return false;
+ }
+ listeners_.erase(listener);
+ }
+
+ implementation_->UnregisterTraceListener(listener);
+ return true;
+}
+
+template <typename TraceCallT, typename... ArgsT>
+void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) {
+ if (tracing_enabled_) {
+ {
+ // instance tracers held in a block to limit the lock lifetime.
+ shared_lock lock{mu_};
+ for (TraceListener *listener : listeners_) {
+ (listener->*trace_call)(std::forward<ArgsT>(args)...);
+ }
+ }
+ }
+}
+
+internal::StreamExecutorInterface *StreamExecutor::implementation() {
+ return implementation_->GetUnderlyingExecutor();
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
new file mode 100644
index 0000000000..29ab235d0e
--- /dev/null
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -0,0 +1,725 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
+#define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
+
+#include <atomic>
+#include <set>
+#include <tuple>
+#include <vector>
+
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+#include "tensorflow/stream_executor/lib/threadpool.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "tensorflow/stream_executor/shared_memory_config.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+#include "tensorflow/stream_executor/trace_listener.h"
+
+namespace perftools {
+namespace gputools {
+
+// Structure used for device memory leak checking.
+struct AllocRecord {
+ // The requested allocation size of the buffer.
+ uint64 bytes;
+
+ // Holds a representation of the stack at the time the associated buffer was
+ // allocated. Produced in a form described in
+ // //util/symbolize/symbolized_stacktrace.h.
+ string stack_trace;
+};
+
+// Forward declaration of private friend class.
+template <typename BeginCallT, typename CompleteCallT,
+ typename ReturnT, typename... BeginArgsT>
+class ScopedTracer;
+
+// A StreamExecutor manages a single device, in terms of executing work (kernel
+// launches) and memory management (allocation/deallocation, memory copies to
+// and from the device). It is conceptually the "handle" for a device -- Stream
+// objects, which are used to enqueue work to run on the
+// coprocessor have a StreamExecutor instance as their "parent" object.
+//
+// StreamExecutor objects have an underlying platform that is specified up
+// front;
+// e.g. either it is a CUDA or OpenCL executor.
+//
+// Thread-safe after initialization.
+// StreamExecutor interface should not be invoked from a signal handler.
+class StreamExecutor {
+ public:
+ explicit StreamExecutor(PlatformKind kind,
+ const PluginConfig &plugin_config = PluginConfig());
+
+ // Primarily used for testing.
+ StreamExecutor(PlatformKind kind,
+ internal::StreamExecutorInterface *implementation);
+
+ ~StreamExecutor();
+
+ port::Status Init();
+ port::Status Init(int device_ordinal, DeviceOptions device_options);
+
+ // Returns the platform that this StreamExecutor is acting upon.
+ PlatformKind platform_kind() const { return platform_kind_; }
+
+ // Retrieves (loads) a kernel for the platform this StreamExecutor is acting
+ // upon, if one exists.
+ //
+ // Parameters:
+ // spec: The MultiKernelLoaderSpec is usually generated as a compile-time
+ // constant into an appropriate namespace. For example, see
+ // perftools::gputools::executor_sample::kKernelLoaderSpecs, from which a
+ // MultiKernelLoaderSpec is selected.
+ // kernel: Outparam that the kernel is loaded into. A given Kernel
+ // instantiation should not be loaded into more than once.
+ //
+ // If an error occurs, or there is no kernel available for the StreamExecutor
+ // platform, false is returned.
+ bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel);
+
+ // Synchronously allocates an array on the GPU device of type T with
+ // element_count elements.
+ template <typename T>
+ DeviceMemory<T> AllocateArray(uint64 element_count);
+
+ // As AllocateArray(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) {
+ return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count));
+ }
+
+ // Convenience wrapper that allocates space for a single element of type T
+ // in GPU memory.
+ template <typename T>
+ DeviceMemory<T> AllocateScalar() {
+ return AllocateArray<T>(1);
+ }
+
+ // As AllocateScalar(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedScalar() {
+ return AllocateOwnedArray<T>(1);
+ }
+
+ // Synchronously allocates a scalar of type T on the GPU device that is
+ // (POD) zero-byte initialized.
+ template <typename T>
+ DeviceMemory<T> AllocateZeroed();
+
+ // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedZeroed() {
+ return ScopedDeviceMemory<T>(this, AllocateZeroed<T>());
+ }
+
+ // Allocate a memory region inside another allocated memory region.
+ // Offset and size are specified in terms of T elements.
+ // Warning: Do not free a parent buffer before its sub-buffers; this may cause
+ // use-after-free issues (the specific behavior is not consistent across
+ // platforms).
+ // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a
+ // sub-buffer after parent deallocation is expected to be safe. This will
+ // render your code non-platform-portable, however.
+ template <typename T>
+ DeviceMemory<T> AllocateSubBuffer(DeviceMemory<T> *parent,
+ uint64 element_offset,
+ uint64 element_count);
+
+ // As AllocateSubBuffer(), but returns a ScopedDeviceMemory<T>.
+ template <typename T>
+ ScopedDeviceMemory<T> AllocateOwnedSubBuffer(DeviceMemory<T> *parent,
+ uint64 element_offset,
+ uint64 element_count) {
+ return ScopedDeviceMemory<T>(
+ this, AllocateSubBuffer<T>(parent, element_offset, element_count));
+ }
+
+ // Finds a symbol and returns device memory allocated to the symbol. The
+ // symbol is searched in any kernels that were previously loaded through
+ // GetKernel() before the GetSymbol() call. The user has to make sure that the
+ // type of symbol and T match.
+ // - Note: symbol_name should include its namespace as well. For example,
+ // pass "nms0::symbol" if referring to nms0::symbol.
+ template <typename T>
+ port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name);
+
+ // Deallocate the DeviceMemory previously allocated via this interface.
+ // Deallocation of a nullptr-representative value is permitted.
+ //
+ // Resets the internal contents of mem to be null-representative, but this
+ // null-out effect should not be relied upon in client code.
+ void Deallocate(DeviceMemoryBase *mem);
+
+ // Retrieves a mapping of active opaque GPU memory pointer to a string
+ // representation of the [allocating thread's] stack at the time the pointer
+ // was allocated. Useful for tracking GPU memory leaks.
+ //
+ // Note: this will only be populated if --check_gpu_leaks flag is activated.
+ void GetMemAllocs(std::map<void *, AllocRecord> *records_out);
+
+ // Allocates a region of host memory and registers it with the platform API.
+ // Memory allocated in this manner (or allocated and registered with
+ // HostMemoryRegister() is required for use in asynchronous memcpy operations,
+ // such as Stream::ThenMemcpy.
+ void *HostMemoryAllocate(uint64 bytes);
+
+ // Deallocates a region of host memory allocated by HostMemoryAllocate().
+ void HostMemoryDeallocate(void *location);
+
+ // Registers a region of host memory with the platform API. Registered memory
+ // (or memory allocated with HostMemoryAllocate) is required for use with
+ // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method
+ // is used to register memory allocated outside the StreamExecutor;
+ // HostMemoryAllocate implicitly registers its allocations and
+ // HostMemoryDeallocate implicitly deregisters on deallocation.
+ bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT;
+
+ // Unregisters a region of host memory registered with HostMemoryRegister.
+ // This should be done before deallocating the region with delete[]/free/etc.
+ bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT;
+
+ // Synchronizes all activity occuring in the StreamExecutor's context (most
+ // likely a whole device).
+ bool SynchronizeAllActivity() SE_MUST_USE_RESULT;
+
+ // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the
+ // given location in GPU memory.
+ bool SynchronousMemZero(DeviceMemoryBase *location,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Blocks the caller while "size" bytes are initialized to "value" (in POD
+ // fashion) at the given location in GPU memory.
+ bool SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // [deprecated] Blocks the caller while a data segment of the given size is
+ // copied from the host source to the GPU destination.
+ //
+ // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+ bool SynchronousMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // [deprecated] Blocks the caller while a data segment of the given size is
+ // copied from the GPU source to the host destination.
+ //
+ // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+ bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above.
+ port::Status SynchronousMemcpyH2D(const void *host_src, int64 size,
+ DeviceMemoryBase *gpu_dst);
+
+ // Alternative interface for memcpying from host to device that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <class T>
+ port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src,
+ DeviceMemoryBase *gpu_dst) {
+ auto host_size = host_src.size() * sizeof(T);
+ CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
+ return SynchronousMemcpyH2D(host_src.begin(), host_size, gpu_dst);
+ }
+
+ // Same as SynchronousMemcpy(void*, ...) above.
+ port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &gpu_src, int64 size,
+ void *host_dst);
+
+ // Alternative interface for memcpying from device to host that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <typename T>
+ port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &gpu_src,
+ port::MutableArraySlice<T> host_dst) {
+ auto host_size = host_dst.size() * sizeof(T);
+ CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
+ return SynchronousMemcpyD2H(gpu_src, host_size, host_dst.begin());
+ }
+
+ // Blocks the caller while a data segment of the given size is copied from the
+ // GPU source to the GPU destination.
+ bool SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Enqueues an operation onto stream to zero out size bytes at the given GPU
+ // memory location. Neither stream nor location may be null. Returns whether
+ // the operation was successfully enqueued onto the stream.
+ bool MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Enqueues an operation onto stream to set 32-bit patterns starting at
+ // location, for byte count given by size. size must be 32-bit quantified
+ // (i.e. evently divisible by 4). Returns whether the operation was
+ // successfully enqueued onto the stream.
+ bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
+ uint64 size) SE_MUST_USE_RESULT;
+
+ // Enables peer access from this StreamExecutor to memory
+ // allocated by other, such that launched device code, memcpies, etc may
+ // access it directly.
+ //
+ // Both this StreamExecutor and other must be backed by the same platform (as
+ // in
+ // CUDA vs OpenCL) implementation.
+ port::Status EnablePeerAccessTo(StreamExecutor *other);
+
+ // Returns whether it's possible to enable peer access from this
+ // StreamExecutor
+ // to memory allocated by another.
+ //
+ // Even when this returns true, EnablePeerAccessTo may fail for other reasons;
+ // this is more an up-front test as to whether it's expressly forbidden.
+ bool CanEnablePeerAccessTo(StreamExecutor *other);
+
+ // Gets the preferred shared memory configuration for the device to which this
+ // executor is bound.
+ SharedMemoryConfig GetDeviceSharedMemoryConfig();
+
+ // Sets the preferred shared memory configuration for the device to which this
+ // executor is bound.
+ port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config);
+
+ // Obtains metadata about the underlying device.
+ // The value is cached on first use.
+ const DeviceDescription &GetDeviceDescription() const;
+
+ // Returns the underlying device memory usage information, if it is available.
+ // If it is not available (false is returned), free/total may not be
+ // initialized.
+ //
+ // Note: "Free" reflects the amount of free memory on the underlying device,
+ // so allocations via other StreamExecutors that have the same underlying
+ // device
+ // will be reflected in "free".
+ bool DeviceMemoryUsage(int64 *free, int64 *total) const;
+
+ // The device count reported by this StreamExecutor's platform.
+ // Note: on OpenCL we implicitly select platform zero at the moment.
+ int PlatformDeviceCount() const;
+
+ // Returns whether the StreamExecutor supports BLAS routines for the platform
+ // that underlies this interface.
+ bool SupportsBlas() const;
+
+ // Returns whether the StreamExecutor supports FFT routines for the platform
+ // that underlies this interface.
+ bool SupportsFft() const;
+
+ // Returns whether the StreamExecutor supports RNG routines for the platform
+ // that underlies this interface.
+ bool SupportsRng() const;
+
+ // Returns whether the StreamExecutor support neural net routines for the
+ // platform that underlies this interface.
+ bool SupportsDnn() const;
+
+ // Returns the device ordinal that this StreamExecutor was initialized with.
+ // Meaningless before initialization.
+ int device_ordinal() const { return device_ordinal_; }
+
+ // Returns a borrowed pointer to the underlying StreamExecutor implementation.
+ internal::StreamExecutorInterface *implementation();
+
+ // Warning: use Stream::ThenLaunch instead, this method is not for general
+ // consumption. However, this is the only way to launch a kernel for which
+ // the type signature is only known at runtime; say, if an application
+ // supports loading/launching kernels with arbitrary type signatures.
+ // In this case, the application is expected to know how to do parameter
+ // packing that obeys the contract of the underlying platform implementation.
+ //
+ // Launches a data parallel kernel with the given thread/block
+ // dimensionality and already-packed args/sizes to pass to the underlying
+ // platform driver.
+ //
+ // This is called by Stream::Launch() to delegate to the platform's launch
+ // implementation in StreamExecutorInterface::Launch().
+ bool Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &kernel,
+ const std::vector<KernelArg> &args);
+
+ // Gets-or-creates (creates with memoization) a FftSupport datatype that can
+ // be used to execute FFT routines on the current platform.
+ //
+ // Ownership and user-facing is the same as AsBlas() below.
+ //
+ // Returns null if there was an error initializing the FFT support for the
+ // underlying platform.
+ fft::FftSupport *AsFft();
+
+ // Gets-or-creates (creates with memoization) a DnnSupport datatype that can
+ // be used for neural network routines on the current platform.
+ //
+ // Ownership and user-facing is the same as AsBlas() below.
+ //
+ // Returns null if there was an error initializing the DNN support for the
+ // underlying platform.
+ dnn::DnnSupport *AsDnn();
+
+ // Turns StreamExecutor operation tracing on or off.
+ void EnableTracing(bool enable);
+
+ // Registers a trace listener to receive callbacks for only a single
+ // StreamExecutor instance.
+ // To register a listener for all executors for a given platform, see
+ // Platform::RegisterTraceListener().
+ // Does not take ownership of listener.
+ void RegisterTraceListener(TraceListener* listener);
+
+ // Removes a TraceListener from this StreamExecutor instance.
+ // Returns false (and logs) in cases where the argument listener was not
+ // previously registered.
+ bool UnregisterTraceListener(TraceListener* listener);
+
+ // Converts a DeviceMemory object into a KernelArg object for passing to the
+ // device driver for kernel launch.
+ KernelArg DeviceMemoryToKernelArg(const DeviceMemoryBase &gpu_mem) const;
+
+ private:
+ template <typename BeginCallT, typename CompleteCallT,
+ typename ReturnT, typename... BeginArgsT>
+ friend class ScopedTracer;
+ friend class Event;
+ friend class Stream;
+ friend class Timer;
+ template <typename... Params>
+ friend class TypedKernel;
+ template <typename... Args>
+ friend struct ThenBlasImpl;
+
+ // Gets-or-creates (creates with memoization) a BlasSupport datatype that can
+ // be used to execute BLAS routines on the current platform. This is typically
+ // not user-facing, as users will use the Stream::ThenBlas* family of routines
+ // to entrain BLAS operations. See blas.h for additional details.
+ //
+ // Ownership is not transferred to the caller -- ownership is retained by this
+ // object for memoization. This BLAS interface is also only expected to be
+ // used by a Stream for entraining calls to BLAS functionality.
+ //
+ // Returns null if there was an error initializing the BLAS support for the
+ // underlying platform.
+ blas::BlasSupport *AsBlas();
+
+ // Gets-or-creates (creates with memoization) an RngSupport datatype that can
+ // be used for random-number-generation routines on the current platform.
+ //
+ // Ownership and user-facing is the same as AsBlas() above.
+ //
+ // Returns null if there was an error initializing the RNG support for the
+ // underlying platform.
+ rng::RngSupport *AsRng();
+
+ // Causes the host code to synchronously wait for operations entrained onto
+ // stream to complete. Effectively a join on the asynchronous GPU operations
+ // enqueued on the stream before this program point.
+ bool BlockHostUntilDone(Stream *stream);
+
+ // Synchronously allocates size bytes on the underlying platform and returns
+ // an opaque void* representing that allocation. In the case of failure,
+ // nullptr is returned.
+ void *Allocate(uint64 size);
+
+ // Finds and retrieves device memory for the symbol on the underlying
+ // platform.
+ bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes);
+
+ // Entrains a memcpy operation onto stream, with a host destination location
+ // host_dst and a GPU memory source, with target size size.
+ bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size);
+
+ // Entrains a memcpy operation onto stream, with a GPU destination location
+ // and a host memory source, with target size size.
+ bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size);
+
+ // Entrains a memcpy operation onto stream, with a GPU destination location
+ // and a GPU source location, with target size size. Peer access should have
+ // been enabled between the StreamExecutors owning the GPU memory regions.
+ bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size);
+
+ // Entrains on a stream a user-specified function to be run on the host.
+ // See Stream::ThenDoHostCallback for full details.
+ bool HostCallback(Stream *stream, std::function<void()> callback);
+
+ // Performs platform-specific allocation and initialization of an event.
+ port::Status AllocateEvent(Event *event);
+
+ // Performs platform-specific deallocation and cleanup of an event.
+ port::Status DeallocateEvent(Event *event);
+
+ // Inserts the specified event at the end of the specified stream.
+ port::Status RecordEvent(Stream *stream, Event *event);
+
+ // Wait for the specified event at the end of the specified stream.
+ port::Status WaitForEvent(Stream *stream, Event *event);
+
+ // Requests the current status of the event from the underlying platform.
+ Event::Status PollForEventStatus(Event *event);
+
+ // Allocates stream resources on the underlying platform for subject and
+ // initializes its internals.
+ bool AllocateStream(Stream *subject);
+
+ // Deallocates stream resources on the underlying platform.
+ void DeallocateStream(Stream *subject);
+
+ // Causes dependent to not begin execution until other has finished its
+ // last-enqueued work.
+ bool CreateStreamDependency(Stream *dependent, Stream *other);
+
+ // Allocates timer resources on the underlying platform for subject and
+ // initializes its internals.
+ bool AllocateTimer(Timer *subject);
+
+ // Deallocates timer resources on the underlying platform.
+ void DeallocateTimer(Timer *subject);
+
+ // Records a start event for an interval timer.
+ bool StartTimer(Stream *stream, Timer *timer);
+
+ // Records a stop event for an interval timer.
+ bool StopTimer(Stream *stream, Timer *timer);
+
+ // Allocates a new metadata object, appropriately populated, on the heap, with
+ // ownership transfer to caller.
+ DeviceDescription *PopulateDeviceDescription() const;
+
+ // Adds a task to the port::ThreadPool work queue. These tasks must be
+ // fire-and-forget and have no external data or timing dependencies; their
+ // execution order and completion time have no guarantees.
+ // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal;
+ // there, temporary internal buffers are freed using this method.
+ void EnqueueOnBackgroundThread(std::function<void()> task);
+
+ // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for
+ // leak checking. NULL buffer pointers and buffer sizes of 0 will not be
+ // tracked.
+ void CreateAllocRecord(void *opaque, uint64 size);
+
+ // Removes the AllocRecord keyed by 'opaque' from the record map. NULL
+ // pointers will not be erased (as they're not tracked, per above).
+ void EraseAllocRecord(void *opaque);
+
+ // Calls the relevant TraceListener routine to begin tracing for the specified
+ // asynchronous method.
+ template <typename TraceCallT, typename... ArgsT>
+ void SubmitTrace(TraceCallT trace_call, ArgsT&&... args);
+
+ // Reader/writer lock for class-static StreamExecutor members.
+ static mutex static_mu_;
+
+ // Reader/writer lock for mutable data structures on this StreamExecutor.
+ //
+ // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.)
+ // can acquire the lock on their first (mutating) call as well.
+ mutable mutex mu_;
+
+ // A mapping of pointer (to GPU memory) to string representation of the stack
+ // (of the allocating thread) at the time at which the pointer was allocated.
+ std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
+
+ // Pointer to the platform-specific-interface implementation. This is
+ // delegated to by the interface routines in pointer-to-implementation
+ // fashion.
+ std::unique_ptr<internal::StreamExecutorInterface> implementation_;
+
+ // Memoized BLAS support object -- we only want to create this once when asked
+ // for a BLAS interface.
+ std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
+
+ // Memoized DNN support object -- we only want to create this once when asked
+ // for an DNN interface.
+ std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_);
+
+ // Memoized FFT support object -- we only want to create this once when asked
+ // for a FFT interface.
+ std::unique_ptr<fft::FftSupport> fft_;
+
+ // Memoized RNG support object -- we only want to create this once when asked
+ // for an RNG interface.
+ std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_);
+
+ // Slot to cache the owned DeviceDescription for the underlying device
+ // once it has been quieried from DeviceDescription().
+ mutable std::unique_ptr<DeviceDescription> device_description_
+ GUARDED_BY(mu_);
+
+ // The kind of the underlying platform that is being targeted, as passed
+ // during construction.
+ //
+ // Immutable post-initialization.
+ PlatformKind platform_kind_;
+
+ // The device ordinal that this object was initialized with.
+ //
+ // Immutable post-initialization.
+ int device_ordinal_;
+
+ // Executor for handling host callback work that cannot be performed
+ // by a host callback thread - for example, cleanup after a host BLAS routine
+ // (which may make device API calls). This work cannot block the host
+ // callback thread, will be completed asynchronously, and should be treated
+ // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued
+ // here.
+ //
+ // Immutable post-initialization. Object is thread-safe.
+ std::unique_ptr<port::ThreadPool> background_threads_;
+
+ // Counter for the current number of live streams. This is used to check
+ // for accidentally-outstanding streams at StreamExecutor teardown time, as
+ // well
+ // as to indicate leaks (via a large outstanding count being logged) in the
+ // case we can't allocate more streams.
+ std::atomic_int_fast32_t live_stream_count_;
+
+ // Only one worker thread is needed; little work will be done by the
+ // executor.
+ static const int kNumBackgroundThreads = 1;
+
+ // Indicates if StreamExecutor operation tracing should be performed.
+ bool tracing_enabled_;
+
+ // The set of TraceListeners registered for this StreamExecutor.
+ std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
+
+ SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
+};
+
+////////////
+// Inlines
+
+template <typename T>
+inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) {
+ uint64 bytes = sizeof(T) * element_count;
+ void *opaque = Allocate(bytes);
+ return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
+}
+
+template <typename T>
+inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
+ const string &symbol_name) {
+ // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
+ // be nullptr/0 for consistency with DeviceMemory semantics.
+ void *opaque = nullptr;
+ size_t bytes = 0;
+ if (GetSymbol(symbol_name, &opaque, &bytes)) {
+ CHECK_EQ(bytes % sizeof(T), 0);
+ return DeviceMemory<T>::MakeFromByteSize(opaque, bytes);
+ }
+ return port::Status(
+ port::error::NOT_FOUND,
+ port::StrCat("Check if kernel using the symbol is loaded: ",
+ symbol_name));
+}
+
+template <typename ElemT>
+ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent,
+ DeviceMemoryBase value)
+ : wrapped_(value), parent_(parent) {}
+
+template <typename ElemT>
+ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(
+ StreamExecutor *parent, std::initializer_list<ElemT> values)
+ : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) {
+ if (ptr() != nullptr) {
+ std::vector<ElemT> local(values);
+ if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]),
+ ptr()->size())) {
+ Reset(nullptr);
+ }
+ }
+}
+
+template <typename ElemT>
+ScopedDeviceMemory<ElemT>::~ScopedDeviceMemory() {
+ parent_->Deallocate(&wrapped_);
+}
+
+template <typename ElemT>
+void ScopedDeviceMemory<ElemT>::Reset(DeviceMemory<ElemT> updated) {
+ parent_->Deallocate(&wrapped_);
+ wrapped_ = updated;
+}
+
+template <typename ElemT>
+void ScopedDeviceMemory<ElemT>::Reset(std::nullptr_t) {
+ parent_->Deallocate(&wrapped_);
+ wrapped_ = DeviceMemory<ElemT>{};
+}
+
+template <typename T>
+DeviceMemory<T> StreamExecutor::AllocateZeroed() {
+ void *opaque = Allocate(sizeof(T));
+ if (opaque == nullptr) {
+ return DeviceMemory<T>{};
+ }
+
+ DeviceMemory<T> result = DeviceMemory<T>::MakeFromByteSize(opaque, sizeof(T));
+ bool ok = SynchronousMemZero(&result, sizeof(T));
+ if (!ok) {
+ Deallocate(&result);
+ return DeviceMemory<T>{};
+ }
+
+ return result;
+}
+
+template <typename T>
+DeviceMemory<T> StreamExecutor::AllocateSubBuffer(DeviceMemory<T> *parent,
+ uint64 element_offset,
+ uint64 element_count) {
+ if (element_offset + element_count > parent->ElementCount()) {
+ LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater "
+ << "than parent allocation size: (" << element_offset << " + "
+ << element_count << ") vs. (" << parent->ElementCount() << ")";
+ return DeviceMemory<T>{};
+ }
+
+ void *opaque = implementation_->AllocateSubBuffer(
+ parent, sizeof(T) * element_offset, sizeof(T) * element_count);
+ if (opaque == nullptr) {
+ return DeviceMemory<T>{};
+ }
+ CreateAllocRecord(opaque, sizeof(T) * element_count);
+ return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count,
+ true /* = is_sub_buffer */));
+}
+
+template <typename... Params, typename... Args>
+inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
+ const TypedKernel<Params...> &kernel,
+ Args... args) {
+ KernelInvocationChecker<std::tuple<Params...>,
+ std::tuple<Args...>>::CheckAllStaticAssert();
+ if (ok()) {
+ // This is the core that allows type-safe kernel launching.
+ // Since the platforms take kernel arguments as tuples of (void *, size),
+ // we pack the variadic parameters passed as ...args into the desired
+ // tuple form and pass that packed form to the StreamExecutor::Launch()
+ // implementation.
+ std::vector<KernelArg> kernel_args;
+ kernel_args.reserve(kernel.Arity());
+ kernel.PackParams(&kernel_args, args...);
+ bool ok =
+ parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args);
+ if (!ok) {
+ SetError();
+ LOG(WARNING) << "parent failed to launch kernel: " << &kernel;
+ }
+ }
+ return *this;
+}
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
diff --git a/tensorflow/stream_executor/temporary_device_memory.cc b/tensorflow/stream_executor/temporary_device_memory.cc
new file mode 100644
index 0000000000..d11b58813d
--- /dev/null
+++ b/tensorflow/stream_executor/temporary_device_memory.cc
@@ -0,0 +1,53 @@
+#include "tensorflow/stream_executor/temporary_device_memory.h"
+
+#include "tensorflow/stream_executor/stream.h"
+
+namespace perftools {
+namespace gputools {
+
+TemporaryDeviceMemoryBase::~TemporaryDeviceMemoryBase() {
+ parent_->temporary_memory_manager()->MarkFinalized(device_memory_,
+ allocation_generation_,
+ /*must_exist=*/false);
+}
+
+DeviceMemoryBase* TemporaryDeviceMemoryBase::mutable_device_memory() {
+ DCHECK(!IsFinalized())
+ << "should not access device memory after finalization";
+ return &device_memory_;
+}
+
+const DeviceMemoryBase& TemporaryDeviceMemoryBase::device_memory() const {
+ DCHECK(!IsFinalized())
+ << "should not access device memory after finalization";
+ return device_memory_;
+}
+
+void TemporaryDeviceMemoryBase::Finalize() {
+ DCHECK(!IsFinalized()) << "should not finalize more than once";
+ parent_->temporary_memory_manager()->MarkFinalized(device_memory_,
+ allocation_generation_,
+ /*must_exist=*/true);
+}
+
+bool TemporaryDeviceMemoryBase::IsFinalized() const {
+ return parent_->temporary_memory_manager()->IsFinalized(
+ device_memory_, allocation_generation_);
+}
+
+bool TemporaryDeviceMemoryBase::IsAllocated() const {
+ return parent_->temporary_memory_manager()->HasAllocated(
+ device_memory_, allocation_generation_);
+}
+
+TemporaryDeviceMemoryBase::TemporaryDeviceMemoryBase(
+ Stream* parent, DeviceMemoryBase device_memory,
+ uint64 allocation_generation)
+ : device_memory_(device_memory),
+ allocation_generation_(allocation_generation),
+ parent_(parent) {
+ DCHECK(IsAllocated());
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/temporary_device_memory.h b/tensorflow/stream_executor/temporary_device_memory.h
new file mode 100644
index 0000000000..4e7c63056b
--- /dev/null
+++ b/tensorflow/stream_executor/temporary_device_memory.h
@@ -0,0 +1,123 @@
+// Temporary memories are used to allocate scratch space required by an
+// operation about to be enqueued onto a stream.
+//
+// std::unique_ptr<TemporaryDeviceMemory<float>> temporary_memory =
+// stream.AllocateTemporaryArray<float>(1024).ConsumeValueOrDie();
+// // ... enqueue stuff onto the stream using the temporary memory ...
+// // Note that the memory is accessible via
+// // temporary_memory->device_memory() and similar.
+//
+// // Finalize the temporary memory. The underlying device memory may
+// // be released any time after this program point, as another thread may
+// // call Stream::BlockHostUntilDone, causing synchronization. This
+// // finalization also happens automatically for the user if the unique_ptr
+// // goes out of scope.
+// temporary_memory.Finalize();
+//
+// WARNING: do NOT hold onto the device memory associated with temporary_memory
+// after finalization. If temporary_memory->device_memory() is used after the
+// temporary memory is finalized, it will cause a DCHECK failure.
+//
+// Note that standard usage takes advantage of the type-safe wrapper,
+// TemporaryDeviceMemory<T>, defined below.
+//
+// Also see tests for executable sample usage.
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
+
+#include "tensorflow/stream_executor/device_memory.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+namespace internal {
+class TemporaryMemoryManager;
+}
+
+// Untyped base class (analogous to a void*) for temporary device memory
+// allocations associated with a stream.
+class TemporaryDeviceMemoryBase {
+ public:
+ // Marks the temporary memory as finalized if it is not already marked as
+ // such.
+ ~TemporaryDeviceMemoryBase();
+
+ // Precondition: !IsFinalized()
+ DeviceMemoryBase* mutable_device_memory();
+
+ // Precondition: !IsFinalized()
+ const DeviceMemoryBase& device_memory() const;
+
+ // "Finalizes" this temporary memory, making it acceptable to release at the
+ // next stream synchronization point -- the device memory can be reclaimed at
+ // any time after the temporary memory is marked as finalized (e.g. if a
+ // separate thread is calls Stream::BlockHostUntilDone). This may only be
+ // called once -- see the precondition below.
+ //
+ // Precondition: !IsFinalized()
+ void Finalize();
+
+ // Returns true iff the temporary memory is finalized (that is, the user is
+ // done referring to the temporary device memory, and thus it can be released
+ // at the next stream synchronization point).
+ bool IsFinalized() const;
+
+ // Returns true iff the temporary memory is still allocated.
+ //
+ // Note: this is a polling call, no guarantee is made that the temporary
+ // memory is still allocated after the call has completed.
+ bool IsAllocated() const;
+
+ private:
+ friend class internal::TemporaryMemoryManager;
+ friend class TemporaryDeviceMemoryTest;
+
+ // Note: construction DCHECKs that the memory is known-allocated in the
+ // stream's temporary-allocation-manager.
+ TemporaryDeviceMemoryBase(Stream* parent, DeviceMemoryBase device_memory,
+ uint64 allocation_generation);
+
+ // The device memory region that has allocated.
+ DeviceMemoryBase device_memory_;
+
+ // The generation counter value for the temporary memory record in the
+ // temporary memory manager.
+ uint64 allocation_generation_;
+
+ // The stream that this temporary memory was allocated for.
+ Stream* parent_;
+};
+
+// Type-safe wrapper around the base type (which is analogous to a void*).
+template <typename T>
+class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase {
+ public:
+ // Type-safe wrapper around TemporaryDeviceMemoryBase::mutable_device_memory.
+ DeviceMemory<T>* mutable_device_memory() {
+ StaticSlicingAssertionDummy();
+ return reinterpret_cast<DeviceMemory<T>*>(
+ TemporaryDeviceMemoryBase::mutable_device_memory());
+ }
+
+ // Type-safe wrapper around TemporaryDeviceMemoryBase::device_memory.
+ const DeviceMemory<T>& device_memory() const {
+ StaticSlicingAssertionDummy();
+ return reinterpret_cast<const DeviceMemory<T>&>(
+ TemporaryDeviceMemoryBase::device_memory());
+ }
+
+ private:
+ static void StaticSlicingAssertionDummy() {
+ static_assert(
+ sizeof(TemporaryDeviceMemory) == sizeof(TemporaryDeviceMemoryBase),
+ "derived class is simply a wrapper, no members may be added due to "
+ "slicing");
+ }
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_
diff --git a/tensorflow/stream_executor/temporary_memory_manager.cc b/tensorflow/stream_executor/temporary_memory_manager.cc
new file mode 100644
index 0000000000..0352aa4b2b
--- /dev/null
+++ b/tensorflow/stream_executor/temporary_memory_manager.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/stream_executor/temporary_memory_manager.h"
+
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+
+void TemporaryMemoryManager::ForceDeallocateAll() {
+ mutex_lock lock(mutex_);
+ VLOG(1) << "force-deallocating " << records_.size() << " remaining records";
+ for (auto it = records_.begin(); it != records_.end(); ++it) {
+ DeviceMemoryBase device_memory = it->first;
+ stream_->parent()->Deallocate(&device_memory);
+ }
+}
+
+void TemporaryMemoryManager::MarkFinalized(
+ const DeviceMemoryBase& device_memory, uint64 generation, bool must_exist) {
+ mutex_lock lock(mutex_);
+ auto it = records_.find(device_memory);
+ if (it == records_.end()) {
+ if (must_exist) {
+ LOG(FATAL) << "attempted to mark finalization for temporary "
+ "memory that does not exist";
+ }
+ return;
+ }
+ it->second.finalized = true;
+}
+
+void TemporaryMemoryManager::DeallocateFinalizedTemporaries() {
+ mutex_lock lock(mutex_);
+ int deallocated_count = 0;
+ for (auto it = records_.begin(); it != records_.end();) {
+ if (it->second.finalized) {
+ DeviceMemoryBase device_memory = it->first;
+ stream_->parent()->Deallocate(&device_memory);
+ ++deallocated_count;
+ it = records_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ VLOG(1) << "deallocated " << deallocated_count << " finalized temporaries";
+}
+
+bool TemporaryMemoryManager::IsFinalized(const DeviceMemoryBase& device_memory,
+ uint64 allocation_generation) const {
+ mutex_lock lock(mutex_);
+ auto it = records_.find(device_memory);
+ if (it == records_.end()) {
+ return true; // If there's no record present it's vacuously finalized.
+ }
+
+ if (it->second.allocation_generation == allocation_generation) {
+ return it->second.finalized;
+ }
+
+ // If the allocation generation did not match, it's vacuously true.
+ return true;
+}
+
+bool TemporaryMemoryManager::HasAllocated(const DeviceMemoryBase& device_memory,
+ uint64 generation) const {
+ mutex_lock lock(mutex_);
+ auto it = records_.find(device_memory);
+ if (it == records_.end()) {
+ return false;
+ }
+ return it->second.allocation_generation == generation;
+}
+
+port::StatusOr<std::unique_ptr<TemporaryDeviceMemoryBase>>
+TemporaryMemoryManager::AllocateArrayBase(uint64 element_count,
+ uint64 element_size) {
+ uint64 byte_size = element_count * element_size;
+ DeviceMemoryBase device_memory =
+ stream_->parent()->AllocateArray<uint8>(byte_size);
+ if (device_memory == nullptr) {
+ return port::Status(port::error::RESOURCE_EXHAUSTED,
+ port::StrCat("could not allocate temporary memory of ",
+ byte_size, " bytes"));
+ }
+
+ uint64 generation;
+
+ // Add the record before instantiating the device memory instance so we can
+ // check the allocation invariant at TemporaryDeviceMemory construction time.
+ {
+ mutex_lock lock(mutex_);
+ generation = ++generation_;
+ DCHECK(records_.find(device_memory) == records_.end());
+ records_[device_memory] = {generation,
+ /*finalized=*/false};
+ }
+
+ VLOG(1) << port::Printf(
+ "stream %p allocated temporary device memory at %p (size %llu) in "
+ "generation %llu",
+ stream_, device_memory.opaque(), byte_size, generation);
+ std::unique_ptr<TemporaryDeviceMemoryBase> result(
+ new TemporaryDeviceMemoryBase(stream_, device_memory, generation));
+ return std::move(result);
+}
+
+} // namespace internal
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/temporary_memory_manager.h b/tensorflow/stream_executor/temporary_memory_manager.h
new file mode 100644
index 0000000000..847f0f2182
--- /dev/null
+++ b/tensorflow/stream_executor/temporary_memory_manager.h
@@ -0,0 +1,138 @@
+// The temporary-memory-manager is a helper class for a Stream to keep track of
+// temporary allocations. These allocations defer their deallocation to the next
+// Stream::BlockHostUntilDone call for efficiency purposes (as deallocation
+// itself generally forces synchronization to occur).
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_
+
+#include <map>
+#include <memory>
+
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/temporary_device_memory.h"
+
+namespace perftools {
+namespace gputools {
+namespace internal {
+
+// Record used inside the TemporaryMemoryManager as metadata for a given device
+// memory region.
+struct TemporaryMemoryRecord {
+ // What "generation" this record was allocated in.
+ //
+ // Currently the generation counter is bumped for every allocation, but this
+ // could be made coarser if necessary.
+ uint64 allocation_generation;
+
+ // Notes whether the temporary memory has been marked as finalized, such that
+ // we can release the DeviceMemory associated with this record at
+ // synchronization time.
+ bool finalized;
+};
+
+// Manages temporary memories associated with a stream -- keeps records of
+// outstanding temporaries and their state, and can deallocate them
+// appropriately at points in the Stream lifecycle (e.g. BlockHostUntilDone,
+// destruction).
+class TemporaryMemoryManager {
+ public:
+ explicit TemporaryMemoryManager(Stream* stream) : stream_(stream) {}
+
+ // Allocates a temporary array that is then managed by this object.
+ template <typename T>
+ port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> AllocateArray(
+ uint64 element_count);
+
+ // Forces deallocation of all managed temporary memory regions.
+ //
+ // Called, for example, when the Stream owning this temporary memory manager
+ // is destroyed.
+ //
+ // Note: These calls to Deallocate will likely force synchronization.
+ void ForceDeallocateAll();
+
+ // Marks the given memory region as finalized.
+ //
+ // If must_exist is set, this will check-fail if the temporary memory record
+ // is not found.
+ void MarkFinalized(const DeviceMemoryBase& device_memory, uint64 generation,
+ bool must_exist);
+
+ // Deallocates temporary memories that have been finalized.
+ //
+ // Note: These calls to Deallocate will likely force synchronization, so it is
+ // meant to be called before a "BlockHostUntilDone" is about to be performed.
+ void DeallocateFinalizedTemporaries();
+
+ // Returns whether the provided device_memory is finalized.
+ //
+ // In the vacuous case where the device memory doesn't appear in the temporary
+ // memory records, it is either not a temporary at all, or has already been
+ // deallocated, and thus returns true.
+ bool IsFinalized(const DeviceMemoryBase& device_memory,
+ uint64 allocation_generation) const;
+
+ // Returns whether the manager has a live allocation record for the given
+ // device memory pointer with the given generation counter.
+ //
+ // Note: this is a polling call -- there is no guarantee that the region is
+ // still allocated once the call has completed.
+ bool HasAllocated(const DeviceMemoryBase& device_memory,
+ uint64 generation) const;
+
+ private:
+ // Allocates an array without type parameterization, so that the
+ // implementation can live in the source file. Without this base allocation
+ // method, we incur a circular dependency between the StreamExecutor
+ // definition and this class' definition.
+ port::StatusOr<std::unique_ptr<TemporaryDeviceMemoryBase>> AllocateArrayBase(
+ uint64 element_count, uint64 element_size);
+
+ // Mutex to guard temporary record state.
+ mutable mutex mutex_;
+
+ // Mapping from device memory to the current (live) temporary memory record.
+ //
+ // If a device memory is not in this mapping, it is not a temporary currently
+ // allocated and owned by this temporary memory manager.
+ std::map<DeviceMemoryBase, TemporaryMemoryRecord> records_ GUARDED_BY(mutex_);
+
+ // Allocation generation -- we bump this counter to distinguish temporary
+ // memory handles that have been deallocated and later reallocated at the same
+ // device memory address.
+ uint64 generation_ GUARDED_BY(mutex_);
+
+ // The stream (parent object) for this temporary memory manager -- allocations
+ // are performed through this stream handle.
+ Stream* stream_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(TemporaryMemoryManager);
+};
+
+////////////
+// Inlines
+
+template <typename T>
+port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
+TemporaryMemoryManager::AllocateArray(uint64 element_count) {
+ port::StatusOr<std::unique_ptr<TemporaryDeviceMemoryBase>> temporary_memory =
+ AllocateArrayBase(element_count, sizeof(T));
+ if (!temporary_memory.ok()) {
+ return temporary_memory.status();
+ }
+
+ return std::unique_ptr<TemporaryDeviceMemory<T>>(
+ reinterpret_cast<TemporaryDeviceMemory<T>*>(
+ temporary_memory.ConsumeValueOrDie().release()));
+}
+
+} // namespace internal
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_
diff --git a/tensorflow/stream_executor/timer.cc b/tensorflow/stream_executor/timer.cc
new file mode 100644
index 0000000000..46210a2346
--- /dev/null
+++ b/tensorflow/stream_executor/timer.cc
@@ -0,0 +1,41 @@
+#include "tensorflow/stream_executor/timer.h"
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace perftools {
+namespace gputools {
+
+static internal::TimerInterface *CreateTimerImplementation(
+ StreamExecutor *parent) {
+ PlatformKind platform_kind = parent->platform_kind();
+ if (platform_kind == PlatformKind::kCuda) {
+ return (*internal::MakeCUDATimerImplementation())(parent);
+ } else if (platform_kind == PlatformKind::kOpenCL ||
+ platform_kind == PlatformKind::kOpenCLAltera) {
+ return (*internal::MakeOpenCLTimerImplementation())(parent);
+ } else if (platform_kind == PlatformKind::kHost) {
+ return internal::MakeHostTimerImplementation(parent);
+ } else if (platform_kind == PlatformKind::kMock) {
+ return nullptr;
+ } else {
+ LOG(FATAL) << "cannot create timer implementation for platform kind: "
+ << PlatformKindString(platform_kind);
+ }
+}
+
+Timer::Timer(StreamExecutor *parent)
+ : implementation_(CreateTimerImplementation(parent)), parent_(parent) {}
+
+Timer::~Timer() { parent_->DeallocateTimer(this); }
+
+uint64 Timer::Microseconds() const { return implementation_->Microseconds(); }
+
+uint64 Timer::Nanoseconds() const { return implementation_->Nanoseconds(); }
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/timer.h b/tensorflow/stream_executor/timer.h
new file mode 100644
index 0000000000..ff54c06180
--- /dev/null
+++ b/tensorflow/stream_executor/timer.h
@@ -0,0 +1,60 @@
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TIMER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TIMER_H_
+
+#include <memory>
+
+#include "tensorflow/stream_executor/platform/port.h"
+
+namespace perftools {
+namespace gputools {
+
+namespace internal {
+class TimerInterface;
+} // namespace internal
+
+class StreamExecutor;
+
+// An interval timer, suitable for use in timing the operations which occur in
+// streams.
+//
+// Thread-hostile: CUDA associates a CUDA-context with a particular thread in
+// the system. Any operation that a user attempts to perform by using a Timer
+// on a thread not-associated with the CUDA-context has unknown behavior at the
+// current time; see b/13176597
+class Timer {
+ public:
+ // Instantiate a timer tied to parent as a platform executor.
+ explicit Timer(StreamExecutor *parent);
+
+ // Deallocates any timer resources that the parent StreamExecutor has bestowed
+ // upon this object.
+ ~Timer();
+
+ // Returns the elapsed number of microseconds for a completed timer.
+ // Completed means has been through a start/stop lifecycle.
+ uint64 Microseconds() const;
+
+ // Returns the elapsed number of nanoseconds for a completed timer.
+ // Completed means has been through a start/stop lifecycle.
+ uint64 Nanoseconds() const;
+
+ // Returns the (opaque) backing platform ITimer instance. Ownership is
+ // not transferred to the caller.
+ internal::TimerInterface *implementation() { return implementation_.get(); }
+
+ private:
+ // Platform-dependent implementation of the timer internals for the underlying
+ // platform. This class just delegates to this opaque instance.
+ std::unique_ptr<internal::TimerInterface> implementation_;
+
+ // The StreamExecutor that manages the platform-specific internals for this
+ // timer.
+ StreamExecutor *parent_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(Timer);
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TIMER_H_
diff --git a/tensorflow/stream_executor/trace_listener.h b/tensorflow/stream_executor/trace_listener.h
new file mode 100644
index 0000000000..dcbb223f4f
--- /dev/null
+++ b/tensorflow/stream_executor/trace_listener.h
@@ -0,0 +1,59 @@
+// This file defines the StreamExecutor trace listener, used for inserting
+// non-device-specific instrumentation into the StreamExecutor.
+#ifndef TENSORFLOW_STREAM_EXECUTOR_TRACE_LISTENER_H_
+#define TENSORFLOW_STREAM_EXECUTOR_TRACE_LISTENER_H_
+
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/kernel.h"
+#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/lib/status.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+
+// Traces StreamExecutor PIMPL-level events.
+// The few StreamExecutor interfaces that are synchronous have both Begin and
+// Complete versions of their trace calls. Asynchronous operations only have
+// Submit calls, as execution of the underlying operations is device-specific.
+// As all tracing calls mirror StreamExecutor routines, documentation here is
+// minimal.
+//
+// All calls have default implementations that perform no work; subclasses
+// should override functionality of interest. Keep in mind that these routines
+// are not called on a dedicated thread, so callbacks should execute quickly.
+//
+// Note: This API is constructed on an as-needed basis. Users should add
+// support for further StreamExecutor operations as required. By enforced
+// convention (see SCOPED_TRACE in stream_executor_pimpl.cc), synchronous
+// tracepoints should be named NameBegin and NameComplete.
+class TraceListener {
+ public:
+ virtual ~TraceListener() {}
+
+ virtual void LaunchSubmit(Stream* stream, const ThreadDim& thread_dims,
+ const BlockDim& block_dims,
+ const KernelBase& kernel,
+ const std::vector<KernelArg>& args) {}
+
+ virtual void SynchronousMemcpyH2DBegin(int64 correlation_id,
+ const void* host_src, int64 size,
+ DeviceMemoryBase* gpu_dst) {}
+ virtual void SynchronousMemcpyH2DComplete(int64 correlation_id,
+ const port::Status* result) {}
+
+ virtual void SynchronousMemcpyD2HBegin(int64 correlation_id,
+ const DeviceMemoryBase& gpu_src,
+ int64 size, void* host_dst) {}
+ virtual void SynchronousMemcpyD2HComplete(int64 correlation_id,
+ const port::Status* result) {}
+
+ virtual void BlockHostUntilDoneBegin(int64 correlation_id, Stream* stream) {}
+ virtual void BlockHostUntilDoneComplete(int64 correlation_id, bool result) {}
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_TRACE_LISTENER_H_