aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/thunk.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/thunk.h')
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h11
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 931c0bffab..4df0bb005b 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -40,7 +41,7 @@ class GpuExecutable;
// This is thread-compatible.
class Thunk {
public:
- enum class Kind {
+ enum Kind {
kConditional,
kConvolution,
kCopy,
@@ -53,6 +54,7 @@ class Thunk {
kKernel,
kMemset32BitValue,
kMemzero,
+ kOutfeed,
kSequential,
kTuple,
kWhile,
@@ -94,11 +96,12 @@ class Thunk {
// Execute the kernel for the thunk on the given stream. This method must be
// called after Initialize and can be called multiple times over Thunk's
- // lifetime. Stream argument must be non-null.
+ // lifetime. 'stream' and 'profiler' must be non-null.
//
// Precondition: Initialize(stream->parent()) has been called.
virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) = 0;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) = 0;
private:
Kind kind_;
@@ -108,6 +111,8 @@ class Thunk {
// A sequence of thunks.
using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
+std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
+
} // namespace gpu
} // namespace xla