diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/thunk.h')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/thunk.h | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 931c0bffab..4df0bb005b 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,7 +41,7 @@ class GpuExecutable; // This is thread-compatible. class Thunk { public: - enum class Kind { + enum Kind { kConditional, kConvolution, kCopy, @@ -53,6 +54,7 @@ class Thunk { kKernel, kMemset32BitValue, kMemzero, + kOutfeed, kSequential, kTuple, kWhile, @@ -94,11 +96,12 @@ class Thunk { // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's - // lifetime. Stream argument must be non-null. + // lifetime. 'stream' and 'profiler' must be non-null. // // Precondition: Initialize(stream->parent()) has been called. virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) = 0; + se::Stream* stream, + HloExecutionProfiler* profiler) = 0; private: Kind kind_; @@ -108,6 +111,8 @@ class Thunk { // A sequence of thunks. using ThunkSequence = std::vector<std::unique_ptr<Thunk>>; +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); + } // namespace gpu } // namespace xla |