aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/executor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/executor.cc')
-rw-r--r--tensorflow/core/common_runtime/executor.cc62
1 files changed, 61 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 7cef34ac52..2c48084cab 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1238,6 +1238,9 @@ class ExecutorState {
// Step-local container.
ScopedStepContainer* step_container_;
StepStatsCollectorInterface* const stats_collector_;
+ const tracing::TraceCollector* const trace_collector_;
+ const tracing::EventCollector* const event_collector_;
+
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
@@ -1246,6 +1249,7 @@ class ExecutorState {
CancellationManager* cancellation_manager_;
Executor::Args::Runner runner_;
bool sync_on_finish_;
+ const bool trace_using_annotations_;
// Owned.
@@ -1360,12 +1364,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
tensor_store_(args.tensor_store),
step_container_(args.step_container),
stats_collector_(args.stats_collector),
+ trace_collector_(tracing::GetTraceCollector()),
+ event_collector_(
+ tracing::GetEventCollector(tracing::EventCategory::kCompute)),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
cancellation_manager_(args.cancellation_manager),
runner_(args.runner),
sync_on_finish_(args.sync_on_finish),
+ trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()),
num_outstanding_ops_(0) {
// We start the entire execution in iteration 0 of the root frame
// so let us create the root frame and the state for iteration 0.
@@ -1551,6 +1559,32 @@ struct ExecutorState::AsyncState {
}
};
+// Returns true if `item` might be traced by the given trace and event
+// collectors. Returns false only if `item` definitely will not be traced.
+bool MightTrace(const NodeItem& item,
+ const tracing::TraceCollector* trace_collector,
+ const tracing::EventCollector* event_collector,
+ bool using_annotations) {
+ // Tracing will only be enabled if either `event_collector` is non null,
+ // or `trace_collector` is non-null and enabled for this particular kernel.
+ // Although `tracing::ScopedActivity`,
+ // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of
+ // these properties internally in their constructors, the cost of passing the
+ // necessary arguments to them can be significant, so we avoid constructing
+ // them in the common case (when we know they will not be used).
+ if (event_collector != nullptr) {
+ return true;
+ }
+ if (trace_collector) {
+ if (using_annotations) {
+ return trace_collector->IsEnabledForAnnotations();
+ } else {
+ return trace_collector->IsEnabledForActivities(item.kernel_is_expensive);
+ }
+ }
+ return false;
+}
+
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
@@ -1585,6 +1619,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
Status s;
NodeExecStatsInterface* stats = nullptr;
+
EntryVector outputs;
bool completed = false;
inline_ready.push_back(tagged_node);
@@ -1721,7 +1756,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
// Synchronous computes.
OpKernelContext ctx(&params, item.num_outputs);
nodestats::SetOpStart(stats);
- device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+
+ if (TF_PREDICT_FALSE(MightTrace(item, trace_collector_,
+ event_collector_,
+ trace_using_annotations_))) {
+ const string& op_name = op_kernel->name();
+ tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+ op_name);
+ if (trace_using_annotations_) {
+ // The OpKernel may create child activities (such as GPU kernel
+ // launches), so use a `ScopedAnnotation` to relate these activities
+ // in the trace.
+ tracing::ScopedAnnotation activity(op_name,
+ op_kernel->type_string());
+ device->Compute(op_kernel, &ctx);
+ } else {
+ // Use the cheaper `ScopedActivity` to trace just the OpKernel
+ // execution.
+ tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
+ item.kernel_is_expensive);
+ device->Compute(op_kernel, &ctx);
+ }
+ } else {
+ // In the common case, avoid creating any tracing objects.
+ device->Compute(op_kernel, &ctx);
+ }
+
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
if (s.ok() && impl_->device_record_tensor_accesses_) {