diff options
Diffstat (limited to 'tensorflow/core/common_runtime/executor.cc')
-rw-r--r-- | tensorflow/core/common_runtime/executor.cc | 62 |
1 files changed, 61 insertions, 1 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 7cef34ac52..2c48084cab 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1238,6 +1238,9 @@ class ExecutorState { // Step-local container. ScopedStepContainer* step_container_; StepStatsCollectorInterface* const stats_collector_; + const tracing::TraceCollector* const trace_collector_; + const tracing::EventCollector* const event_collector_; + // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper // instead of a pointer? (avoids having to delete). checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; @@ -1246,6 +1249,7 @@ class ExecutorState { CancellationManager* cancellation_manager_; Executor::Args::Runner runner_; bool sync_on_finish_; + const bool trace_using_annotations_; // Owned. @@ -1360,12 +1364,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) tensor_store_(args.tensor_store), step_container_(args.step_container), stats_collector_(args.stats_collector), + trace_collector_(tracing::GetTraceCollector()), + event_collector_( + tracing::GetEventCollector(tracing::EventCategory::kCompute)), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), impl_(impl), cancellation_manager_(args.cancellation_manager), runner_(args.runner), sync_on_finish_(args.sync_on_finish), + trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()), num_outstanding_ops_(0) { // We start the entire execution in iteration 0 of the root frame // so let us create the root frame and the state for iteration 0. @@ -1551,6 +1559,32 @@ struct ExecutorState::AsyncState { } }; +// Returns true if `item` might be traced by the given trace and event +// collectors. Returns false only if `item` definitely will not be traced. +bool MightTrace(const NodeItem& item, + const tracing::TraceCollector* trace_collector, + const tracing::EventCollector* event_collector, + bool using_annotations) { + // Tracing will only be enabled if either `event_collector` is non null, + // or `trace_collector` is non-null and enabled for this particular kernel. + // Although `tracing::ScopedActivity`, + // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of + // these properties internally in their constructors, the cost of passing the + // necessary arguments to them can be significant, so we avoid constructing + // them in the common case (when we know they will not be used). + if (event_collector != nullptr) { + return true; + } + if (trace_collector) { + if (using_annotations) { + return trace_collector->IsEnabledForAnnotations(); + } else { + return trace_collector->IsEnabledForActivities(item.kernel_is_expensive); + } + } + return false; +} + void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { const GraphView& gview = impl_->gview_; TaggedNodeSeq ready; @@ -1585,6 +1619,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { Status s; NodeExecStatsInterface* stats = nullptr; + EntryVector outputs; bool completed = false; inline_ready.push_back(tagged_node); @@ -1721,7 +1756,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { // Synchronous computes. OpKernelContext ctx(¶ms, item.num_outputs); nodestats::SetOpStart(stats); - device->Compute(CHECK_NOTNULL(op_kernel), &ctx); + + if (TF_PREDICT_FALSE(MightTrace(item, trace_collector_, + event_collector_, + trace_using_annotations_))) { + const string& op_name = op_kernel->name(); + tracing::ScopedRegion region(tracing::EventCategory::kCompute, + op_name); + if (trace_using_annotations_) { + // The OpKernel may create child activities (such as GPU kernel + // launches), so use a `ScopedAnnotation` to relate these activities + // in the trace. + tracing::ScopedAnnotation activity(op_name, + op_kernel->type_string()); + device->Compute(op_kernel, &ctx); + } else { + // Use the cheaper `ScopedActivity` to trace just the OpKernel + // execution. + tracing::ScopedActivity activity(op_name, op_kernel->type_string(), + item.kernel_is_expensive); + device->Compute(op_kernel, &ctx); + } + } else { + // In the common case, avoid creating any tracing objects. + device->Compute(op_kernel, &ctx); + } + nodestats::SetOpEnd(stats); s = ProcessOutputs(item, &ctx, &outputs, stats); if (s.ok() && impl_->device_record_tensor_accesses_) { |