aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/executor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/executor.cc')
-rw-r--r--tensorflow/core/common_runtime/executor.cc2118
1 files changed, 2118 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
new file mode 100644
index 0000000000..7f2473f93b
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -0,0 +1,2118 @@
+#include "tensorflow/core/common_runtime/executor.h"
+
+#include <atomic>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <deque>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/edgeset.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+namespace tensorflow {
+
+namespace {
+
+// 1-D, 0 element tensor.
+static const Tensor* const kEmptyTensor = new Tensor;
+
+bool IsInitializationOp(const Node* node) {
+ return node->op_def().allows_uninitialized_input();
+}
+
+// Sets the timeline_label field of *node_stats, using data from *node.
+// Returns true iff the node is a transfer node.
+// TODO(tucker): merge with the DetailText function in session.cc
+// in a common location.
+bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
+ bool is_transfer_node = false;
+ string memory;
+ for (auto& all : node_stats->memory()) {
+ int64 tot = all.total_bytes();
+ if (tot >= 0.1 * 1048576.0) {
+ int64 peak = all.peak_bytes();
+ if (peak > 0) {
+ memory =
+ strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
+ peak / 1048576.0));
+ } else {
+ memory = strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB] ", tot / 1048576.0));
+ }
+ }
+ }
+ const NodeDef& def = node->def();
+ string text = "";
+ if (IsSend(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device));
+ text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
+ tensor_name, " @", recv_device);
+ is_transfer_node = true;
+ } else if (IsRecv(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device));
+ text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
+ tensor_name, " @", send_device);
+ is_transfer_node = true;
+ } else {
+ text = strings::StrCat(
+ memory, def.name(), " = ", def.op(), "(",
+ str_util::Join(
+ std::vector<StringPiece>(def.input().begin(), def.input().end()),
+ ", "),
+ ")");
+ }
+ node_stats->set_timeline_label(text);
+ return is_transfer_node;
+}
+
+// Helper routines for collecting step stats.
+namespace nodestats {
+inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
+
+void SetScheduled(NodeExecStats* nt, int64 t) { nt->set_scheduled_micros(t); }
+
+void SetAllStart(NodeExecStats* nt) { nt->set_all_start_micros(NowInUsec()); }
+
+void SetOpStart(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetOpEnd(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetAllEnd(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetOutput(NodeExecStats* nt, int slot, AllocationType allocation_type,
+ const Tensor* v) {
+ DCHECK(v);
+ NodeOutput* no = nt->add_output();
+ no->set_slot(slot);
+ no->set_allocation_type(allocation_type);
+ v->FillDescription(no->mutable_tensor_description());
+}
+
+void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AllocatorMemoryUsed* memory = nt->add_memory();
+ // retrieving the sizes from the wrapped allocator removes the
+ // executor's reference to it, so allocator_pair.second must not
+ // be dereferenced again after this statement
+ auto sizes = allocator_pair.second->GetSizesAndUnRef();
+ memory->set_allocator_name(allocator_pair.first->Name());
+ int tb = sizes.first;
+ memory->set_total_bytes(tb);
+ if (allocator_pair.first->TracksAllocationSizes()) {
+ memory->set_peak_bytes(sizes.second);
+ }
+ }
+}
+} // namespace nodestats
+
+struct NodeItem {
+ // A graph node.
+ const Node* node = nullptr;
+
+ // The kernel for this node.
+ OpKernel* kernel = nullptr;
+
+ // ExecutorImpl::tensors_[input_start] is the 1st positional input
+ // for this node.
+ int input_start = 0;
+};
+
+// Map from std::pair<node_id, output_index> to attributes.
+struct pairhash {
+ public:
+ template <typename T, typename U>
+ std::size_t operator()(const std::pair<T, U>& x) const {
+ return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
+ }
+};
+typedef std::unordered_map<std::pair<int, int>, AllocatorAttributes, pairhash>
+ DevAttrMap;
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class ExecutorImpl : public Executor {
+ public:
+ ExecutorImpl(const LocalExecutorParams& p, const Graph* g)
+ : params_(p), graph_(g) {
+ CHECK(p.create_kernel != nullptr);
+ CHECK(p.delete_kernel != nullptr);
+ }
+
+ ~ExecutorImpl() override {
+ for (NodeItem& item : nodes_) {
+ params_.delete_kernel(item.kernel);
+ }
+ delete graph_;
+ }
+
+ Status Initialize();
+
+ // Infer memory allocation attributes of a node n's output,
+ // based on its use node dst. Note that dst might not be directly
+ // connected to n by a single edge, but might be a downstream
+ // consumer of n's output by reference. *attr is updated with any
+ // necessary attributes.
+ Status InferAllocAttr(const Node* n, const Node* dst,
+ const DeviceNameUtils::ParsedName& local_dev_name,
+ AllocatorAttributes* attr);
+
+ // Process all Nodes in the current graph, attempting to infer the
+ // memory allocation attributes to be used wherever they may allocate
+ // a tensor buffer.
+ Status SetAllocAttrs();
+
+ void RunAsync(const Args& args, DoneCallback done) override;
+
+ private:
+ friend class ExecutorState;
+ friend class SimpleExecutorState;
+
+ // Owned.
+ LocalExecutorParams params_;
+ const Graph* graph_;
+ std::vector<NodeItem> nodes_; // nodes_.size == graph_.num_node_ids().
+ int total_tensors_ = 0; // total_tensors_ = sum(nodes_[*].num_inputs())
+
+ // The number of inputs for each frame in this graph. This is static
+ // information of the graph.
+ std::unordered_map<string, int> frame_input_count_;
+
+ DevAttrMap alloc_attr_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
+};
+
+Status ExecutorImpl::Initialize() {
+ const int num_nodes = graph_->num_node_ids();
+ nodes_.resize(num_nodes);
+
+ Status s;
+ total_tensors_ = 0;
+
+ // Preprocess every node in the graph to create an instance of op
+ // kernel for each node;
+ for (const Node* n : graph_->nodes()) {
+ const int id = n->id();
+ NodeItem* item = &nodes_[id];
+ item->node = n;
+ item->input_start = total_tensors_;
+ total_tensors_ += n->num_inputs();
+ s = params_.create_kernel(n->def(), &item->kernel);
+ if (!s.ok()) {
+ s = AttachDef(s, n->def());
+ LOG(ERROR) << "Executor failed to create kernel. " << s;
+ break;
+ }
+ CHECK(item->kernel);
+
+ // Initialize static information about the frames in the graph.
+ if (IsEnter(n)) {
+ string frame_name;
+ s = GetNodeAttr(n->def(), "frame_name", &frame_name);
+ if (!s.ok()) return s;
+ ++frame_input_count_[frame_name];
+ }
+ }
+ if (params_.has_control_flow) {
+ VLOG(2) << "Graph has control flow.";
+ }
+ if (!s.ok()) return s;
+ return SetAllocAttrs();
+}
+
+Status ExecutorImpl::SetAllocAttrs() {
+ Status s;
+ Device* device = params_.device;
+ DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
+
+ for (const Node* n : graph_->nodes()) {
+ // Examine the out edges of each node looking for special use
+ // cases that may affect memory allocation attributes.
+ for (auto e : n->out_edges()) {
+ AllocatorAttributes attr;
+ s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
+ if (!s.ok()) return s;
+ if (attr.value != 0) {
+ VLOG(2) << "node " << n->name() << " gets attr " << attr.value
+ << " for output " << e->src_output();
+ alloc_attr_[std::make_pair(n->id(), e->src_output())].Merge(attr);
+ } else {
+ VLOG(2) << "default output attr for node " << n->name() << " output "
+ << e->src_output();
+ }
+ }
+ }
+ return s;
+}
+
+Status ExecutorImpl::InferAllocAttr(
+ const Node* n, const Node* dst,
+ const DeviceNameUtils::ParsedName& local_dev_name,
+ AllocatorAttributes* attr) {
+ Status s;
+ if (IsSend(dst)) {
+ string dst_name;
+ s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
+ if (!s.ok()) return s;
+ DeviceNameUtils::ParsedName parsed_dst_name;
+ if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
+ s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
+ n->name());
+ return s;
+ }
+ if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
+ // Value is going to be the source of an RPC.
+ attr->set_nic_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the source of an RPC out";
+ } else if (local_dev_name.type == "CPU" && parsed_dst_name.type == "GPU") {
+ // Value is going to be the source of a local DMA from CPU to GPU.
+ attr->set_gpu_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
+ } else {
+ VLOG(2) << "default alloc case local type " << local_dev_name.type
+ << " remote type " << parsed_dst_name.type;
+ }
+ } else if (dst->type_string() == "ToFloat") {
+ for (auto e : dst->out_edges()) {
+ s = InferAllocAttr(n, e->dst(), local_dev_name, attr);
+ if (!s.ok()) return s;
+ }
+ }
+ return s;
+}
+
+// The state associated with one invokation of ExecutorImpl::Run.
+// ExecutorState dispatches nodes when they become ready and keeps
+// track of how many predecessors of a node have not done (pending_).
+class ExecutorState {
+ public:
+ ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
+ ~ExecutorState();
+
+ void RunAsync(Executor::DoneCallback done);
+
+ private:
+ typedef ExecutorState ME;
+
+ // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+ // TODO(yuanbyu): A better way to do "has_value"?
+ struct Entry {
+ Tensor val = *kEmptyTensor; // A tensor value.
+ Tensor* ref = nullptr; // A tensor reference.
+ mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr.
+ bool has_value = false; // Whether the value exists
+
+ // Every entry carries an optional DeviceContext containing
+ // Device-specific information about how the Tensor was produced.
+ DeviceContext* device_context = nullptr;
+
+ // The attributes of the allocator that creates the tensor.
+ AllocatorAttributes alloc_attr;
+ };
+
+ // Contains a map from node id to the DeviceContext object that was
+ // assigned by the device at the beginning of a step.
+ DeviceContextMap device_context_map_;
+
+ struct IterationState {
+ // The state of an iteration.
+
+ // The pending count for each graph node. One copy per iteration.
+ // Iteration i can be garbage collected when it is done.
+ // TODO(yuanbyu): This vector currently has size of the number of nodes
+ // in this partition. This is not efficient if the subgraph for the frame
+ // is only a small subset of the partition. We should make the vector
+ // size to be only the size of the frame subgraph.
+ std::vector<int>* pending_count;
+
+ // The dead input count for each graph node. One copy per iteration.
+ std::vector<int>* dead_count;
+
+ // One copy per iteration. For iteration k, i-th node's j-th input is in
+ // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either
+ // a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+ //
+ // NOTE: No need to protect input_tensors[i] by any locks because it
+ // is resized once. Each element of tensors_ is written once by the
+ // source node of an edge and is cleared by the destination of the same
+ // edge. The latter node is never run concurrently with the former node.
+ std::vector<Entry>* input_tensors;
+
+ // The number of outstanding ops for each iteration.
+ int outstanding_ops;
+
+ // The number of outstanding frames for each iteration.
+ int outstanding_frame_count;
+
+ ~IterationState() {
+ delete pending_count;
+ delete dead_count;
+ delete input_tensors;
+ }
+ };
+
+ struct FrameState {
+ // A new frame is created for each loop. Execution starts at iteration 0.
+ // When a value at iteration 0 passes through a NextIteration node,
+ // iteration 1 is created and starts running. Note that iteration 0 may
+ // still be running so multiple iterations may run in parallel. The
+ // frame maintains the state of iterations in several data structures
+ // such as pending_count and input_tensors. When iteration 0 completes,
+ // we garbage collect the state of iteration 0.
+ //
+ // A frame instance is considered "done" and can be garbage collected
+ // if all its inputs have entered and all its iterations are "done".
+ //
+ // A frame manages the live iterations of an iterative computation.
+ // Iteration i is considered "done" when there are no outstanding ops,
+ // frames at iteration i are done, all recvs for this iteration are
+ // completed, and iteration i-1 is done. For iteration 0, we instead
+ // wait for there to be no more pending inputs of the frame.
+ //
+ // Frames and iterations are garbage collected once they are done.
+ // The state we need to keep around is highly dependent on the
+ // parallelism enabled by the scheduler. We may want to have the
+ // scheduler dynamically control the outstanding number of live
+ // parallel frames and iterations. To reduce the state space, the
+ // scheduler might want to schedule ops in inner frames first and
+ // lower iterations first.
+ //
+ // This frame state is mostly initialized lazily on demand so we
+ // don't introduce unnecessary overhead.
+
+ // The name of this frame, which is the concatenation of its parent
+ // frame name, the iteration of the parent frame when this frame was
+ // created, and the value of the attr 'frame_name'.
+ string frame_name;
+
+ // The unique id for this frame. Generated by fingerprinting
+ // frame_name.
+ uint64 frame_id;
+
+ // The iteration id of its parent frame when this frame is created.
+ // -1 if there is no parent frame. The frame_name/parent_iter pair
+ // uniquely identifies this FrameState.
+ int64 parent_iter = -1;
+
+ // The FrameState of its parent frame.
+ FrameState* parent_frame = nullptr;
+
+ // The highest iteration number we have reached so far in this frame.
+ int64 iteration_count = 0;
+
+ // The number of inputs this frame is still waiting.
+ int num_pending_inputs = 0;
+
+ // The number of outstanding iterations.
+ int num_outstanding_iterations = 0;
+
+ // The maximum allowed number of parallel iterations.
+ int max_parallel_iterations = 1;
+
+ // The iteration states of this frame.
+ std::vector<IterationState*> iterations;
+
+ // The NextIteration nodes to enter a new iteration. If the number of
+ // outstanding iterations reaches the limit, we will defer the start of
+ // the next iteration until the number of outstanding iterations falls
+ // below the limit.
+ std::vector<std::pair<const Node*, Entry>> next_iter_roots;
+
+ // The values of the loop invariants for this loop. They are added into
+ // this list as they "enter" the frame. When a loop invariant enters,
+ // we make it available to all active iterations. When the frame starts
+ // a new iteration, we make all the current loop invariants available
+ // to the new iteration.
+ std::vector<std::pair<const Node*, Entry>> inv_values;
+
+ // The list of dead exit nodes for the current highest iteration. We
+ // will only "execute" the dead exits of the final iteration.
+ std::vector<const Node*> dead_exits;
+
+ IterationState* GetIteration(int64 iter) {
+ int index = iter % iterations.size();
+ return iterations[index];
+ }
+
+ void SetIteration(int64 iter, IterationState* state) {
+ int index = iter % iterations.size();
+ iterations[index] = state;
+ }
+
+ ~FrameState() {
+ for (size_t i = 0; i < iterations.size(); ++i) {
+ delete iterations[i];
+ iterations[i] = nullptr;
+ }
+ }
+ };
+
+ // A tagged node: <frame*, iter, node*>.
+ struct TaggedNode {
+ const Node* node = nullptr;
+ FrameState* input_frame = nullptr;
+ int64 input_iter = -1;
+ bool is_dead = false;
+
+ TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter,
+ bool dead) {
+ node = t_node;
+ input_frame = in_frame;
+ input_iter = in_iter;
+ is_dead = dead;
+ }
+ };
+
+ typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
+ typedef gtl::InlinedVector<Entry, 4> EntryVector;
+
+ // Not owned.
+ Rendezvous* rendezvous_;
+ StepStatsCollector* stats_collector_;
+ // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper instead of a
+ // pointer? (avoids having to delete).
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
+ FunctionCallFrame* call_frame_;
+ const ExecutorImpl* impl_;
+ CancellationManager* cancellation_manager_;
+ Executor::Args::Runner runner_;
+
+ // Owned.
+
+ // Step-local resource manager.
+ ResourceMgr step_resource_manager_;
+
+ // The root frame in which the execution of this step is started.
+ FrameState* root_frame_;
+
+ // Invoked when the execution finishes.
+ Executor::DoneCallback done_cb_;
+
+ std::atomic_int_fast32_t num_outstanding_ops_;
+
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+
+ // Mapping from frame name to outstanding frames. A new frame is created
+ // at some iteration of an active frame. So the unique key for the new
+ // child frame is composed of the name of the parent frame, the iteration
+ // number at which the parent frame is creating the new frame, and the
+ // name of the new frame from nodedef.
+ std::unordered_map<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
+
+ // The unique name of a frame.
+ inline string MakeFrameName(FrameState* frame, int64 iter_id, string name) {
+ return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
+ }
+
+ // Initialize the pending count for a graph.
+ static void InitializePending(const Graph* graph, std::vector<int>* pending);
+
+ // Find an existing or create a new child frame in the frame 'frame' at
+ // iteration 'iter'.
+ void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node,
+ FrameState** child) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Increments the iteration id. If this is a new iteration, initialize it.
+ void IncrementIteration(FrameState* frame, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns true if the computation in the frame is completed.
+ bool IsFrameDone(FrameState* frame) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns true if the iteration of the frame is completed.
+ bool IsIterationDone(FrameState* frame, int64 iter)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Get the output frame/iter of a node. Create new frame/iteration if
+ // needed. If there are dead roots for the new iteration, we need to
+ // "execute" them so ad them to the ready queue. Returns true if
+ // we need to check for the completion of output frame/iter.
+ bool SetOutputFrameIter(const TaggedNode& tagged_node,
+ const EntryVector& outputs, FrameState** frame,
+ int64* iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Cleanup frames and iterations
+ void CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate all the deferred NextIteration nodes in a new iteration.
+ void ActivateNexts(FrameState* frame, int64 iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate all the current loop invariants in a new iteration.
+ void ActivateLoopInvs(FrameState* frame, int64 iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Add a new loop invariant and make it available to all active iterations.
+ void AddLoopInv(FrameState* frame, const Node* node, const Entry& value,
+ TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate the successors of a node.
+ void ActivateNode(const Node* node, const bool is_dead, FrameState* frame,
+ int64 iter, const EntryVector& outputs,
+ TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Process a ready node in current thread.
+ void Process(TaggedNode node, int64 scheduled_usec);
+
+ // Before invoking item->kernel, fills in its "inputs".
+ Status PrepareInputs(const NodeItem& item, Entry* first_input,
+ TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts,
+ AllocatorAttributeVec* input_alloc_attrs,
+ bool* is_input_dead);
+
+ // After item->kernel computation is done, processes its outputs.
+ Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ EntryVector* outputs, NodeExecStats* stats);
+
+ // After processing the outputs, propagates the outputs to their dsts.
+ void PropagateOutputs(const TaggedNode& tagged_node,
+ const EntryVector& outputs, TaggedNodeSeq* ready);
+
+ // "node" just finishes. Takes ownership of "stats". Returns true if
+ // execution has completed.
+ bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
+ NodeExecStats* stats, std::deque<TaggedNode>* inline_ready);
+
+ // Call Process() on all nodes in 'inline_ready'.
+ void ProcessInline(const std::deque<TaggedNode>& inline_ready);
+
+ // Schedule all the expensive nodes in 'ready', and put all the inexpensive
+ // nodes in 'ready' into 'inline_ready'.
+ void ScheduleReady(const TaggedNodeSeq& ready,
+ std::deque<TaggedNode>* inline_ready);
+
+ // One thread of control finishes.
+ void Finish();
+};
+
+ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
+ : rendezvous_(args.rendezvous),
+ stats_collector_(args.stats_collector),
+ slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
+ call_frame_(args.call_frame),
+ impl_(impl),
+ cancellation_manager_(args.cancellation_manager),
+ runner_(args.runner),
+ num_outstanding_ops_(0) {
+ // We start the entire execution in iteration 0 of the root frame
+ // so let us create the root frame and the state for iteration 0.
+ // Initialize the frame.
+ root_frame_ = new FrameState;
+ root_frame_->frame_name = "_root"; // assume to be unique
+ root_frame_->frame_id = 0; // must be 0
+ root_frame_->num_pending_inputs = 0;
+ root_frame_->num_outstanding_iterations = 1;
+ root_frame_->max_parallel_iterations = 1; // enough for root frame
+ root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
+
+ VLOG(2) << "Create frame: " << root_frame_->frame_name;
+
+ // Initialize the iteration.
+ IterationState* iter_state = new IterationState;
+ root_frame_->iterations[0] = iter_state;
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ iter_state->dead_count = new std::vector<int>(impl->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ // Initialize the executor state.
+ outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
+}
+
+ExecutorState::~ExecutorState() {
+ for (auto name_frame : outstanding_frames_) {
+ delete name_frame.second;
+ }
+
+ for (auto it : device_context_map_) {
+ it.second->Unref();
+ }
+
+ delete slice_reader_cache_;
+}
+
+void ExecutorState::InitializePending(const Graph* graph,
+ std::vector<int>* pending) {
+ pending->resize(graph->num_node_ids());
+ for (const Node* n : graph->nodes()) {
+ const int id = n->id();
+ const int num_in_edges = n->in_edges().size();
+ if (IsMerge(n)) {
+ // merge waits all control inputs so we initialize the pending
+ // count to be the number of control edges.
+ int32 num_control_edges = 0;
+ for (const Edge* edge : n->in_edges()) {
+ if (edge->IsControlEdge()) {
+ num_control_edges++;
+ }
+ }
+ // Use bit 0 to indicate if there is a ready live data input.
+ (*pending)[id] = num_control_edges << 1;
+ } else {
+ (*pending)[id] = num_in_edges;
+ }
+ }
+}
+
+void ExecutorState::RunAsync(Executor::DoneCallback done) {
+ const Graph* graph = impl_->graph_;
+ TaggedNodeSeq ready;
+
+ {
+ // Initialize the executor state. We grab the mutex here just to
+ // keep the thread safety analysis happy.
+ mutex_lock l(mu_);
+ std::vector<int>* pending = root_frame_->iterations[0]->pending_count;
+ InitializePending(graph, pending);
+ }
+
+ // Ask the device to fill in the device context map.
+ Device* device = impl_->params_.device;
+ device->FillContextMap(graph, &device_context_map_);
+
+ // Initialize the ready queue.
+ for (const Node* n : graph->nodes()) {
+ const int num_in_edges = n->in_edges().size();
+ if (num_in_edges == 0) {
+ ready.push_back(TaggedNode{n, root_frame_, 0, false});
+ }
+ }
+ if (ready.empty()) {
+ done(Status::OK());
+ } else {
+ num_outstanding_ops_ = ready.size();
+ root_frame_->iterations[0]->outstanding_ops = ready.size();
+ done_cb_ = done;
+ // Schedule to run all the ready ops in thread pool.
+ ScheduleReady(ready, nullptr);
+ }
+}
+
+namespace {
+
+// This function is provided for use by OpKernelContext when allocating
+// the index'th output of node. It provides access to the
+// AllocatorAttributes computed during initialization to determine in
+// which memory region the tensor should be allocated.
+AllocatorAttributes OutputAttributes(const DevAttrMap* attr_map,
+ const Node* node,
+ const OpKernel* op_kernel, int index) {
+ DCHECK_GE(index, 0);
+
+ AllocatorAttributes attr;
+ int nid = node->id();
+ const auto& iter = attr_map->find(std::make_pair(nid, index));
+ if (iter != attr_map->end()) {
+ attr = iter->second;
+ VLOG(2) << "nondefault attr " << attr.value << " for node " << node->name()
+ << " output " << index;
+ } else {
+ VLOG(2) << "default attr for node " << node->name() << " output " << index;
+ }
+
+ DCHECK_LT(index, op_kernel->output_memory_types().size());
+ bool on_host = op_kernel->output_memory_types()[index] == HOST_MEMORY;
+ attr.set_on_host(on_host);
+ return attr;
+}
+
+// Helpers to make a copy of 'p' and makes a copy of the input type
+// vector and the device context vector.
+//
+// NOTE: We need to make a copy of p.input for asynchronous kernel
+// because OpKernelContext methods like input_type(i) needs the param
+// points to valid input type vector. It's not an issue for sync
+// kernels because the type vector is kept on the stack.
+OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
+ OpKernelContext::Params* ret = new OpKernelContext::Params;
+ *ret = p;
+ ret->inputs = new TensorValueVec(*p.inputs);
+ ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts);
+ ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs);
+ return ret;
+}
+
+// Helpers to delete 'p' and copies made by CopyParams.
+void DeleteParams(OpKernelContext::Params* p) {
+ delete p->inputs;
+ delete p->input_device_contexts;
+ delete p->input_alloc_attrs;
+ delete p;
+}
+
+} // namespace
+
+void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ TaggedNodeSeq ready;
+ std::deque<TaggedNode> inline_ready;
+
+ // Parameters passed to OpKernel::Compute.
+ TensorValueVec inputs;
+ DeviceContextVec input_device_contexts;
+ AllocatorAttributeVec input_alloc_attrs;
+
+ OpKernelContext::Params params;
+ Device* device = impl_->params_.device;
+ params.device = device;
+ // track allocations if and only if we are collecting statistics
+ params.track_allocations = (stats_collector_ != nullptr);
+ params.rendezvous = rendezvous_;
+ params.cancellation_manager = cancellation_manager_;
+ params.call_frame = call_frame_;
+ params.function_library = impl_->params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_resource_manager = &step_resource_manager_;
+ params.slice_reader_cache = slice_reader_cache_;
+ params.inputs = &inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.input_alloc_attrs = &input_alloc_attrs;
+
+ Status s;
+ NodeExecStats* stats = nullptr;
+ EntryVector outputs;
+ bool completed = false;
+ inline_ready.push_back(tagged_node);
+ while (!inline_ready.empty()) {
+ tagged_node = inline_ready.front();
+ inline_ready.pop_front();
+ const Node* node = tagged_node.node;
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+ const int id = node->id();
+ const NodeItem& item = nodes[id];
+
+ // Set the device_context for this node id, if it exists.
+ auto dc_it = device_context_map_.find(id);
+ if (dc_it != device_context_map_.end()) {
+ params.op_device_context = dc_it->second;
+ }
+
+ if (stats_collector_) {
+ stats = new NodeExecStats;
+ stats->set_node_name(node->name());
+ nodestats::SetScheduled(stats, scheduled_usec);
+ nodestats::SetAllStart(stats);
+ }
+
+ VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def());
+
+ std::vector<Entry>* input_tensors;
+ {
+ // Need the lock because the iterations vector could be resized by
+ // another thread.
+ mutex_lock l(mu_);
+ input_tensors = input_frame->GetIteration(input_iter)->input_tensors;
+ }
+ Entry* first_input = input_tensors->data() + item.input_start;
+ outputs.clear();
+ outputs.resize(node->num_outputs());
+
+ // Only execute this node if it is not dead or it is a send/recv
+ // transfer node. For transfer nodes, we need to propagate the "dead"
+ // bit even when the node is dead.
+ AsyncOpKernel* async = nullptr;
+ if (!tagged_node.is_dead || IsTransferNode(node)) {
+ // Prepares inputs.
+ bool is_input_dead = false;
+ s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
+ &input_alloc_attrs, &is_input_dead);
+ if (!s.ok()) {
+ // Continue to process the nodes in 'inline_ready'.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ continue;
+ }
+
+ // Set up compute params.
+ OpKernel* op_kernel = item.kernel;
+ params.op_kernel = op_kernel;
+ params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
+ params.is_input_dead = is_input_dead;
+ params.output_alloc_attr = [this, node, op_kernel](int index) {
+ return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index);
+ };
+
+ async = op_kernel->AsAsync();
+ if (async) {
+ // Asynchronous computes.
+ auto pcopy = CopyParams(params);
+ auto ctx = new OpKernelContext(*pcopy);
+ auto done = [this, tagged_node, item, first_input, ctx, stats,
+ pcopy]() {
+ VLOG(2) << this << " Async kernel done: "
+ << SummarizeNodeDef(item.node->def());
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+ EntryVector outputs;
+ Status s = ProcessOutputs(item, ctx, &outputs, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, ctx);
+ // Clears inputs.
+ int num_inputs = tagged_node.node->num_inputs();
+ for (int i = 0; i < num_inputs; ++i) {
+ (first_input + i)->val = *kEmptyTensor;
+ }
+ TaggedNodeSeq ready;
+ if (s.ok()) {
+ PropagateOutputs(tagged_node, outputs, &ready);
+ }
+ // Schedule to run all the ready ops in thread pool.
+ bool completed = NodeDone(s, item.node, ready, stats, nullptr);
+ delete ctx;
+ DeleteParams(pcopy);
+ if (completed) Finish();
+ };
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->ComputeAsync(async, ctx, done);
+ } else {
+ // Synchronous computes.
+ OpKernelContext ctx(params);
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+
+ // Processes outputs.
+ s = ProcessOutputs(item, &ctx, &outputs, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, &ctx);
+ }
+ }
+
+ if (!async) {
+ // Clears inputs.
+ int num_inputs = node->num_inputs();
+ for (int i = 0; i < num_inputs; ++i) {
+ (first_input + i)->val = *kEmptyTensor;
+ }
+ // Propagates outputs.
+ if (s.ok()) {
+ PropagateOutputs(tagged_node, outputs, &ready);
+ }
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ // Postprocess.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ }
+ } // while !inline_ready.empty()
+
+ // This thread of computation is done if completed = true.
+ if (completed) Finish();
+}
+
+Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
+ TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts,
+ AllocatorAttributeVec* input_alloc_attrs,
+ bool* is_input_dead) {
+ const Node* node = item.node;
+
+ inputs->clear();
+ inputs->resize(node->num_inputs());
+ input_device_contexts->clear();
+ input_device_contexts->resize(node->num_inputs());
+ input_alloc_attrs->clear();
+ input_alloc_attrs->resize(node->num_inputs());
+
+ *is_input_dead = false;
+
+ bool is_merge = IsMerge(node);
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ const bool expect_ref = IsRefType(node->input_type(i));
+ Entry* entry = first_input + i;
+ (*input_device_contexts)[i] = entry->device_context;
+ (*input_alloc_attrs)[i] = entry->alloc_attr;
+
+ // i-th input.
+ TensorValue* inp = &(*inputs)[i];
+
+ // Only merge and transfer nodes can have no-value inputs.
+ if (!entry->has_value) {
+ if (!is_merge) {
+ DCHECK(IsTransferNode(node));
+ inp->tensor = &entry->val;
+ *is_input_dead = true;
+ }
+ continue;
+ }
+ if (entry->ref == nullptr) {
+ if (expect_ref) {
+ return AttachDef(
+ errors::InvalidArgument(i, "-th input expects a ref type"),
+ item.kernel->def());
+ }
+ inp->tensor = &entry->val;
+ } else {
+ if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
+ return AttachDef(
+ errors::FailedPrecondition("Attempting to use uninitialized value ",
+ item.kernel->def().input(i)),
+ item.kernel->def());
+ }
+ if (expect_ref) {
+ inp->mutex_if_ref = entry->ref_mu;
+ inp->tensor = entry->ref;
+ } else {
+ // Automatically deref the tensor ref when the op expects a
+ // tensor but is given a ref to a tensor. Need to deref it
+ // under the mutex.
+ {
+ mutex_lock l(*(entry->ref_mu));
+ entry->val = *entry->ref;
+ }
+ inp->tensor = &entry->val;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ EntryVector* outputs,
+ NodeExecStats* stats) {
+ const Node* node = item.node;
+ outputs->clear();
+ outputs->resize(node->num_outputs());
+
+ Status s = ctx->status();
+ if (!s.ok()) {
+ s = AttachDef(s, item.kernel->def());
+ LOG(WARNING) << this << " Compute status: " << s;
+ return s;
+ }
+
+ // Get the device_context for this node id, if it exists.
+ DeviceContext* device_context = nullptr;
+ auto dc_it = device_context_map_.find(node->id());
+ if (dc_it != device_context_map_.end()) {
+ device_context = dc_it->second;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ TensorValue val = ctx->release_output(i);
+ // Only Switch and Recv nodes can generate new dead outputs
+ if (*ctx->is_output_dead() || val.tensor == nullptr) {
+ DCHECK(IsSwitch(node) || IsRecv(node));
+ } else {
+ Entry* out = &((*outputs)[i]);
+ out->has_value = true;
+
+ // Set the device context of the output entry.
+ out->device_context = device_context;
+
+ // Set the allocator attributes of the output entry.
+ out->alloc_attr = ctx->output_alloc_attr(i);
+
+ // Sanity check of output tensor types.
+ DataType dtype = val->dtype();
+ if (val.is_ref()) dtype = MakeRefType(dtype);
+ if (dtype == node->output_type(i)) {
+ if (val.is_ref()) {
+ out->ref = val.tensor;
+ out->ref_mu = val.mutex_if_ref;
+ } else {
+ out->val = *val.tensor;
+ }
+ if (stats_collector_ && val.tensor->IsInitialized()) {
+ nodestats::SetOutput(stats, i, ctx->output_allocation_type(i),
+ val.tensor);
+ }
+ } else {
+ s.Update(errors::Internal("Output ", i, " of type ",
+ DataTypeString(dtype),
+ " does not match declared output type ",
+ DataTypeString(node->output_type(i)),
+ " for node ", SummarizeNodeDef(node->def())));
+ }
+ }
+ if (!val.is_ref()) {
+ // If OpKernelContext returns outputs via pass-by-value, we
+ // don't need this trouble.
+ delete val.tensor;
+ }
+ }
+ return s;
+}
+
+void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
+ const EntryVector& outputs,
+ TaggedNodeSeq* ready) {
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+
+ // Propagates outputs along out edges, and puts newly ready nodes
+ // into the ready queue.
+ ready->clear();
+
+ {
+ FrameState* output_frame = input_frame;
+ int64 output_iter = input_iter;
+
+ mutex_lock l(mu_);
+ // Sets the output_frame and output_iter of node.
+ bool maybe_completed = SetOutputFrameIter(
+ tagged_node, outputs, &output_frame, &output_iter, ready);
+ if (output_frame != nullptr) {
+ // Continue to process the out nodes:
+ ActivateNode(tagged_node.node, tagged_node.is_dead, output_frame,
+ output_iter, outputs, ready);
+ }
+
+ // At this point, this node is completely done.
+ input_frame->GetIteration(input_iter)->outstanding_ops--;
+ CleanupFramesIterations(input_frame, input_iter, ready);
+
+ // The execution of a node such as Enter may cause the completion of
+ // output_frame:output_iter, so perform cleanup if output_frame:output_iter
+ // is indeed completed.
+ if (maybe_completed) {
+ CleanupFramesIterations(output_frame, output_iter, ready);
+ }
+ }
+}
+
+void ExecutorState::ActivateNode(const Node* node, const bool is_dead,
+ FrameState* output_frame, int64 output_iter,
+ const EntryVector& outputs,
+ TaggedNodeSeq* ready) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ IterationState* output_iter_state = output_frame->GetIteration(output_iter);
+ std::vector<int>* pending = output_iter_state->pending_count;
+ std::vector<int>* dead_count = output_iter_state->dead_count;
+ for (const Edge* e : node->out_edges()) {
+ const Node* dst_node = e->dst();
+ const int dst_id = dst_node->id();
+ const int src_slot = e->src_output();
+
+ bool dst_dead = false;
+ bool dst_ready = false;
+ bool dst_need_input = !e->IsControlEdge();
+ if (IsMerge(dst_node)) {
+ // A merge node is ready if a) all control edges are enabled and a
+ // live data input becomes available, or b) all control edges are
+ // enabled and all data inputs are dead.
+ if (e->IsControlEdge()) {
+ (*pending)[dst_id] -= 2;
+ int count = (*pending)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = (count == 1) || ((count == 0) && dst_dead);
+ } else {
+ if (outputs[src_slot].has_value) {
+ // This is a live data input.
+ int count = (*pending)[dst_id];
+ (*pending)[dst_id] |= 0x1;
+ dst_ready = (count == 0);
+ } else {
+ // This is a dead data input.
+ ++(*dead_count)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = ((*pending)[dst_id] == 0) && dst_dead;
+ }
+ // This input for dst is not needed if !dst_ready. We suppress the
+ // propagation to make the thread safety analysis happy.
+ dst_need_input = dst_ready;
+ }
+ } else {
+ // A non-merge node is ready if all its inputs are ready. We wait
+ // for all inputs to come in even if we know the node is dead. This
+ // ensures that all input tensors get cleaned up.
+ if (is_dead || (!e->IsControlEdge() && !outputs[src_slot].has_value)) {
+ ++(*dead_count)[dst_id];
+ }
+ dst_dead = (*dead_count)[dst_id] > 0;
+ dst_ready = (--(*pending)[dst_id] == 0);
+ }
+
+ if (dst_need_input) {
+ const NodeItem& dst_item = nodes[dst_id];
+ const int dst_slot = e->dst_input();
+ std::vector<Entry>* input_tensors = output_iter_state->input_tensors;
+ int dst_loc = dst_item.input_start + dst_slot;
+ (*input_tensors)[dst_loc] = outputs[src_slot];
+ }
+
+ // Add dst to the ready queue if it's ready
+ if (dst_ready) {
+ dst_dead = dst_dead && !IsControlTrigger(dst_node);
+ ready->push_back(
+ TaggedNode(dst_node, output_frame, output_iter, dst_dead));
+ output_iter_state->outstanding_ops++;
+ }
+ }
+}
+
+void ExecutorState::ActivateNexts(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate the deferred NextIteration nodes to the new iteration.
+ for (auto& node_entry : frame->next_iter_roots) {
+ const Node* node = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = !entry.has_value;
+ ActivateNode(node, is_dead, frame, iter, {entry}, ready);
+ }
+ frame->next_iter_roots.clear();
+}
+
+void ExecutorState::ActivateLoopInvs(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate loop invariants to the new iteration.
+ for (auto& node_entry : frame->inv_values) {
+ const Node* node = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = !entry.has_value;
+ ActivateNode(node, is_dead, frame, iter, {entry}, ready);
+ }
+}
+
+void ExecutorState::AddLoopInv(FrameState* frame, const Node* node,
+ const Entry& entry, TaggedNodeSeq* ready) {
+ // Store this value.
+ frame->inv_values.push_back({node, entry});
+
+ // Make this value available to all iterations.
+ bool is_dead = !entry.has_value;
+ for (int i = 1; i <= frame->iteration_count; ++i) {
+ ActivateNode(node, is_dead, frame, i, {entry}, ready);
+ }
+}
+
+bool ExecutorState::NodeDone(const Status& s, const Node* node,
+ const TaggedNodeSeq& ready, NodeExecStats* stats,
+ std::deque<TaggedNode>* inline_ready) {
+ if (stats_collector_) {
+ nodestats::SetAllEnd(stats);
+ if (!SetTimelineLabel(node, stats)) {
+ // Only record non-transfer nodes.
+ stats_collector_->Save(impl_->params_.device->name(), stats);
+ } else {
+ delete stats;
+ }
+ }
+
+ Rendezvous* captured_rendezvous = nullptr; // Will be set on error.
+ if (!s.ok()) {
+ // Some error happened. This thread of computation is done.
+ mutex_lock l(mu_);
+ if (status_.ok()) {
+ captured_rendezvous = rendezvous_;
+ if (captured_rendezvous) captured_rendezvous->Ref();
+ status_ = s;
+ }
+ }
+ if (captured_rendezvous) {
+ // If we captured the rendezvous_ pointer, we are in an error condition.
+ // Use captured_rendezvous, in case "this" is deleted by another thread.
+ TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
+ captured_rendezvous->StartAbort(s);
+ captured_rendezvous->Unref();
+ }
+
+ bool completed = false;
+ int ready_size = ready.size();
+ if (ready_size == 0 || !s.ok()) {
+ completed = (num_outstanding_ops_.fetch_sub(1) == 1);
+ } else if (ready_size > 1) {
+ num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
+ }
+
+ // Schedule the ready nodes in 'ready'.
+ if (s.ok()) {
+ ScheduleReady(ready, inline_ready);
+ }
+ return completed;
+}
+
+void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
+ if (inline_ready.empty()) return;
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ for (auto& tagged_node : inline_ready) {
+ Process(tagged_node, scheduled_usec);
+ }
+}
+
+void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
+ std::deque<TaggedNode>* inline_ready) {
+ if (ready.empty()) return;
+
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ if (inline_ready == nullptr) {
+ // Schedule to run all the ready ops in thread pool.
+ for (auto& tagged_node : ready) {
+ runner_(std::bind(&ME::Process, this, tagged_node, scheduled_usec));
+ }
+ return;
+ }
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ const TaggedNode* curr_expensive_node = nullptr;
+ for (auto& tagged_node : ready) {
+ const NodeItem& item = nodes[tagged_node.node->id()];
+ if (tagged_node.is_dead || !item.kernel->IsExpensive()) {
+ // Inline this inexpensive node.
+ inline_ready->push_back(tagged_node);
+ } else {
+ if (curr_expensive_node) {
+ // Dispatch to another thread since there is plenty of work to
+ // do for this thread.
+ runner_(std::bind(&ME::Process, this, *curr_expensive_node,
+ scheduled_usec));
+ }
+ curr_expensive_node = &tagged_node;
+ }
+ }
+ if (curr_expensive_node) {
+ if (inline_ready->empty()) {
+ // Tail recursion optimization
+ inline_ready->push_back(*curr_expensive_node);
+ } else {
+ // There are inline nodes to run already. We dispatch this expensive
+ // node to other thread.
+ runner_(
+ std::bind(&ME::Process, this, *curr_expensive_node, scheduled_usec));
+ }
+ }
+}
+
+void ExecutorState::Finish() {
+ mu_.lock();
+ auto status = status_;
+ auto done_cb = done_cb_;
+ auto runner = runner_;
+ mu_.unlock();
+ delete this;
+ CHECK(done_cb != nullptr);
+ runner([done_cb, status]() { done_cb(status); });
+}
+
+bool ExecutorState::IsFrameDone(FrameState* frame) {
+ return (frame->num_pending_inputs == 0 &&
+ frame->num_outstanding_iterations == 0);
+}
+
+bool ExecutorState::IsIterationDone(FrameState* frame, int64 iter) {
+ IterationState* iter_state = frame->GetIteration(iter);
+ if (iter_state->outstanding_ops == 0 &&
+ iter_state->outstanding_frame_count == 0) {
+ if (iter == 0) {
+ // The enclosing frame has no pending input.
+ return frame->num_pending_inputs == 0;
+ } else {
+ // The preceding iteration is deleted (and therefore done).
+ return (frame->GetIteration(iter - 1) == nullptr);
+ }
+ }
+ return false;
+}
+
+void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
+ const Node* node,
+ FrameState** child) {
+ // Get the child frame name.
+ string enter_name;
+ Status s = GetNodeAttr(node->def(), "frame_name", &enter_name);
+ CHECK(s.ok()) << s;
+ const string child_name = MakeFrameName(frame, iter, enter_name);
+
+ auto it = outstanding_frames_.find(child_name);
+ if (it != outstanding_frames_.end()) {
+ *child = it->second;
+ } else {
+ // Need to create a new frame instance.
+ VLOG(2) << "Create frame: " << child_name;
+
+ FrameState* temp = new FrameState;
+ temp->frame_name = child_name;
+ temp->frame_id = Hash64(child_name);
+ temp->parent_frame = frame;
+ temp->parent_iter = iter;
+ s = GetNodeAttr(node->def(), "parallel_iterations",
+ &temp->max_parallel_iterations);
+ CHECK(s.ok()) << s;
+ // 'iterations' is a fixed-length circular buffer.
+ temp->iterations.resize(temp->max_parallel_iterations + 1);
+ IterationState* iter_state = new IterationState;
+ temp->iterations[0] = iter_state;
+
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ InitializePending(impl_->graph_, iter_state->pending_count);
+ iter_state->dead_count =
+ new std::vector<int>(impl_->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ auto frame_pending = impl_->frame_input_count_.find(enter_name);
+ DCHECK(frame_pending != impl_->frame_input_count_.end());
+ temp->num_pending_inputs = frame_pending->second;
+ temp->num_outstanding_iterations = 1;
+ *child = temp;
+
+ frame->GetIteration(iter)->outstanding_frame_count++;
+ outstanding_frames_[child_name] = temp;
+ }
+}
+
+void ExecutorState::IncrementIteration(FrameState* frame,
+ TaggedNodeSeq* ready) {
+ frame->iteration_count++;
+ int64 next_iter = frame->iteration_count;
+
+ VLOG(2) << "Create iteration: [" << frame->frame_name << ", " << next_iter
+ << "]";
+
+ IterationState* iter_state = new IterationState;
+ frame->SetIteration(next_iter, iter_state);
+ frame->num_outstanding_iterations++;
+ frame->dead_exits.clear();
+
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ InitializePending(impl_->graph_, iter_state->pending_count);
+ iter_state->dead_count = new std::vector<int>(impl_->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ // Activate the successors of the deferred roots in the new iteration.
+ ActivateNexts(frame, next_iter, ready);
+
+ // Activate the loop invariants in the new iteration.
+ ActivateLoopInvs(frame, next_iter, ready);
+}
+
+bool ExecutorState::SetOutputFrameIter(const TaggedNode& tagged_node,
+ const EntryVector& outputs,
+ FrameState** output_frame,
+ int64* output_iter,
+ TaggedNodeSeq* ready) {
+ const Node* node = tagged_node.node;
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+ bool is_dead = tagged_node.is_dead;
+ bool is_enter = IsEnter(node);
+
+ if (is_enter) {
+ FindOrCreateChildFrame(input_frame, input_iter, node, output_frame);
+ // Propagate if this is a loop invariant.
+ bool is_constant;
+ Status s = GetNodeAttr(node->def(), "is_constant", &is_constant);
+ CHECK(s.ok()) << s;
+ if (is_constant) {
+ AddLoopInv(*output_frame, node, outputs[0], ready);
+ }
+ --(*output_frame)->num_pending_inputs;
+ *output_iter = 0;
+ } else if (IsExit(node)) {
+ if (is_dead) {
+ // Stop and remember this node if it is a dead exit.
+ if (input_iter == input_frame->iteration_count) {
+ input_frame->dead_exits.push_back(node);
+ }
+ *output_frame = nullptr;
+ } else {
+ *output_frame = input_frame->parent_frame;
+ *output_iter = input_frame->parent_iter;
+ }
+ } else if (IsNextIteration(node)) {
+ if (is_dead) {
+ // Stop the deadness propagation
+ *output_frame = nullptr;
+ } else {
+ if (input_iter == input_frame->iteration_count &&
+ input_frame->num_outstanding_iterations ==
+ input_frame->max_parallel_iterations) {
+ // Reached the maximum for parallel iterations.
+ input_frame->next_iter_roots.push_back({node, outputs[0]});
+ *output_frame = nullptr;
+ } else {
+ // If this is a new iteration, start it.
+ if (input_iter == input_frame->iteration_count) {
+ IncrementIteration(input_frame, ready);
+ }
+ *output_iter = input_iter + 1;
+ }
+ }
+ }
+ return is_enter;
+}
+
+void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ int64 curr_iter = iter;
+ while (curr_iter <= frame->iteration_count &&
+ IsIterationDone(frame, curr_iter)) {
+ // Delete the iteration curr_iter
+ VLOG(2) << "Delete iteration [" << frame->frame_name << ", " << curr_iter
+ << "].";
+
+ delete frame->GetIteration(curr_iter);
+ frame->SetIteration(curr_iter, nullptr);
+ --frame->num_outstanding_iterations;
+ ++curr_iter;
+
+ // If there is a deferred iteration, start it.
+ if (frame->next_iter_roots.size() > 0) {
+ IncrementIteration(frame, ready);
+ }
+ }
+
+ if (IsFrameDone(frame)) {
+ FrameState* parent_frame = frame->parent_frame;
+ int64 parent_iter = frame->parent_iter;
+
+ // Propagate all the dead exits to the parent frame.
+ for (const Node* node : frame->dead_exits) {
+ auto parent_iter_state = parent_frame->GetIteration(parent_iter);
+ std::vector<int>* pending = parent_iter_state->pending_count;
+ std::vector<int>* dead_count = parent_iter_state->dead_count;
+ for (const Edge* e : node->out_edges()) {
+ const Node* dst_node = e->dst();
+ const int dst_id = dst_node->id();
+
+ bool dst_dead = true;
+ bool dst_ready = false;
+ // We know this is a dead input to dst
+ if (IsMerge(dst_node)) {
+ if (e->IsControlEdge()) {
+ (*pending)[dst_id] -= 2;
+ int count = (*pending)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = (count == 1) || ((count == 0) && dst_dead);
+ } else {
+ ++(*dead_count)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = ((*pending)[dst_id] == 0) && dst_dead;
+ }
+ } else {
+ ++(*dead_count)[dst_id];
+ dst_ready = (--(*pending)[dst_id] == 0);
+ }
+ if (dst_ready) {
+ ready->push_back(
+ TaggedNode(dst_node, parent_frame, parent_iter, dst_dead));
+ parent_iter_state->outstanding_ops++;
+ }
+ }
+ }
+
+ // Delete the frame
+ const string& frame_name = frame->frame_name;
+ VLOG(2) << "Delete frame " << frame_name;
+ outstanding_frames_.erase(frame_name);
+ delete frame;
+
+ // Cleanup recursively
+ if (parent_frame != nullptr) {
+ parent_frame->GetIteration(parent_iter)->outstanding_frame_count--;
+ CleanupFramesIterations(parent_frame, parent_iter, ready);
+ }
+ }
+}
+
+// When ExecutorImpl graph has no control flow nodes,
+// SimpleExecutorState is used instead of ExecutorState. It maintains
+// fewer internal state and is convenient for experimenting with async
+// op kernels.
+class SimpleExecutorState {
+ public:
+ SimpleExecutorState(const Executor::Args& args, ExecutorImpl* impl);
+ ~SimpleExecutorState() {
+ for (auto it : device_context_map_) {
+ it.second->Unref();
+ }
+ delete slice_reader_cache_;
+ }
+ void RunAsync(Executor::DoneCallback done);
+
+ private:
+ typedef SimpleExecutorState ME;
+
+ // Not owned.
+ Rendezvous* rendezvous_;
+ StepStatsCollector* stats_collector_;
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
+ FunctionCallFrame* call_frame_;
+ const ExecutorImpl* impl_;
+ CancellationManager* cancellation_manager_;
+ Executor::Args::Runner runner_;
+
+ // Owned.
+
+ // i-th node's j-th input is in tensors_[impl_->nodes[i].input_start
+ // + j]. The output is either a tensor pointer (pass-by-reference)
+ // or a tensor (pass-by-value).
+ //
+ // NOTE: Not protected by mu_ because tensors_ is resized once. Each
+ // element of tensors_ is written once by the source node of an edge
+ // and is cleared by the destination of the same edge. The latter
+ // node is never run concurrently with the former node.
+ struct Entry {
+ Tensor val = *kEmptyTensor; // A tensor value.
+ Tensor* ref = nullptr; // A tensor reference.
+ mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr.
+
+ // Every entry carries an optional DeviceContext containing
+ // Device-specific information about how the Tensor was produced.
+ DeviceContext* device_context = nullptr;
+
+ // The attributes of the allocator that creates the tensor.
+ AllocatorAttributes alloc_attr;
+ };
+
+ // Contains a map from node id to the DeviceContext object that was
+ // assigned by the device at the beginning of a step.
+ DeviceContextMap device_context_map_;
+
+ std::vector<Entry> input_tensors_;
+
+ // Step-local resource manager.
+ ResourceMgr step_resource_manager_;
+
+ // Invoked when the execution finishes.
+ Executor::DoneCallback done_cb_;
+
+ // How many active threads of computation are being used. Same as
+ // the number of pending Process() functions.
+ std::atomic_int_fast32_t num_active_;
+
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+
+ // i-th kernel is still waiting for pending[i] inputs.
+ class CountDown {
+ public:
+ CountDown() : v_(0) {}
+ void Set(int32 v) { v_.store(v); }
+ bool Dec() {
+ return v_.load(std::memory_order_acquire) == 1 || v_.fetch_sub(1) == 1;
+ }
+
+ private:
+ std::atomic_int_fast32_t v_;
+ };
+ std::vector<CountDown> pending_;
+
+ // Process Node identified by "id" in current thread. "scheduled_usec"
+ // indicates when the node becomes ready and gets scheduled.
+ void Process(int id, int64 scheduled_usec);
+
+ // Before invoking item->kernel, fills in its "inputs".
+ Status PrepareInputs(const NodeItem& item, TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts);
+
+ // After item->kernel computation is done, processes its outputs
+ // and returns nodes that become "ready".
+ typedef gtl::InlinedVector<int, 8> ReadyNodeIds;
+ Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ ReadyNodeIds* ready, NodeExecStats* stats);
+
+ // "node" just finishes. Takes ownership of "stats". Returns true if
+ // execution has completed.
+ bool NodeDone(const Status& s, const Node* node, const ReadyNodeIds& ready,
+ NodeExecStats* stats, std::deque<int>* inline_ready);
+
+ // Call Process() on all nodes in 'inline_ready'.
+ void ProcessInline(const std::deque<int>& inline_ready);
+
+ // Schedule all the expensive nodes in 'ready', and put all the inexpensive
+ // nodes in 'ready' into 'inline_ready'.
+ void ScheduleReady(const ReadyNodeIds& ready, std::deque<int>* inline_ready);
+
+ // One thread of control finishes.
+ void Finish();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SimpleExecutorState);
+};
+
+SimpleExecutorState::SimpleExecutorState(const Executor::Args& args,
+ ExecutorImpl* impl)
+ : rendezvous_(args.rendezvous),
+ stats_collector_(args.stats_collector),
+ slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
+ call_frame_(args.call_frame),
+ impl_(impl),
+ cancellation_manager_(args.cancellation_manager),
+ runner_(args.runner),
+ num_active_(0),
+ pending_(impl_->nodes_.size()) {}
+
+void SimpleExecutorState::ProcessInline(const std::deque<int>& inline_ready) {
+ if (inline_ready.empty()) return;
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ for (int id : inline_ready) {
+ Process(id, scheduled_usec);
+ }
+}
+
+void SimpleExecutorState::ScheduleReady(const ReadyNodeIds& ready,
+ std::deque<int>* inline_ready) {
+ if (ready.empty()) return;
+
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ if (inline_ready == nullptr) {
+ // Schedule to run all the ready ops in thread pool.
+ for (auto id : ready) {
+ runner_(std::bind(&ME::Process, this, id, scheduled_usec));
+ }
+ return;
+ }
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ int curr_expensive_node = -1;
+ for (auto id : ready) {
+ if (!nodes[id].kernel->IsExpensive()) {
+ // Inline this inexpensive node.
+ inline_ready->push_back(id);
+ } else {
+ if (curr_expensive_node != -1) {
+ // Dispatch to another thread since there is plenty of work to
+ // do for this thread.
+ runner_(
+ std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec));
+ }
+ curr_expensive_node = id;
+ }
+ }
+ if (curr_expensive_node != -1) {
+ if (inline_ready->empty()) {
+ // Tail recursion optimization
+ inline_ready->push_back(curr_expensive_node);
+ } else {
+ // There are inline nodes to run already. We dispatch this expensive
+ // node to other thread.
+ runner_(
+ std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec));
+ }
+ }
+}
+
+void SimpleExecutorState::RunAsync(Executor::DoneCallback done) {
+ const Graph* graph = impl_->graph_;
+ ReadyNodeIds ready;
+
+ // Ask the device to fill in the device context map.
+ Device* device = impl_->params_.device;
+ device->FillContextMap(graph, &device_context_map_);
+
+ for (const Node* n : graph->nodes()) {
+ const int id = n->id();
+ const int num_in_edges = n->in_edges().size();
+ pending_[id].Set(num_in_edges);
+ if (num_in_edges == 0) {
+ ready.push_back(id);
+ }
+ }
+ if (ready.empty()) {
+ done(Status::OK());
+ } else {
+ num_active_ = ready.size();
+ done_cb_ = done;
+ input_tensors_.resize(impl_->total_tensors_);
+ // Schedule to run all the ready ops in thread pool.
+ ScheduleReady(ready, nullptr);
+ }
+}
+
+Status SimpleExecutorState::PrepareInputs(
+ const NodeItem& item, TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts) {
+ const Node* node = item.node;
+
+ inputs->clear();
+ inputs->resize(node->num_inputs());
+ input_device_contexts->clear();
+ input_device_contexts->resize(node->num_inputs());
+
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ const bool expect_ref = IsRefType(node->input_type(i));
+ Entry* entry = input_tensors_.data() + item.input_start + i;
+ (*input_device_contexts)[i] = entry->device_context;
+
+ // i-th input.
+ TensorValue* inp = &(*inputs)[i];
+
+ if (entry->ref == nullptr) {
+ if (expect_ref) {
+ return AttachDef(
+ errors::InvalidArgument(i, "-th input expects a ref type"),
+ item.kernel->def());
+ }
+ inp->tensor = &entry->val;
+ } else {
+ if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
+ return AttachDef(
+ errors::FailedPrecondition("Attempting to use uninitialized value ",
+ item.kernel->def().input(i)),
+ item.kernel->def());
+ }
+ if (expect_ref) {
+ inp->mutex_if_ref = entry->ref_mu;
+ inp->tensor = entry->ref;
+ } else {
+ // Automatically deref the tensor ref when the op expects a
+ // tensor but is given a ref to a tensor. Need to deref it
+ // under the mutex.
+ {
+ mutex_lock l(*(entry->ref_mu));
+ entry->val = *entry->ref;
+ }
+ inp->tensor = &entry->val;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void SimpleExecutorState::Process(int id, int64 scheduled_usec) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ ReadyNodeIds ready;
+ std::deque<int> inline_ready;
+
+ // Parameters passed to OpKernel::Compute.
+ TensorValueVec inputs;
+ DeviceContextVec input_device_contexts;
+
+ OpKernelContext::Params params;
+ Device* device = impl_->params_.device;
+ params.device = device;
+ // track allocations if and only if we are collecting statistics
+ params.track_allocations = (stats_collector_ != nullptr);
+ params.rendezvous = rendezvous_;
+ params.cancellation_manager = cancellation_manager_;
+ params.call_frame = call_frame_;
+ params.function_library = impl_->params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_resource_manager = &step_resource_manager_;
+ params.slice_reader_cache = slice_reader_cache_;
+ params.inputs = &inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.frame_iter = FrameAndIter(0, 0);
+
+ Status s;
+ NodeExecStats* stats = nullptr;
+ bool completed = false;
+ inline_ready.push_back(id);
+ while (!inline_ready.empty()) {
+ id = inline_ready.front();
+ inline_ready.pop_front();
+ const NodeItem& item = nodes[id];
+ const Node* node = item.node;
+
+ // Set the device_context for this node id, if it exists.
+ auto dc_it = device_context_map_.find(id);
+ if (dc_it != device_context_map_.end()) {
+ params.op_device_context = dc_it->second;
+ }
+
+ if (stats_collector_) {
+ stats = new NodeExecStats;
+ stats->set_node_name(node->name());
+ nodestats::SetScheduled(stats, scheduled_usec);
+ nodestats::SetAllStart(stats);
+ }
+
+ VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def());
+
+ // Prepares inputs.
+ s = PrepareInputs(item, &inputs, &input_device_contexts);
+ if (!s.ok()) {
+ // Continue to process the nodes in 'inline_ready'.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ continue;
+ }
+
+ OpKernel* op_kernel = item.kernel;
+ params.op_kernel = op_kernel;
+ params.output_alloc_attr = [this, node, op_kernel](int index) {
+ return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index);
+ };
+
+ // Asynchronous computes.
+ AsyncOpKernel* async = op_kernel->AsAsync();
+ if (async) {
+ auto pcopy = CopyParams(params);
+ auto ctx = new OpKernelContext(*pcopy);
+ auto done = [this, item, ctx, stats, pcopy]() {
+ VLOG(2) << this
+ << " Async kernel done: " << SummarizeNodeDef(item.node->def());
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+ ReadyNodeIds ready;
+ Status s = ProcessOutputs(item, ctx, &ready, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, ctx);
+ // Schedule to run all the ready ops in thread pool.
+ bool completed = NodeDone(s, item.node, ready, stats, nullptr);
+ delete ctx;
+ DeleteParams(pcopy);
+ if (completed) Finish();
+ };
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->ComputeAsync(async, ctx, done);
+ } else {
+ // Synchronous computes.
+ OpKernelContext ctx(params);
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+
+ s = ProcessOutputs(item, &ctx, &ready, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, &ctx);
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ completed = NodeDone(s, node, ready, stats, &inline_ready);
+ }
+ } // while !inline_ready.empty()
+
+ // This thread of computation is done if completed = true.
+ if (completed) Finish();
+}
+
+bool SimpleExecutorState::NodeDone(const Status& s, const Node* node,
+ const ReadyNodeIds& ready,
+ NodeExecStats* stats,
+ std::deque<int>* inline_ready) {
+ if (stats_collector_) {
+ nodestats::SetAllEnd(stats);
+ if (!SetTimelineLabel(node, stats)) {
+ // Only record non-transfer nodes.
+ stats_collector_->Save(impl_->params_.device->name(), stats);
+ } else {
+ delete stats;
+ }
+ }
+
+ Rendezvous* captured_rendezvous = nullptr; // Will be set on error.
+ if (!s.ok()) {
+ // Some error happened. This thread of computation is done.
+ mutex_lock l(mu_);
+ if (status_.ok()) {
+ captured_rendezvous = rendezvous_;
+ if (captured_rendezvous) captured_rendezvous->Ref();
+ status_ = s;
+ }
+ }
+ if (captured_rendezvous) {
+ // If we captured the rendezvous_ pointer, we are in an error condition.
+ // Use captured_rendezvous, in case "this" is deleted by another thread.
+ TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
+ captured_rendezvous->StartAbort(s);
+ captured_rendezvous->Unref();
+ }
+
+ bool completed = false;
+ int ready_size = ready.size();
+ if (ready_size == 0 || !s.ok()) {
+ completed = (num_active_.fetch_sub(1) == 1);
+ } else if (ready_size > 1) {
+ num_active_.fetch_add(ready_size - 1, std::memory_order_relaxed);
+ }
+
+ // Schedule the ready nodes in 'ready'.
+ if (s.ok()) {
+ ScheduleReady(ready, inline_ready);
+ }
+ return completed;
+}
+
+void SimpleExecutorState::Finish() {
+ mu_.lock();
+ auto ret = status_;
+ auto done_cb = done_cb_;
+ auto runner = runner_;
+ mu_.unlock();
+ delete this;
+ CHECK(done_cb != nullptr);
+ runner([done_cb, ret]() { done_cb(ret); });
+}
+
+Status SimpleExecutorState::ProcessOutputs(const NodeItem& item,
+ OpKernelContext* ctx,
+ ReadyNodeIds* ready,
+ NodeExecStats* stats) {
+ Status s = ctx->status();
+ if (!s.ok()) {
+ s = AttachDef(s, item.kernel->def());
+ LOG(WARNING) << this << " Compute status: " << s;
+ return s;
+ }
+
+ // Processes outputs.
+ gtl::InlinedVector<Entry, 4> outputs;
+ const Node* node = item.node;
+ outputs.resize(node->num_outputs());
+
+ // Get the device_context for this node id, if it exists.
+ DeviceContext* device_context = nullptr;
+ auto dc_it = device_context_map_.find(node->id());
+ if (dc_it != device_context_map_.end()) {
+ device_context = dc_it->second;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ TensorValue val = ctx->release_output(i);
+ // Sanity check of output tensor types.
+ DataType dtype = val->dtype();
+ if (val.is_ref()) dtype = MakeRefType(dtype);
+ if (dtype == node->output_type(i)) {
+ Entry* out = &(outputs[i]);
+ if (val.is_ref()) {
+ out->ref = val.tensor;
+ out->ref_mu = val.mutex_if_ref;
+ } else {
+ out->val = *val.tensor;
+ }
+
+ // Set the device context of the output entry.
+ out->device_context = device_context;
+
+ // Set the allocator attributes of the output entry.
+ out->alloc_attr = ctx->output_alloc_attr(i);
+
+ if (stats_collector_ && val.tensor->IsInitialized()) {
+ nodestats::SetOutput(stats, i, ctx->output_allocation_type(i),
+ val.tensor);
+ }
+ } else {
+ s.Update(
+ errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
+ " does not match declared output type ",
+ DataTypeString(node->output_type(i)),
+ " for operation ", SummarizeNodeDef(node->def())));
+ }
+ if (!val.is_ref()) {
+ // If OpKernelContext returns outputs via pass-by-value, we
+ // don't need this trouble.
+ delete val.tensor;
+ }
+ }
+ if (!s.ok()) return s;
+
+ // Clears inputs.
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ input_tensors_[item.input_start + i].val = *kEmptyTensor;
+ }
+
+ // Propagates outputs along out edges.
+ ready->clear();
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ for (const Edge* e : node->out_edges()) {
+ const int src_slot = e->src_output();
+ const int dst_id = e->dst()->id();
+ const NodeItem& dst_item = nodes[dst_id];
+ if (!e->IsControlEdge()) {
+ const int dst_slot = e->dst_input();
+ input_tensors_[dst_item.input_start + dst_slot] = outputs[src_slot];
+ }
+ if (pending_[dst_id].Dec()) {
+ ready->push_back(dst_id);
+ }
+ }
+ return Status::OK();
+}
+
+// NOTE(yuanbyu): Use the executor that supports control flow by default.
+const bool use_control_flow_executor = true;
+void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
+ if (params_.has_control_flow || use_control_flow_executor) {
+ (new ExecutorState(args, this))->RunAsync(done);
+ } else {
+ (new SimpleExecutorState(args, this))->RunAsync(done);
+ }
+}
+
+} // end namespace
+
+Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
+ Executor** executor) {
+ ExecutorImpl* impl = new ExecutorImpl(params, graph);
+ Status s = impl->Initialize();
+ if (s.ok()) {
+ *executor = impl;
+ } else {
+ delete impl;
+ }
+ return s;
+}
+
+Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
+ const NodeDef& ndef, OpKernel** kernel) {
+ auto device_type = DeviceType(device->attributes().device_type());
+ auto allocator = device->GetAllocator(AllocatorAttributes());
+ return CreateOpKernel(device_type, device, allocator, flib, ndef, kernel);
+}
+
+void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
+
+Status CreateCachedKernel(Device* device, const string& session,
+ FunctionLibraryRuntime* flib, const NodeDef& ndef,
+ OpKernel** kernel) {
+ auto op_seg = device->op_segment();
+ auto create_fn = [device, flib, &ndef](OpKernel** kernel) {
+ return CreateNonCachedKernel(device, flib, ndef, kernel);
+ };
+ return op_seg->FindOrCreate(session, ndef.name(), kernel, create_fn);
+}
+
+// Deletes "kernel".
+void DeleteCachedKernel(Device* device, const string& session,
+ OpKernel* kernel) {
+ // Do nothing.
+}
+
+} // end namespace tensorflow