aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/executor.cc4
1 files changed, 4 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 2c48084cab..40ec1502da 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -1240,6 +1241,7 @@ class ExecutorState {
StepStatsCollectorInterface* const stats_collector_;
const tracing::TraceCollector* const trace_collector_;
const tracing::EventCollector* const event_collector_;
+ Context context_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
@@ -1367,6 +1369,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
trace_collector_(tracing::GetTraceCollector()),
event_collector_(
tracing::GetEventCollector(tracing::EventCategory::kCompute)),
+ context_(ContextKind::kThread),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
@@ -1586,6 +1589,7 @@ bool MightTrace(const NodeItem& item,
}
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
+ WithContext wc(context_);
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
TaggedNodeReadyQueue inline_ready;