aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r--tensorflow/stream_executor/stream.h1258
1 files changed, 1258 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
new file mode 100644
index 0000000000..d4d5e7729b
--- /dev/null
+++ b/tensorflow/stream_executor/stream.h
@@ -0,0 +1,1258 @@
+// The Stream is used in conjunction with the StreamExecutor "parent" to
+// perform actions with a linear stream of dependencies. Dependencies can also
+// be created between Streams to do task management (i.e. limit which tasks
+// can be performed concurrently and specify what task dependencies exist).
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
+#define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
+
+#include <complex>
+#include <functional>
+#include <memory>
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/fft.h"
+#include "tensorflow/stream_executor/kernel.h"
+#include "tensorflow/stream_executor/launch_dim.h"
+#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/temporary_memory_manager.h"
+
+namespace perftools {
+namespace gputools {
+
+namespace host {
+class HostBlas;
+class HostFft;
+class HostRng;
+class HostTimer;
+} // namespace host
+
+namespace ocl {
+class CLBlas;
+} // namespace ocl
+
+namespace internal {
+class StreamInterface;
+} // namespace internal
+
+class DeviceMemoryBase;
+template <typename ElemT>
+class DeviceMemory;
+
+class Timer;
+
+namespace dnn {
+struct BatchDescriptor;
+struct FilterDescriptor;
+struct ConvolutionDescriptor;
+} // namespace dnn
+
+class StreamExecutor;
+
+// Represents a stream of dependent computations on a GPU device.
+//
+// The operations within a stream execute linearly and asynchronously until
+// BlockHostUntilDone() is invoked, which synchronously joins host code with
+// the execution of the stream.
+//
+// If any given operation fails when entraining work for the stream, ok() will
+// indicate that an error has occurred. After initialization, once a stream is
+// !ok(), it will never be ok().
+//
+// Thread-safe post-initialization.
+class Stream {
+ public:
+ // Instantiate a stream tied to parent as a platform executor. Work
+ // entrained onto this stream will be launched/managed on that
+ // StreamExecutor's platform.
+ explicit Stream(StreamExecutor *parent);
+
+ // Test only. Use an externally-populated value (like a mock) for the
+ // platform-specific stream implementation.
+ Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
+
+ // Deallocates any stream resources that the parent StreamExecutor has
+ // bestowed
+ // upon this object.
+ ~Stream();
+
+ // Returns whether any errors have occurred while entraining work for this
+ // stream.
+ bool ok() const { return !InErrorState(); }
+
+ // Initialize the stream. This must be performed before entraining any other
+ // operations.
+ Stream &Init();
+
+ // Initializes timer t via the StreamExecutor.
+ Stream &InitTimer(Timer *t);
+
+ // Convenience wrapper around Init() and InitTimer().
+ Stream &InitWithTimer(Timer *t);
+
+ // Warning! After calling BlockHostUntilDone(), all sub-streams will be
+ // returned and hence invalid. This may be a temporary solution to the issue
+ // b/18070215.
+ // Get or create a sub-stream from this stream. If there is any sub-stream
+ // in the pool that can be reused then just return this sub-stream.
+ // Otherwise
+ // create a new sub-stream.
+ Stream *GetOrCreateSubStream();
+
+ // Return the sub-stream back to the host stream so that it can be reused
+ // later.
+ void ReturnSubStream(Stream *sub_stream);
+
+ // Allocate temporary memories. The stream will deallocate them when blocked
+ // or destroyed.
+ template <typename T>
+ port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
+ AllocateTemporaryArray(uint64 element_count);
+
+ // Entrains onto the stream of operations: a kernel launch with the given
+ // (variadic) parameters for the invocation. These arguments can be things
+ // like DeviceMemory or primitive types such as int. What arguments you may
+ // pass to a given kernel are noted as the template parameters to the
+ // TypedKernel type that the machocc compiler generates.
+ //
+ // Template parameters:
+ // Params... The type list of formal parameters that the typed kernel
+ // expects, which is matched against Args...
+ // Args... The deduced type list for passed actual arguments
+ //
+ // Implementation: A compile-time compatibility check is performed that has
+ // some leniency versus an exact parameter pack match -- for example,
+ // `const DeviceMemory<T>` is considered "pack compatible" with a
+ // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
+ // perfect forwarding support without rvalue references. It also attempts to
+ // spit out helpful static_assert error traces with information as to the
+ // argument number and types that were mismatched.
+ template <typename... Params, typename... Args>
+ Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
+ const TypedKernel<Params...> &kernel, Args... args);
+
+ // Record a "start" event for the interval timer at this point in the
+ // stream's
+ // execution (relative to the previously and subsequently enqueued items in
+ // the stream's execution). Streams may be started/stopped multiple times.
+ Stream &ThenStartTimer(Timer *t);
+
+ // Record a "stop" event for the interval timer at this point in the
+ // stream's
+ // execution. See also Stream::ThenStartTimer.
+ Stream &ThenStopTimer(Timer *t);
+
+ // TODO(leary) If work is added to the stream that is being depended upon,
+ // then what? Have to describe what happens.
+ template <typename... Params>
+ Stream &ThenWaitFor(Stream *other, Params... more_streams) {
+ return ThenWaitFor(more_streams...).ThenWaitFor(other);
+ }
+
+ // Create a dependency for this stream's next work on the other stream
+ // completing. Does not take ownership of other, and other must not be
+ // null.
+ //
+ // Checks that a stream does not wait for itself, and it is up to the
+ // user to guarantee that a stream does not come to wait on itself in a
+ // cyclic
+ // manner; in that case, behavior is undefined.
+ //
+ // N.B. Base recursion case for the variadic ThenWaitFor.
+ Stream &ThenWaitFor(Stream *other);
+
+ // Waits for all streams values in others.
+ // Checks that there is no shallow circular wait (i.e. that "this" is not in
+ // others).
+ Stream &ThenWaitFor(std::vector<std::unique_ptr<Stream>> *others);
+
+ // Waits for an event object to be set.
+ // Note that ThenRecordEvent must have been called on the event before
+ // you call this function; otherwise the event will be considered complete
+ // and this wait will do nothing.
+ Stream &ThenWaitFor(Event *event);
+
+ // Inserts the specified event into the end of this stream. Once the stream
+ // has processed all events prior to the insertion point, the event will be
+ // marked as completed.
+ // The stream does not take ownership of event - meaning that event's lifetime
+ // must extend past the point at which it is marked complete!
+ Stream &ThenRecordEvent(Event *event);
+
+ ////////////////
+ // DNN support
+ //
+ // See DnnSupport::* for comments on the following methods.
+
+ // TODO(leary) add double-precision version of this interface.
+ Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output);
+
+ Stream &ThenSeparableConvolve(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
+ const DeviceMemory<float> &first_weights,
+ const DeviceMemory<float> &second_weights,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output);
+
+ Stream &ThenConvolveBackwardData(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data);
+
+ Stream &ThenConvolveBackwardFilter(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data);
+
+ Stream &ThenMatMul(const DeviceMemory<float> &input_data,
+ const DeviceMemory<float> &weights,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
+ const DeviceMemory<int8> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
+ const DeviceMemory<int16> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
+ const DeviceMemory<float> &biases,
+ const dnn::BatchDescriptor &dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ const DeviceMemory<float> &output_data,
+ const DeviceMemory<float> &input_diff_data,
+ DeviceMemory<float> *output_diff_data);
+
+ Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor,
+ const DeviceMemory<float> &input_data,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenActivate(dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenDepthConcatenate(
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float> *> input_data,
+ DeviceMemory<float> *output_data);
+
+ Stream &ThenElementwiseOperate(
+ dnn::ElementwiseOperation operation,
+ port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
+ port::ArraySlice<const DeviceMemory<float> *> input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
+
+ // See DnnSupport::DoMemcpyD2HQuantized.
+ // TODO(wgulland) Use a template to merge the versions of
+ // ThenMemcpyD2HQuantized.
+ Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<uint8> host_dst);
+
+ // See DnnSupport::DoMemcpyD2HQuantized.
+ Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<uint16> host_dst);
+
+ // See DnnSupport::DoMemcpyD2HQuantized.
+ Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<int32> host_dst);
+
+ // See DnnSupport::DoMemcpyH2DQuantized.
+ Stream &ThenMemcpyH2DQuantized(port::ArraySlice<uint8> host_src,
+ DeviceMemory<float> *gpu_unquantized_dst);
+
+ /////////////////
+ // BLAS support
+
+ // See BlasSupport::DoBlasAsum.
+ Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *result);
+ Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *result);
+ Stream &ThenBlasAsum(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result);
+ Stream &ThenBlasAsum(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result);
+
+ // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
+ // present in DeviceMemory, it must be an execution-time constant (i.e. a
+ // value
+ // that the stream does not change or populate during the course of
+ // execution). The value is effectively captured at stream-enqueue time.
+ Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasCopy.
+ Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasCopy(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasCopy(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasDot.
+ Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *result);
+ Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *result);
+
+ // See BlasSupport::DoBlasDotc.
+ Stream &ThenBlasDotc(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result);
+ Stream &ThenBlasDotc(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result);
+
+ // See BlasSupport::DoBlasDotu.
+ Stream &ThenBlasDotu(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *result);
+ Stream &ThenBlasDotu(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *result);
+
+ // See BlasSupport::DoBlasNrm2.
+ Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<float> *result);
+ Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<double> *result);
+ Stream &ThenBlasNrm2(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<float> *result);
+ Stream &ThenBlasNrm2(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<double> *result);
+
+ // See BlasSupport::DoBlasRot.
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy, float c, float s);
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy, double c, double s);
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
+ int incx, DeviceMemory<std::complex<float>> *y, int incy,
+ float c, float s);
+ Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
+ int incx, DeviceMemory<std::complex<double>> *y, int incy,
+ double c, double s);
+
+ // See BlasSupport::DoBlasRotg.
+ Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
+ DeviceMemory<float> *c, DeviceMemory<float> *s);
+ Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
+ DeviceMemory<double> *c, DeviceMemory<double> *s);
+ Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
+ DeviceMemory<std::complex<float>> *b,
+ DeviceMemory<float> *c,
+ DeviceMemory<std::complex<float>> *s);
+ Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
+ DeviceMemory<std::complex<double>> *b,
+ DeviceMemory<double> *c,
+ DeviceMemory<std::complex<double>> *s);
+
+ // See BlasSupport::DoBlasRotm.
+ Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy,
+ const DeviceMemory<float> &param);
+ Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy,
+ const DeviceMemory<double> &param);
+
+ // See BlasSupport::DoBlasRotmg.
+ Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
+ DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
+ DeviceMemory<float> *param);
+ Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
+ DeviceMemory<double> *x1,
+ const DeviceMemory<double> &y1,
+ DeviceMemory<double> *param);
+
+ // See BlasSupport::DoBlasScal.
+ Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
+ int incx);
+ Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
+ int incx);
+ Stream &ThenBlasScal(uint64 elem_count, float alpha,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasScal(uint64 elem_count, double alpha,
+ DeviceMemory<std::complex<double>> *x, int incx);
+ Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasSwap.
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
+ DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
+ int incx, DeviceMemory<std::complex<float>> *y,
+ int incy);
+ Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
+ int incx, DeviceMemory<std::complex<double>> *y,
+ int incy);
+
+ // See BlasSupport::DoBlasIamax.
+ Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamax(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result);
+ Stream &ThenBlasIamax(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<int> *result);
+
+ // See BlasSupport::DoBlasIamin.
+ Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
+ int incx, DeviceMemory<int> *result);
+ Stream &ThenBlasIamin(uint64 elem_count,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<int> *result);
+ Stream &ThenBlasIamin(uint64 elem_count,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<int> *result);
+
+ // See BlasSupport::DoBlasGbmv.
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &x, int incx,
+ float beta, DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &x, int incx,
+ double beta, DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
+ uint64 ku, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasGemv.
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasGer.
+ Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda);
+ Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda);
+
+ // See BlasSupport::DoBlasGerc.
+ Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasGeru.
+ Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasHbmv.
+ Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasHemv.
+ Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasHer.
+ Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasHer2.
+ Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *a, int lda);
+ Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *a, int lda);
+
+ // See BlasSupport::DoBlasHpmv.
+ Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &ap,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *y, int incy);
+ Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &ap,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *y, int incy);
+
+ // See BlasSupport::DoBlasHpr.
+ Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ DeviceMemory<std::complex<float>> *ap);
+ Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ DeviceMemory<std::complex<double>> *ap);
+
+ // See BlasSupport::DoBlasHpr2.
+ Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &x, int incx,
+ const DeviceMemory<std::complex<float>> &y, int incy,
+ DeviceMemory<std::complex<float>> *ap);
+ Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &x, int incx,
+ const DeviceMemory<std::complex<double>> &y, int incy,
+ DeviceMemory<std::complex<double>> *ap);
+
+ // See BlasSupport::DoBlasSbmv.
+ Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+
+ // See BlasSupport::DoBlasSpmv.
+ Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &ap,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &ap,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+
+ // See BlasSupport::DoBlasSpr.
+ Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *ap);
+ Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *ap);
+
+ // See BlasSupport::DoBlasSpr2.
+ Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *ap);
+ Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *ap);
+
+ // See BlasSupport::DoBlasSymv.
+ Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &x, int incx, float beta,
+ DeviceMemory<float> *y, int incy);
+ Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &x, int incx, double beta,
+ DeviceMemory<double> *y, int incy);
+
+ // See BlasSupport::DoBlasSyr.
+ Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ DeviceMemory<float> *a, int lda);
+ Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ DeviceMemory<double> *a, int lda);
+
+ // See BlasSupport::DoBlasSyr2.
+ Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
+ const DeviceMemory<float> &x, int incx,
+ const DeviceMemory<float> &y, int incy,
+ DeviceMemory<float> *a, int lda);
+ Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
+ const DeviceMemory<double> &x, int incx,
+ const DeviceMemory<double> &y, int incy,
+ DeviceMemory<double> *a, int lda);
+
+ // See BlasSupport::DoBlasTbmv.
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTbsv.
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n, uint64 k,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTpmv.
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx);
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap, DeviceMemory<double> *x,
+ int incx);
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTpsv.
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &ap, DeviceMemory<float> *x,
+ int incx);
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &ap, DeviceMemory<double> *x,
+ int incx);
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &ap,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &ap,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTrmv.
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasTrsv.
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<float> &a, int lda,
+ DeviceMemory<float> *x, int incx);
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<double> &a, int lda,
+ DeviceMemory<double> *x, int incx);
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *x, int incx);
+ Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
+ blas::Diagonal diag, uint64 n,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *x, int incx);
+
+ // See BlasSupport::DoBlasGemm.
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb, float beta,
+ DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb, double beta,
+ DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasGemmBatched.
+ Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, float alpha,
+ const port::ArraySlice<DeviceMemory<float> *> &a,
+ int lda,
+ const port::ArraySlice<DeviceMemory<float> *> &b,
+ int ldb, float beta,
+ const port::ArraySlice<DeviceMemory<float> *> &c,
+ int ldc, int batch_count);
+ Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, double alpha,
+ const port::ArraySlice<DeviceMemory<double> *> &a,
+ int lda,
+ const port::ArraySlice<DeviceMemory<double> *> &b,
+ int ldb, double beta,
+ const port::ArraySlice<DeviceMemory<double> *> &c,
+ int ldc, int batch_count);
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
+ std::complex<float> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
+ int batch_count);
+ Stream &ThenBlasGemmBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
+ std::complex<double> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
+ int batch_count);
+
+ // See BlasSupport::DoBlasHemm.
+ Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasHerk.
+ Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, float alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc);
+ Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, double alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc);
+
+ // See BlasSupport::DoBlasHer2k.
+ Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ float beta, DeviceMemory<std::complex<float>> *c,
+ int ldc);
+ Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ double beta, DeviceMemory<std::complex<double>> *c,
+ int ldc);
+
+ // See BlasSupport::DoBlasSymm.
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasSyrk.
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a,
+ int lda, float beta, DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a,
+ int lda, double beta, DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasSyr2k.
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a,
+ int lda, const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc);
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a,
+ int lda, const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc);
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ const DeviceMemory<std::complex<float>> &b, int ldb,
+ std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ const DeviceMemory<std::complex<double>> &b, int ldb,
+ std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c, int ldc);
+
+ // See BlasSupport::DoBlasTrmm.
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, DeviceMemory<float> *b, int ldb);
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, DeviceMemory<double> *b, int ldb);
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb);
+ Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb);
+
+ // See BlasSupport::DoBlasTrsm.
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, float alpha, const DeviceMemory<float> &a,
+ int lda, DeviceMemory<float> *b, int ldb);
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, double alpha, const DeviceMemory<double> &a,
+ int lda, DeviceMemory<double> *b, int ldb);
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda,
+ DeviceMemory<std::complex<float>> *b, int ldb);
+ Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
+ blas::Transpose transa, blas::Diagonal diag, uint64 m,
+ uint64 n, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda,
+ DeviceMemory<std::complex<double>> *b, int ldb);
+
+ // See FftSupport::DoFft.
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<std::complex<float>> *output);
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<std::complex<double>> *output);
+ Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
+ DeviceMemory<std::complex<float>> *output);
+ Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
+ DeviceMemory<std::complex<double>> *output);
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<float>> &input,
+ DeviceMemory<float> *output);
+ Stream &ThenFft(fft::Plan *plan,
+ const DeviceMemory<std::complex<double>> &input,
+ DeviceMemory<double> *output);
+
+ // Makes the RNG use the provided value as the basis for further generation.
+ // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
+ // sources of seed data if the default (high quality) sources are not
+ // desired.
+ // For most use cases, this function will not be necessary; each provided
+ // back-end implementation will be appropriately seeded by default.
+ // At a minimum 16 bytes of data are required in the seed buffer.
+ //
+ // To seed with good (non-reproducable) data:
+ // File* f = File::Open("/dev/random", "r");
+ // int64 bytes_read = f->Read(seed_data, bytes_to_read);
+ // < error checking >
+ // stream.ThenSetRngSeed(seed_data, bytes_read);
+ //
+ // To seed with reproducible data:
+ // uint64_t seed_data[2] = { <data> };
+ // stream.ThenSetRngSeed(seed_data, 16);
+ Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
+
+ // Populates the memory indicated by values with uniform-random-distribution
+ // values. TODO(leary) seeding API/description
+ //
+ // Uses the type and size of the DeviceMemory to infer what data should be
+ // populated.
+ Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
+ Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
+ Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
+ Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
+ Stream &ThenPopulateRandGaussian(float mean, float stddev,
+ DeviceMemory<float> *values);
+ Stream &ThenPopulateRandGaussian(double mean, double stddev,
+ DeviceMemory<double> *values);
+
+ // Entrain onto the stream: a memcpy to a host destination from a GPU source
+ // of the given target size. host_dst must be a pointer to host memory
+ // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
+ // then registered with StreamExecutor::HostMemoryRegister.
+ Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size);
+
+ // Entrain onto the stream: a memcpy to a GPU destination from a host source
+ // of the given target size. host_src must be a pointer to host memory
+ // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
+ // then registered with StreamExecutor::HostMemoryRegister.
+ Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
+ uint64 size);
+
+ // Alternative interface for memcpying from device to host that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <typename T>
+ Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
+ port::MutableArraySlice<T> host_dst) {
+ auto host_size = host_dst.size() * sizeof(T);
+ CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
+ return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
+ }
+
+ // Alternative interface for memcpying from host to device that takes an
+ // array slice. Checks that the destination size can accomodate the host
+ // slice size.
+ template <typename T>
+ Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
+ DeviceMemory<T> *gpu_dst) {
+ auto host_size = host_src.size() * sizeof(T);
+ CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
+ return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
+ }
+
+ // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
+ // of the given target size. gpu_src/dst must be pointers to GPU memory and
+ // peer access must be enabled between their owning StreamExecutors.
+ Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
+ uint64 size);
+
+ // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
+ // ensuring that the host pointer isn't getting confused accidentally with a
+ // device pointer if you're not doing metaprogramming against the API.
+ Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
+ const DeviceMemoryBase &gpu_src, uint64 size) {
+ return ThenMemcpy(gpu_dst, gpu_src, size);
+ }
+
+ // Entrain onto the stream: a memset of zero at a GPU location of size
+ // bytes.
+ // The location must not be null.
+ // TODO(leary) Presently the size must be a 4-byte multiple.
+ Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
+
+ // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location
+ // of
+ // size bytes, where bytes must be evenly 32-bit sized (i.e. evently
+ // divisible
+ // by 4). The location must not be null.
+ Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,
+ uint64 size);
+
+ // (Synchronously) block the host code waiting for the operations entrained
+ // on
+ // the stream (enqueued to this point in program execution) to complete.
+ bool BlockHostUntilDone();
+
+ // Warning! This method interacts with internal threads in
+ // sometimes-unpredictable ways and is intended for GPU-Executor-internal
+ // use
+ // only. Please check with a member of the FASTR team before making use of
+ // this method.
+ //
+ // Entrains onto the stream a function to be executed on the host at some
+ // point in the future.
+ // Async host callbacks DO NOT block the stream as device functions (or as
+ // synchronous host callbacks). No synchronization is possible with
+ // asynchronous callbacks; they are strictly fire-and-forget.
+ // This method is private due to the potential for undefined behavior with
+ // synchronization using OpenCL user events.
+ // The ONLY lifetime guarantee in these calls is that the StreamExecutor
+ // parameter will still be valid - this Stream may not be!
+ // Any callbacks requiring device API calls must use this method.
+ Stream &ThenEnqueueOnBackgroundThread(
+ std::function<void(StreamExecutor *)> task);
+
+ // Returns the (opaque) platform-specific backing object. Ownership is not
+ // transferred to the caller.
+ internal::StreamInterface *implementation() { return implementation_.get(); }
+
+ // Entrains onto the stream a callback to the host (from the device).
+ // Host callbacks block/occupy the stream just as device functions
+ // (execute one at a time, block later stream operations).
+ // Behavior is undefined when synchronizing using OpenCL user events.
+ // Behavior is undefined if host callbacks call device routines or insert
+ // them into any stream.
+ // On certain platforms, ThenDoHostCallback is expected to have significant
+ // negative effects on performance.
+ Stream &ThenDoHostCallback(std::function<void()> callback);
+
+ // Identical to ThenDoHostCallback; only exposed for testing purposes.
+ Stream &ThenDoHostCallbackForTest(std::function<void()> callback);
+
+ // Returns the StreamExecutor (parent object) associated with this stream.
+ StreamExecutor *parent() const {
+ CHECK(parent_ != nullptr);
+ return parent_;
+ }
+
+ // Returns the (internal usage) temporary-memory-allocation manager associated
+ // with this stream.
+ internal::TemporaryMemoryManager *temporary_memory_manager();
+
+ private:
+ friend class host::HostBlas; // for parent_.
+ friend class host::HostFft; // for parent_.
+ friend class host::HostRng; // for parent_.
+ template <typename... Args>
+ friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
+ friend class ocl::CLBlas; // for parent_.
+
+ bool InErrorState() const {
+ shared_lock lock{mu_};
+ return !ok_;
+ }
+
+ // Sets the error state if operation_retcode is false.
+ // This is a useful shorthand for many stream routines.
+ void CheckError(bool operation_retcode) {
+ if (operation_retcode) {
+ return;
+ }
+ mutex_lock lock{mu_};
+ ok_ = false;
+ }
+
+ void SetError() { CheckError(false /* = operation_retcode */); }
+
+ // The platform-dependent implementation that the StreamExecutor interface
+ // delegates to.
+ std::unique_ptr<internal::StreamInterface> implementation_;
+
+ // The StreamExecutor that supports the operation of this stream.
+ StreamExecutor *parent_;
+
+ // mutex that guards the allocation / error state flags.
+ // Mutable so that it can be obtained via const reader lock.
+ mutable mutex mu_;
+
+ // Whether Init() was successfully called to allocate this stream on the
+ // underlying platform. It simply flips from 0 to 1 with a sanity check.
+ // See StreamExecutor::AllocateStream.
+ bool allocated_ GUARDED_BY(mu_);
+
+ // Whether all operations have entrained successfully to the current program
+ // point.
+ bool ok_ GUARDED_BY(mu_);
+
+ // Sub-streams that are generated from this stream. Each element has a pointer
+ // to sub-stream and a boolean value indicating if this substream is ready to
+ // be reused.
+ std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
+ GUARDED_BY(mu_);
+
+ // Streams can allocate temporary memories to help with work they enqueue
+ // (e.g. for scratch memory spaces). This member tracks those allocations and
+ // notes when they can be reclaimed -- reclamation is attempted when
+ // BlockHostUntilDone() is called.
+ internal::TemporaryMemoryManager temporary_memory_manager_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(Stream);
+};
+
+////////////
+// Inlines
+
+template <typename T>
+inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
+Stream::AllocateTemporaryArray(uint64 element_count) {
+ return temporary_memory_manager_.AllocateArray<T>(element_count);
+}
+
+inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
+ return &temporary_memory_manager_;
+}
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_