#include "tensorflow/core/common_runtime/executor.h" #include #include #include #include #include #include #include #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_segment.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/edgeset.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/port.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/public/tensor.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" namespace tensorflow { namespace { // 1-D, 0 element tensor. static const Tensor* const kEmptyTensor = new Tensor; bool IsInitializationOp(const Node* node) { return node->op_def().allows_uninitialized_input(); } // Sets the timeline_label field of *node_stats, using data from *node. // Returns true iff the node is a transfer node. // TODO(tucker): merge with the DetailText function in session.cc // in a common location. bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) { bool is_transfer_node = false; string memory; for (auto& all : node_stats->memory()) { int64 tot = all.total_bytes(); if (tot >= 0.1 * 1048576.0) { int64 peak = all.peak_bytes(); if (peak > 0) { memory = strings::StrCat(memory, "[", all.allocator_name(), strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0, peak / 1048576.0)); } else { memory = strings::StrCat(memory, "[", all.allocator_name(), strings::Printf(" %.1fMB] ", tot / 1048576.0)); } } } const NodeDef& def = node->def(); string text = ""; if (IsSend(node)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); string recv_device; TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device)); text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", tensor_name, " @", recv_device); is_transfer_node = true; } else if (IsRecv(node)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name)); string send_device; TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device)); text = strings::StrCat(memory, def.name(), " = ", def.op(), "(", tensor_name, " @", send_device); is_transfer_node = true; } else { text = strings::StrCat( memory, def.name(), " = ", def.op(), "(", str_util::Join( std::vector(def.input().begin(), def.input().end()), ", "), ")"); } node_stats->set_timeline_label(text); return is_transfer_node; } // Helper routines for collecting step stats. namespace nodestats { inline int64 NowInUsec() { return Env::Default()->NowMicros(); } void SetScheduled(NodeExecStats* nt, int64 t) { nt->set_scheduled_micros(t); } void SetAllStart(NodeExecStats* nt) { nt->set_all_start_micros(NowInUsec()); } void SetOpStart(NodeExecStats* nt) { DCHECK_NE(nt->all_start_micros(), 0); nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros()); } void SetOpEnd(NodeExecStats* nt) { DCHECK_NE(nt->all_start_micros(), 0); nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros()); } void SetAllEnd(NodeExecStats* nt) { DCHECK_NE(nt->all_start_micros(), 0); nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros()); } void SetOutput(NodeExecStats* nt, int slot, AllocationType allocation_type, const Tensor* v) { DCHECK(v); NodeOutput* no = nt->add_output(); no->set_slot(slot); no->set_allocation_type(allocation_type); v->FillDescription(no->mutable_tensor_description()); } void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) { for (const auto& allocator_pair : ctx->wrapped_allocators()) { AllocatorMemoryUsed* memory = nt->add_memory(); // retrieving the sizes from the wrapped allocator removes the // executor's reference to it, so allocator_pair.second must not // be dereferenced again after this statement auto sizes = allocator_pair.second->GetSizesAndUnRef(); memory->set_allocator_name(allocator_pair.first->Name()); int tb = sizes.first; memory->set_total_bytes(tb); if (allocator_pair.first->TracksAllocationSizes()) { memory->set_peak_bytes(sizes.second); } } } } // namespace nodestats struct NodeItem { // A graph node. const Node* node = nullptr; // The kernel for this node. OpKernel* kernel = nullptr; // ExecutorImpl::tensors_[input_start] is the 1st positional input // for this node. int input_start = 0; }; // Map from std::pair to attributes. struct pairhash { public: template std::size_t operator()(const std::pair& x) const { return std::hash()(x.first) ^ std::hash()(x.second); } }; typedef std::unordered_map, AllocatorAttributes, pairhash> DevAttrMap; typedef gtl::InlinedVector TensorValueVec; typedef gtl::InlinedVector DeviceContextVec; typedef gtl::InlinedVector AllocatorAttributeVec; class ExecutorImpl : public Executor { public: ExecutorImpl(const LocalExecutorParams& p, const Graph* g) : params_(p), graph_(g) { CHECK(p.create_kernel != nullptr); CHECK(p.delete_kernel != nullptr); } ~ExecutorImpl() override { for (NodeItem& item : nodes_) { params_.delete_kernel(item.kernel); } delete graph_; } Status Initialize(); // Infer memory allocation attributes of a node n's output, // based on its use node dst. Note that dst might not be directly // connected to n by a single edge, but might be a downstream // consumer of n's output by reference. *attr is updated with any // necessary attributes. Status InferAllocAttr(const Node* n, const Node* dst, const DeviceNameUtils::ParsedName& local_dev_name, AllocatorAttributes* attr); // Process all Nodes in the current graph, attempting to infer the // memory allocation attributes to be used wherever they may allocate // a tensor buffer. Status SetAllocAttrs(); void RunAsync(const Args& args, DoneCallback done) override; private: friend class ExecutorState; friend class SimpleExecutorState; // Owned. LocalExecutorParams params_; const Graph* graph_; std::vector nodes_; // nodes_.size == graph_.num_node_ids(). int total_tensors_ = 0; // total_tensors_ = sum(nodes_[*].num_inputs()) // The number of inputs for each frame in this graph. This is static // information of the graph. std::unordered_map frame_input_count_; DevAttrMap alloc_attr_; TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl); }; Status ExecutorImpl::Initialize() { const int num_nodes = graph_->num_node_ids(); nodes_.resize(num_nodes); Status s; total_tensors_ = 0; // Preprocess every node in the graph to create an instance of op // kernel for each node; for (const Node* n : graph_->nodes()) { const int id = n->id(); NodeItem* item = &nodes_[id]; item->node = n; item->input_start = total_tensors_; total_tensors_ += n->num_inputs(); s = params_.create_kernel(n->def(), &item->kernel); if (!s.ok()) { s = AttachDef(s, n->def()); LOG(ERROR) << "Executor failed to create kernel. " << s; break; } CHECK(item->kernel); // Initialize static information about the frames in the graph. if (IsEnter(n)) { string frame_name; s = GetNodeAttr(n->def(), "frame_name", &frame_name); if (!s.ok()) return s; ++frame_input_count_[frame_name]; } } if (params_.has_control_flow) { VLOG(2) << "Graph has control flow."; } if (!s.ok()) return s; return SetAllocAttrs(); } Status ExecutorImpl::SetAllocAttrs() { Status s; Device* device = params_.device; DeviceNameUtils::ParsedName local_dev_name = device->parsed_name(); for (const Node* n : graph_->nodes()) { // Examine the out edges of each node looking for special use // cases that may affect memory allocation attributes. for (auto e : n->out_edges()) { AllocatorAttributes attr; s = InferAllocAttr(n, e->dst(), local_dev_name, &attr); if (!s.ok()) return s; if (attr.value != 0) { VLOG(2) << "node " << n->name() << " gets attr " << attr.value << " for output " << e->src_output(); alloc_attr_[std::make_pair(n->id(), e->src_output())].Merge(attr); } else { VLOG(2) << "default output attr for node " << n->name() << " output " << e->src_output(); } } } return s; } Status ExecutorImpl::InferAllocAttr( const Node* n, const Node* dst, const DeviceNameUtils::ParsedName& local_dev_name, AllocatorAttributes* attr) { Status s; if (IsSend(dst)) { string dst_name; s = GetNodeAttr(dst->def(), "recv_device", &dst_name); if (!s.ok()) return s; DeviceNameUtils::ParsedName parsed_dst_name; if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) { s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ", n->name()); return s; } if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) { // Value is going to be the source of an RPC. attr->set_nic_compatible(true); VLOG(2) << "node " << n->name() << " is the source of an RPC out"; } else if (local_dev_name.type == "CPU" && parsed_dst_name.type == "GPU") { // Value is going to be the source of a local DMA from CPU to GPU. attr->set_gpu_compatible(true); VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy"; } else { VLOG(2) << "default alloc case local type " << local_dev_name.type << " remote type " << parsed_dst_name.type; } } else if (dst->type_string() == "ToFloat") { for (auto e : dst->out_edges()) { s = InferAllocAttr(n, e->dst(), local_dev_name, attr); if (!s.ok()) return s; } } return s; } // The state associated with one invokation of ExecutorImpl::Run. // ExecutorState dispatches nodes when they become ready and keeps // track of how many predecessors of a node have not done (pending_). class ExecutorState { public: ExecutorState(const Executor::Args& args, ExecutorImpl* impl); ~ExecutorState(); void RunAsync(Executor::DoneCallback done); private: typedef ExecutorState ME; // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). // TODO(yuanbyu): A better way to do "has_value"? struct Entry { Tensor val = *kEmptyTensor; // A tensor value. Tensor* ref = nullptr; // A tensor reference. mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr. bool has_value = false; // Whether the value exists // Every entry carries an optional DeviceContext containing // Device-specific information about how the Tensor was produced. DeviceContext* device_context = nullptr; // The attributes of the allocator that creates the tensor. AllocatorAttributes alloc_attr; }; // Contains a map from node id to the DeviceContext object that was // assigned by the device at the beginning of a step. DeviceContextMap device_context_map_; struct IterationState { // The state of an iteration. // The pending count for each graph node. One copy per iteration. // Iteration i can be garbage collected when it is done. // TODO(yuanbyu): This vector currently has size of the number of nodes // in this partition. This is not efficient if the subgraph for the frame // is only a small subset of the partition. We should make the vector // size to be only the size of the frame subgraph. std::vector* pending_count; // The dead input count for each graph node. One copy per iteration. std::vector* dead_count; // One copy per iteration. For iteration k, i-th node's j-th input is in // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either // a tensor pointer (pass-by-reference) or a tensor (pass-by-value). // // NOTE: No need to protect input_tensors[i] by any locks because it // is resized once. Each element of tensors_ is written once by the // source node of an edge and is cleared by the destination of the same // edge. The latter node is never run concurrently with the former node. std::vector* input_tensors; // The number of outstanding ops for each iteration. int outstanding_ops; // The number of outstanding frames for each iteration. int outstanding_frame_count; ~IterationState() { delete pending_count; delete dead_count; delete input_tensors; } }; struct FrameState { // A new frame is created for each loop. Execution starts at iteration 0. // When a value at iteration 0 passes through a NextIteration node, // iteration 1 is created and starts running. Note that iteration 0 may // still be running so multiple iterations may run in parallel. The // frame maintains the state of iterations in several data structures // such as pending_count and input_tensors. When iteration 0 completes, // we garbage collect the state of iteration 0. // // A frame instance is considered "done" and can be garbage collected // if all its inputs have entered and all its iterations are "done". // // A frame manages the live iterations of an iterative computation. // Iteration i is considered "done" when there are no outstanding ops, // frames at iteration i are done, all recvs for this iteration are // completed, and iteration i-1 is done. For iteration 0, we instead // wait for there to be no more pending inputs of the frame. // // Frames and iterations are garbage collected once they are done. // The state we need to keep around is highly dependent on the // parallelism enabled by the scheduler. We may want to have the // scheduler dynamically control the outstanding number of live // parallel frames and iterations. To reduce the state space, the // scheduler might want to schedule ops in inner frames first and // lower iterations first. // // This frame state is mostly initialized lazily on demand so we // don't introduce unnecessary overhead. // The name of this frame, which is the concatenation of its parent // frame name, the iteration of the parent frame when this frame was // created, and the value of the attr 'frame_name'. string frame_name; // The unique id for this frame. Generated by fingerprinting // frame_name. uint64 frame_id; // The iteration id of its parent frame when this frame is created. // -1 if there is no parent frame. The frame_name/parent_iter pair // uniquely identifies this FrameState. int64 parent_iter = -1; // The FrameState of its parent frame. FrameState* parent_frame = nullptr; // The highest iteration number we have reached so far in this frame. int64 iteration_count = 0; // The number of inputs this frame is still waiting. int num_pending_inputs = 0; // The number of outstanding iterations. int num_outstanding_iterations = 0; // The maximum allowed number of parallel iterations. int max_parallel_iterations = 1; // The iteration states of this frame. std::vector iterations; // The NextIteration nodes to enter a new iteration. If the number of // outstanding iterations reaches the limit, we will defer the start of // the next iteration until the number of outstanding iterations falls // below the limit. std::vector> next_iter_roots; // The values of the loop invariants for this loop. They are added into // this list as they "enter" the frame. When a loop invariant enters, // we make it available to all active iterations. When the frame starts // a new iteration, we make all the current loop invariants available // to the new iteration. std::vector> inv_values; // The list of dead exit nodes for the current highest iteration. We // will only "execute" the dead exits of the final iteration. std::vector dead_exits; IterationState* GetIteration(int64 iter) { int index = iter % iterations.size(); return iterations[index]; } void SetIteration(int64 iter, IterationState* state) { int index = iter % iterations.size(); iterations[index] = state; } ~FrameState() { for (size_t i = 0; i < iterations.size(); ++i) { delete iterations[i]; iterations[i] = nullptr; } } }; // A tagged node: . struct TaggedNode { const Node* node = nullptr; FrameState* input_frame = nullptr; int64 input_iter = -1; bool is_dead = false; TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter, bool dead) { node = t_node; input_frame = in_frame; input_iter = in_iter; is_dead = dead; } }; typedef gtl::InlinedVector TaggedNodeSeq; typedef gtl::InlinedVector EntryVector; // Not owned. Rendezvous* rendezvous_; StepStatsCollector* stats_collector_; // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper instead of a // pointer? (avoids having to delete). checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; FunctionCallFrame* call_frame_; const ExecutorImpl* impl_; CancellationManager* cancellation_manager_; Executor::Args::Runner runner_; // Owned. // Step-local resource manager. ResourceMgr step_resource_manager_; // The root frame in which the execution of this step is started. FrameState* root_frame_; // Invoked when the execution finishes. Executor::DoneCallback done_cb_; std::atomic_int_fast32_t num_outstanding_ops_; mutex mu_; Status status_ GUARDED_BY(mu_); // Mapping from frame name to outstanding frames. A new frame is created // at some iteration of an active frame. So the unique key for the new // child frame is composed of the name of the parent frame, the iteration // number at which the parent frame is creating the new frame, and the // name of the new frame from nodedef. std::unordered_map outstanding_frames_ GUARDED_BY(mu_); // The unique name of a frame. inline string MakeFrameName(FrameState* frame, int64 iter_id, string name) { return strings::StrCat(frame->frame_name, ";", iter_id, ";", name); } // Initialize the pending count for a graph. static void InitializePending(const Graph* graph, std::vector* pending); // Find an existing or create a new child frame in the frame 'frame' at // iteration 'iter'. void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, FrameState** child) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Increments the iteration id. If this is a new iteration, initialize it. void IncrementIteration(FrameState* frame, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns true if the computation in the frame is completed. bool IsFrameDone(FrameState* frame) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns true if the iteration of the frame is completed. bool IsIterationDone(FrameState* frame, int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Get the output frame/iter of a node. Create new frame/iteration if // needed. If there are dead roots for the new iteration, we need to // "execute" them so ad them to the ready queue. Returns true if // we need to check for the completion of output frame/iter. bool SetOutputFrameIter(const TaggedNode& tagged_node, const EntryVector& outputs, FrameState** frame, int64* iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Cleanup frames and iterations void CleanupFramesIterations(FrameState* frame, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Activate all the deferred NextIteration nodes in a new iteration. void ActivateNexts(FrameState* frame, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Activate all the current loop invariants in a new iteration. void ActivateLoopInvs(FrameState* frame, int64 iter, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Add a new loop invariant and make it available to all active iterations. void AddLoopInv(FrameState* frame, const Node* node, const Entry& value, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Activate the successors of a node. void ActivateNode(const Node* node, const bool is_dead, FrameState* frame, int64 iter, const EntryVector& outputs, TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_); // Process a ready node in current thread. void Process(TaggedNode node, int64 scheduled_usec); // Before invoking item->kernel, fills in its "inputs". Status PrepareInputs(const NodeItem& item, Entry* first_input, TensorValueVec* inputs, DeviceContextVec* input_device_contexts, AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead); // After item->kernel computation is done, processes its outputs. Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, NodeExecStats* stats); // After processing the outputs, propagates the outputs to their dsts. void PropagateOutputs(const TaggedNode& tagged_node, const EntryVector& outputs, TaggedNodeSeq* ready); // "node" just finishes. Takes ownership of "stats". Returns true if // execution has completed. bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, NodeExecStats* stats, std::deque* inline_ready); // Call Process() on all nodes in 'inline_ready'. void ProcessInline(const std::deque& inline_ready); // Schedule all the expensive nodes in 'ready', and put all the inexpensive // nodes in 'ready' into 'inline_ready'. void ScheduleReady(const TaggedNodeSeq& ready, std::deque* inline_ready); // One thread of control finishes. void Finish(); }; ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) : rendezvous_(args.rendezvous), stats_collector_(args.stats_collector), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), impl_(impl), cancellation_manager_(args.cancellation_manager), runner_(args.runner), num_outstanding_ops_(0) { // We start the entire execution in iteration 0 of the root frame // so let us create the root frame and the state for iteration 0. // Initialize the frame. root_frame_ = new FrameState; root_frame_->frame_name = "_root"; // assume to be unique root_frame_->frame_id = 0; // must be 0 root_frame_->num_pending_inputs = 0; root_frame_->num_outstanding_iterations = 1; root_frame_->max_parallel_iterations = 1; // enough for root frame root_frame_->iterations.resize(root_frame_->max_parallel_iterations); VLOG(2) << "Create frame: " << root_frame_->frame_name; // Initialize the iteration. IterationState* iter_state = new IterationState; root_frame_->iterations[0] = iter_state; iter_state->outstanding_ops = 0; iter_state->outstanding_frame_count = 0; iter_state->pending_count = new std::vector; iter_state->dead_count = new std::vector(impl->graph_->num_node_ids()); iter_state->input_tensors = new std::vector(impl_->total_tensors_); // Initialize the executor state. outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); } ExecutorState::~ExecutorState() { for (auto name_frame : outstanding_frames_) { delete name_frame.second; } for (auto it : device_context_map_) { it.second->Unref(); } delete slice_reader_cache_; } void ExecutorState::InitializePending(const Graph* graph, std::vector* pending) { pending->resize(graph->num_node_ids()); for (const Node* n : graph->nodes()) { const int id = n->id(); const int num_in_edges = n->in_edges().size(); if (IsMerge(n)) { // merge waits all control inputs so we initialize the pending // count to be the number of control edges. int32 num_control_edges = 0; for (const Edge* edge : n->in_edges()) { if (edge->IsControlEdge()) { num_control_edges++; } } // Use bit 0 to indicate if there is a ready live data input. (*pending)[id] = num_control_edges << 1; } else { (*pending)[id] = num_in_edges; } } } void ExecutorState::RunAsync(Executor::DoneCallback done) { const Graph* graph = impl_->graph_; TaggedNodeSeq ready; { // Initialize the executor state. We grab the mutex here just to // keep the thread safety analysis happy. mutex_lock l(mu_); std::vector* pending = root_frame_->iterations[0]->pending_count; InitializePending(graph, pending); } // Ask the device to fill in the device context map. Device* device = impl_->params_.device; device->FillContextMap(graph, &device_context_map_); // Initialize the ready queue. for (const Node* n : graph->nodes()) { const int num_in_edges = n->in_edges().size(); if (num_in_edges == 0) { ready.push_back(TaggedNode{n, root_frame_, 0, false}); } } if (ready.empty()) { done(Status::OK()); } else { num_outstanding_ops_ = ready.size(); root_frame_->iterations[0]->outstanding_ops = ready.size(); done_cb_ = done; // Schedule to run all the ready ops in thread pool. ScheduleReady(ready, nullptr); } } namespace { // This function is provided for use by OpKernelContext when allocating // the index'th output of node. It provides access to the // AllocatorAttributes computed during initialization to determine in // which memory region the tensor should be allocated. AllocatorAttributes OutputAttributes(const DevAttrMap* attr_map, const Node* node, const OpKernel* op_kernel, int index) { DCHECK_GE(index, 0); AllocatorAttributes attr; int nid = node->id(); const auto& iter = attr_map->find(std::make_pair(nid, index)); if (iter != attr_map->end()) { attr = iter->second; VLOG(2) << "nondefault attr " << attr.value << " for node " << node->name() << " output " << index; } else { VLOG(2) << "default attr for node " << node->name() << " output " << index; } DCHECK_LT(index, op_kernel->output_memory_types().size()); bool on_host = op_kernel->output_memory_types()[index] == HOST_MEMORY; attr.set_on_host(on_host); return attr; } // Helpers to make a copy of 'p' and makes a copy of the input type // vector and the device context vector. // // NOTE: We need to make a copy of p.input for asynchronous kernel // because OpKernelContext methods like input_type(i) needs the param // points to valid input type vector. It's not an issue for sync // kernels because the type vector is kept on the stack. OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) { OpKernelContext::Params* ret = new OpKernelContext::Params; *ret = p; ret->inputs = new TensorValueVec(*p.inputs); ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts); ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs); return ret; } // Helpers to delete 'p' and copies made by CopyParams. void DeleteParams(OpKernelContext::Params* p) { delete p->inputs; delete p->input_device_contexts; delete p->input_alloc_attrs; delete p; } } // namespace void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { const std::vector& nodes = impl_->nodes_; TaggedNodeSeq ready; std::deque inline_ready; // Parameters passed to OpKernel::Compute. TensorValueVec inputs; DeviceContextVec input_device_contexts; AllocatorAttributeVec input_alloc_attrs; OpKernelContext::Params params; Device* device = impl_->params_.device; params.device = device; // track allocations if and only if we are collecting statistics params.track_allocations = (stats_collector_ != nullptr); params.rendezvous = rendezvous_; params.cancellation_manager = cancellation_manager_; params.call_frame = call_frame_; params.function_library = impl_->params_.function_library; params.resource_manager = device->resource_manager(); params.step_resource_manager = &step_resource_manager_; params.slice_reader_cache = slice_reader_cache_; params.inputs = &inputs; params.input_device_contexts = &input_device_contexts; params.input_alloc_attrs = &input_alloc_attrs; Status s; NodeExecStats* stats = nullptr; EntryVector outputs; bool completed = false; inline_ready.push_back(tagged_node); while (!inline_ready.empty()) { tagged_node = inline_ready.front(); inline_ready.pop_front(); const Node* node = tagged_node.node; FrameState* input_frame = tagged_node.input_frame; int64 input_iter = tagged_node.input_iter; const int id = node->id(); const NodeItem& item = nodes[id]; // Set the device_context for this node id, if it exists. auto dc_it = device_context_map_.find(id); if (dc_it != device_context_map_.end()) { params.op_device_context = dc_it->second; } if (stats_collector_) { stats = new NodeExecStats; stats->set_node_name(node->name()); nodestats::SetScheduled(stats, scheduled_usec); nodestats::SetAllStart(stats); } VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def()); std::vector* input_tensors; { // Need the lock because the iterations vector could be resized by // another thread. mutex_lock l(mu_); input_tensors = input_frame->GetIteration(input_iter)->input_tensors; } Entry* first_input = input_tensors->data() + item.input_start; outputs.clear(); outputs.resize(node->num_outputs()); // Only execute this node if it is not dead or it is a send/recv // transfer node. For transfer nodes, we need to propagate the "dead" // bit even when the node is dead. AsyncOpKernel* async = nullptr; if (!tagged_node.is_dead || IsTransferNode(node)) { // Prepares inputs. bool is_input_dead = false; s = PrepareInputs(item, first_input, &inputs, &input_device_contexts, &input_alloc_attrs, &is_input_dead); if (!s.ok()) { // Continue to process the nodes in 'inline_ready'. completed = NodeDone(s, item.node, ready, stats, &inline_ready); continue; } // Set up compute params. OpKernel* op_kernel = item.kernel; params.op_kernel = op_kernel; params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter); params.is_input_dead = is_input_dead; params.output_alloc_attr = [this, node, op_kernel](int index) { return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index); }; async = op_kernel->AsAsync(); if (async) { // Asynchronous computes. auto pcopy = CopyParams(params); auto ctx = new OpKernelContext(*pcopy); auto done = [this, tagged_node, item, first_input, ctx, stats, pcopy]() { VLOG(2) << this << " Async kernel done: " << SummarizeNodeDef(item.node->def()); if (stats_collector_) nodestats::SetOpEnd(stats); EntryVector outputs; Status s = ProcessOutputs(item, ctx, &outputs, stats); if (stats_collector_) nodestats::SetMemory(stats, ctx); // Clears inputs. int num_inputs = tagged_node.node->num_inputs(); for (int i = 0; i < num_inputs; ++i) { (first_input + i)->val = *kEmptyTensor; } TaggedNodeSeq ready; if (s.ok()) { PropagateOutputs(tagged_node, outputs, &ready); } // Schedule to run all the ready ops in thread pool. bool completed = NodeDone(s, item.node, ready, stats, nullptr); delete ctx; DeleteParams(pcopy); if (completed) Finish(); }; if (stats_collector_) nodestats::SetOpStart(stats); device->ComputeAsync(async, ctx, done); } else { // Synchronous computes. OpKernelContext ctx(params); if (stats_collector_) nodestats::SetOpStart(stats); device->Compute(CHECK_NOTNULL(op_kernel), &ctx); if (stats_collector_) nodestats::SetOpEnd(stats); // Processes outputs. s = ProcessOutputs(item, &ctx, &outputs, stats); if (stats_collector_) nodestats::SetMemory(stats, &ctx); } } if (!async) { // Clears inputs. int num_inputs = node->num_inputs(); for (int i = 0; i < num_inputs; ++i) { (first_input + i)->val = *kEmptyTensor; } // Propagates outputs. if (s.ok()) { PropagateOutputs(tagged_node, outputs, &ready); } if (stats_collector_) { scheduled_usec = nodestats::NowInUsec(); } // Postprocess. completed = NodeDone(s, item.node, ready, stats, &inline_ready); } } // while !inline_ready.empty() // This thread of computation is done if completed = true. if (completed) Finish(); } Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, TensorValueVec* inputs, DeviceContextVec* input_device_contexts, AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) { const Node* node = item.node; inputs->clear(); inputs->resize(node->num_inputs()); input_device_contexts->clear(); input_device_contexts->resize(node->num_inputs()); input_alloc_attrs->clear(); input_alloc_attrs->resize(node->num_inputs()); *is_input_dead = false; bool is_merge = IsMerge(node); for (int i = 0; i < node->num_inputs(); ++i) { const bool expect_ref = IsRefType(node->input_type(i)); Entry* entry = first_input + i; (*input_device_contexts)[i] = entry->device_context; (*input_alloc_attrs)[i] = entry->alloc_attr; // i-th input. TensorValue* inp = &(*inputs)[i]; // Only merge and transfer nodes can have no-value inputs. if (!entry->has_value) { if (!is_merge) { DCHECK(IsTransferNode(node)); inp->tensor = &entry->val; *is_input_dead = true; } continue; } if (entry->ref == nullptr) { if (expect_ref) { return AttachDef( errors::InvalidArgument(i, "-th input expects a ref type"), item.kernel->def()); } inp->tensor = &entry->val; } else { if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) { return AttachDef( errors::FailedPrecondition("Attempting to use uninitialized value ", item.kernel->def().input(i)), item.kernel->def()); } if (expect_ref) { inp->mutex_if_ref = entry->ref_mu; inp->tensor = entry->ref; } else { // Automatically deref the tensor ref when the op expects a // tensor but is given a ref to a tensor. Need to deref it // under the mutex. { mutex_lock l(*(entry->ref_mu)); entry->val = *entry->ref; } inp->tensor = &entry->val; } } } return Status::OK(); } Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, NodeExecStats* stats) { const Node* node = item.node; outputs->clear(); outputs->resize(node->num_outputs()); Status s = ctx->status(); if (!s.ok()) { s = AttachDef(s, item.kernel->def()); LOG(WARNING) << this << " Compute status: " << s; return s; } // Get the device_context for this node id, if it exists. DeviceContext* device_context = nullptr; auto dc_it = device_context_map_.find(node->id()); if (dc_it != device_context_map_.end()) { device_context = dc_it->second; } for (int i = 0; i < node->num_outputs(); ++i) { TensorValue val = ctx->release_output(i); // Only Switch and Recv nodes can generate new dead outputs if (*ctx->is_output_dead() || val.tensor == nullptr) { DCHECK(IsSwitch(node) || IsRecv(node)); } else { Entry* out = &((*outputs)[i]); out->has_value = true; // Set the device context of the output entry. out->device_context = device_context; // Set the allocator attributes of the output entry. out->alloc_attr = ctx->output_alloc_attr(i); // Sanity check of output tensor types. DataType dtype = val->dtype(); if (val.is_ref()) dtype = MakeRefType(dtype); if (dtype == node->output_type(i)) { if (val.is_ref()) { out->ref = val.tensor; out->ref_mu = val.mutex_if_ref; } else { out->val = *val.tensor; } if (stats_collector_ && val.tensor->IsInitialized()) { nodestats::SetOutput(stats, i, ctx->output_allocation_type(i), val.tensor); } } else { s.Update(errors::Internal("Output ", i, " of type ", DataTypeString(dtype), " does not match declared output type ", DataTypeString(node->output_type(i)), " for node ", SummarizeNodeDef(node->def()))); } } if (!val.is_ref()) { // If OpKernelContext returns outputs via pass-by-value, we // don't need this trouble. delete val.tensor; } } return s; } void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, const EntryVector& outputs, TaggedNodeSeq* ready) { FrameState* input_frame = tagged_node.input_frame; int64 input_iter = tagged_node.input_iter; // Propagates outputs along out edges, and puts newly ready nodes // into the ready queue. ready->clear(); { FrameState* output_frame = input_frame; int64 output_iter = input_iter; mutex_lock l(mu_); // Sets the output_frame and output_iter of node. bool maybe_completed = SetOutputFrameIter( tagged_node, outputs, &output_frame, &output_iter, ready); if (output_frame != nullptr) { // Continue to process the out nodes: ActivateNode(tagged_node.node, tagged_node.is_dead, output_frame, output_iter, outputs, ready); } // At this point, this node is completely done. input_frame->GetIteration(input_iter)->outstanding_ops--; CleanupFramesIterations(input_frame, input_iter, ready); // The execution of a node such as Enter may cause the completion of // output_frame:output_iter, so perform cleanup if output_frame:output_iter // is indeed completed. if (maybe_completed) { CleanupFramesIterations(output_frame, output_iter, ready); } } } void ExecutorState::ActivateNode(const Node* node, const bool is_dead, FrameState* output_frame, int64 output_iter, const EntryVector& outputs, TaggedNodeSeq* ready) { const std::vector& nodes = impl_->nodes_; IterationState* output_iter_state = output_frame->GetIteration(output_iter); std::vector* pending = output_iter_state->pending_count; std::vector* dead_count = output_iter_state->dead_count; for (const Edge* e : node->out_edges()) { const Node* dst_node = e->dst(); const int dst_id = dst_node->id(); const int src_slot = e->src_output(); bool dst_dead = false; bool dst_ready = false; bool dst_need_input = !e->IsControlEdge(); if (IsMerge(dst_node)) { // A merge node is ready if a) all control edges are enabled and a // live data input becomes available, or b) all control edges are // enabled and all data inputs are dead. if (e->IsControlEdge()) { (*pending)[dst_id] -= 2; int count = (*pending)[dst_id]; dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); dst_ready = (count == 1) || ((count == 0) && dst_dead); } else { if (outputs[src_slot].has_value) { // This is a live data input. int count = (*pending)[dst_id]; (*pending)[dst_id] |= 0x1; dst_ready = (count == 0); } else { // This is a dead data input. ++(*dead_count)[dst_id]; dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); dst_ready = ((*pending)[dst_id] == 0) && dst_dead; } // This input for dst is not needed if !dst_ready. We suppress the // propagation to make the thread safety analysis happy. dst_need_input = dst_ready; } } else { // A non-merge node is ready if all its inputs are ready. We wait // for all inputs to come in even if we know the node is dead. This // ensures that all input tensors get cleaned up. if (is_dead || (!e->IsControlEdge() && !outputs[src_slot].has_value)) { ++(*dead_count)[dst_id]; } dst_dead = (*dead_count)[dst_id] > 0; dst_ready = (--(*pending)[dst_id] == 0); } if (dst_need_input) { const NodeItem& dst_item = nodes[dst_id]; const int dst_slot = e->dst_input(); std::vector* input_tensors = output_iter_state->input_tensors; int dst_loc = dst_item.input_start + dst_slot; (*input_tensors)[dst_loc] = outputs[src_slot]; } // Add dst to the ready queue if it's ready if (dst_ready) { dst_dead = dst_dead && !IsControlTrigger(dst_node); ready->push_back( TaggedNode(dst_node, output_frame, output_iter, dst_dead)); output_iter_state->outstanding_ops++; } } } void ExecutorState::ActivateNexts(FrameState* frame, int64 iter, TaggedNodeSeq* ready) { // Propagate the deferred NextIteration nodes to the new iteration. for (auto& node_entry : frame->next_iter_roots) { const Node* node = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = !entry.has_value; ActivateNode(node, is_dead, frame, iter, {entry}, ready); } frame->next_iter_roots.clear(); } void ExecutorState::ActivateLoopInvs(FrameState* frame, int64 iter, TaggedNodeSeq* ready) { // Propagate loop invariants to the new iteration. for (auto& node_entry : frame->inv_values) { const Node* node = node_entry.first; const Entry& entry = node_entry.second; const bool is_dead = !entry.has_value; ActivateNode(node, is_dead, frame, iter, {entry}, ready); } } void ExecutorState::AddLoopInv(FrameState* frame, const Node* node, const Entry& entry, TaggedNodeSeq* ready) { // Store this value. frame->inv_values.push_back({node, entry}); // Make this value available to all iterations. bool is_dead = !entry.has_value; for (int i = 1; i <= frame->iteration_count; ++i) { ActivateNode(node, is_dead, frame, i, {entry}, ready); } } bool ExecutorState::NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, NodeExecStats* stats, std::deque* inline_ready) { if (stats_collector_) { nodestats::SetAllEnd(stats); if (!SetTimelineLabel(node, stats)) { // Only record non-transfer nodes. stats_collector_->Save(impl_->params_.device->name(), stats); } else { delete stats; } } Rendezvous* captured_rendezvous = nullptr; // Will be set on error. if (!s.ok()) { // Some error happened. This thread of computation is done. mutex_lock l(mu_); if (status_.ok()) { captured_rendezvous = rendezvous_; if (captured_rendezvous) captured_rendezvous->Ref(); status_ = s; } } if (captured_rendezvous) { // If we captured the rendezvous_ pointer, we are in an error condition. // Use captured_rendezvous, in case "this" is deleted by another thread. TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); captured_rendezvous->StartAbort(s); captured_rendezvous->Unref(); } bool completed = false; int ready_size = ready.size(); if (ready_size == 0 || !s.ok()) { completed = (num_outstanding_ops_.fetch_sub(1) == 1); } else if (ready_size > 1) { num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed); } // Schedule the ready nodes in 'ready'. if (s.ok()) { ScheduleReady(ready, inline_ready); } return completed; } void ExecutorState::ProcessInline(const std::deque& inline_ready) { if (inline_ready.empty()) return; int64 scheduled_usec = 0; if (stats_collector_) { scheduled_usec = nodestats::NowInUsec(); } for (auto& tagged_node : inline_ready) { Process(tagged_node, scheduled_usec); } } void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready, std::deque* inline_ready) { if (ready.empty()) return; int64 scheduled_usec = 0; if (stats_collector_) { scheduled_usec = nodestats::NowInUsec(); } if (inline_ready == nullptr) { // Schedule to run all the ready ops in thread pool. for (auto& tagged_node : ready) { runner_(std::bind(&ME::Process, this, tagged_node, scheduled_usec)); } return; } const std::vector& nodes = impl_->nodes_; const TaggedNode* curr_expensive_node = nullptr; for (auto& tagged_node : ready) { const NodeItem& item = nodes[tagged_node.node->id()]; if (tagged_node.is_dead || !item.kernel->IsExpensive()) { // Inline this inexpensive node. inline_ready->push_back(tagged_node); } else { if (curr_expensive_node) { // Dispatch to another thread since there is plenty of work to // do for this thread. runner_(std::bind(&ME::Process, this, *curr_expensive_node, scheduled_usec)); } curr_expensive_node = &tagged_node; } } if (curr_expensive_node) { if (inline_ready->empty()) { // Tail recursion optimization inline_ready->push_back(*curr_expensive_node); } else { // There are inline nodes to run already. We dispatch this expensive // node to other thread. runner_( std::bind(&ME::Process, this, *curr_expensive_node, scheduled_usec)); } } } void ExecutorState::Finish() { mu_.lock(); auto status = status_; auto done_cb = done_cb_; auto runner = runner_; mu_.unlock(); delete this; CHECK(done_cb != nullptr); runner([done_cb, status]() { done_cb(status); }); } bool ExecutorState::IsFrameDone(FrameState* frame) { return (frame->num_pending_inputs == 0 && frame->num_outstanding_iterations == 0); } bool ExecutorState::IsIterationDone(FrameState* frame, int64 iter) { IterationState* iter_state = frame->GetIteration(iter); if (iter_state->outstanding_ops == 0 && iter_state->outstanding_frame_count == 0) { if (iter == 0) { // The enclosing frame has no pending input. return frame->num_pending_inputs == 0; } else { // The preceding iteration is deleted (and therefore done). return (frame->GetIteration(iter - 1) == nullptr); } } return false; } void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, FrameState** child) { // Get the child frame name. string enter_name; Status s = GetNodeAttr(node->def(), "frame_name", &enter_name); CHECK(s.ok()) << s; const string child_name = MakeFrameName(frame, iter, enter_name); auto it = outstanding_frames_.find(child_name); if (it != outstanding_frames_.end()) { *child = it->second; } else { // Need to create a new frame instance. VLOG(2) << "Create frame: " << child_name; FrameState* temp = new FrameState; temp->frame_name = child_name; temp->frame_id = Hash64(child_name); temp->parent_frame = frame; temp->parent_iter = iter; s = GetNodeAttr(node->def(), "parallel_iterations", &temp->max_parallel_iterations); CHECK(s.ok()) << s; // 'iterations' is a fixed-length circular buffer. temp->iterations.resize(temp->max_parallel_iterations + 1); IterationState* iter_state = new IterationState; temp->iterations[0] = iter_state; iter_state->outstanding_ops = 0; iter_state->outstanding_frame_count = 0; iter_state->pending_count = new std::vector; InitializePending(impl_->graph_, iter_state->pending_count); iter_state->dead_count = new std::vector(impl_->graph_->num_node_ids()); iter_state->input_tensors = new std::vector(impl_->total_tensors_); auto frame_pending = impl_->frame_input_count_.find(enter_name); DCHECK(frame_pending != impl_->frame_input_count_.end()); temp->num_pending_inputs = frame_pending->second; temp->num_outstanding_iterations = 1; *child = temp; frame->GetIteration(iter)->outstanding_frame_count++; outstanding_frames_[child_name] = temp; } } void ExecutorState::IncrementIteration(FrameState* frame, TaggedNodeSeq* ready) { frame->iteration_count++; int64 next_iter = frame->iteration_count; VLOG(2) << "Create iteration: [" << frame->frame_name << ", " << next_iter << "]"; IterationState* iter_state = new IterationState; frame->SetIteration(next_iter, iter_state); frame->num_outstanding_iterations++; frame->dead_exits.clear(); iter_state->outstanding_ops = 0; iter_state->outstanding_frame_count = 0; iter_state->pending_count = new std::vector; InitializePending(impl_->graph_, iter_state->pending_count); iter_state->dead_count = new std::vector(impl_->graph_->num_node_ids()); iter_state->input_tensors = new std::vector(impl_->total_tensors_); // Activate the successors of the deferred roots in the new iteration. ActivateNexts(frame, next_iter, ready); // Activate the loop invariants in the new iteration. ActivateLoopInvs(frame, next_iter, ready); } bool ExecutorState::SetOutputFrameIter(const TaggedNode& tagged_node, const EntryVector& outputs, FrameState** output_frame, int64* output_iter, TaggedNodeSeq* ready) { const Node* node = tagged_node.node; FrameState* input_frame = tagged_node.input_frame; int64 input_iter = tagged_node.input_iter; bool is_dead = tagged_node.is_dead; bool is_enter = IsEnter(node); if (is_enter) { FindOrCreateChildFrame(input_frame, input_iter, node, output_frame); // Propagate if this is a loop invariant. bool is_constant; Status s = GetNodeAttr(node->def(), "is_constant", &is_constant); CHECK(s.ok()) << s; if (is_constant) { AddLoopInv(*output_frame, node, outputs[0], ready); } --(*output_frame)->num_pending_inputs; *output_iter = 0; } else if (IsExit(node)) { if (is_dead) { // Stop and remember this node if it is a dead exit. if (input_iter == input_frame->iteration_count) { input_frame->dead_exits.push_back(node); } *output_frame = nullptr; } else { *output_frame = input_frame->parent_frame; *output_iter = input_frame->parent_iter; } } else if (IsNextIteration(node)) { if (is_dead) { // Stop the deadness propagation *output_frame = nullptr; } else { if (input_iter == input_frame->iteration_count && input_frame->num_outstanding_iterations == input_frame->max_parallel_iterations) { // Reached the maximum for parallel iterations. input_frame->next_iter_roots.push_back({node, outputs[0]}); *output_frame = nullptr; } else { // If this is a new iteration, start it. if (input_iter == input_frame->iteration_count) { IncrementIteration(input_frame, ready); } *output_iter = input_iter + 1; } } } return is_enter; } void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter, TaggedNodeSeq* ready) { int64 curr_iter = iter; while (curr_iter <= frame->iteration_count && IsIterationDone(frame, curr_iter)) { // Delete the iteration curr_iter VLOG(2) << "Delete iteration [" << frame->frame_name << ", " << curr_iter << "]."; delete frame->GetIteration(curr_iter); frame->SetIteration(curr_iter, nullptr); --frame->num_outstanding_iterations; ++curr_iter; // If there is a deferred iteration, start it. if (frame->next_iter_roots.size() > 0) { IncrementIteration(frame, ready); } } if (IsFrameDone(frame)) { FrameState* parent_frame = frame->parent_frame; int64 parent_iter = frame->parent_iter; // Propagate all the dead exits to the parent frame. for (const Node* node : frame->dead_exits) { auto parent_iter_state = parent_frame->GetIteration(parent_iter); std::vector* pending = parent_iter_state->pending_count; std::vector* dead_count = parent_iter_state->dead_count; for (const Edge* e : node->out_edges()) { const Node* dst_node = e->dst(); const int dst_id = dst_node->id(); bool dst_dead = true; bool dst_ready = false; // We know this is a dead input to dst if (IsMerge(dst_node)) { if (e->IsControlEdge()) { (*pending)[dst_id] -= 2; int count = (*pending)[dst_id]; dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); dst_ready = (count == 1) || ((count == 0) && dst_dead); } else { ++(*dead_count)[dst_id]; dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs()); dst_ready = ((*pending)[dst_id] == 0) && dst_dead; } } else { ++(*dead_count)[dst_id]; dst_ready = (--(*pending)[dst_id] == 0); } if (dst_ready) { ready->push_back( TaggedNode(dst_node, parent_frame, parent_iter, dst_dead)); parent_iter_state->outstanding_ops++; } } } // Delete the frame const string& frame_name = frame->frame_name; VLOG(2) << "Delete frame " << frame_name; outstanding_frames_.erase(frame_name); delete frame; // Cleanup recursively if (parent_frame != nullptr) { parent_frame->GetIteration(parent_iter)->outstanding_frame_count--; CleanupFramesIterations(parent_frame, parent_iter, ready); } } } // When ExecutorImpl graph has no control flow nodes, // SimpleExecutorState is used instead of ExecutorState. It maintains // fewer internal state and is convenient for experimenting with async // op kernels. class SimpleExecutorState { public: SimpleExecutorState(const Executor::Args& args, ExecutorImpl* impl); ~SimpleExecutorState() { for (auto it : device_context_map_) { it.second->Unref(); } delete slice_reader_cache_; } void RunAsync(Executor::DoneCallback done); private: typedef SimpleExecutorState ME; // Not owned. Rendezvous* rendezvous_; StepStatsCollector* stats_collector_; checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; FunctionCallFrame* call_frame_; const ExecutorImpl* impl_; CancellationManager* cancellation_manager_; Executor::Args::Runner runner_; // Owned. // i-th node's j-th input is in tensors_[impl_->nodes[i].input_start // + j]. The output is either a tensor pointer (pass-by-reference) // or a tensor (pass-by-value). // // NOTE: Not protected by mu_ because tensors_ is resized once. Each // element of tensors_ is written once by the source node of an edge // and is cleared by the destination of the same edge. The latter // node is never run concurrently with the former node. struct Entry { Tensor val = *kEmptyTensor; // A tensor value. Tensor* ref = nullptr; // A tensor reference. mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr. // Every entry carries an optional DeviceContext containing // Device-specific information about how the Tensor was produced. DeviceContext* device_context = nullptr; // The attributes of the allocator that creates the tensor. AllocatorAttributes alloc_attr; }; // Contains a map from node id to the DeviceContext object that was // assigned by the device at the beginning of a step. DeviceContextMap device_context_map_; std::vector input_tensors_; // Step-local resource manager. ResourceMgr step_resource_manager_; // Invoked when the execution finishes. Executor::DoneCallback done_cb_; // How many active threads of computation are being used. Same as // the number of pending Process() functions. std::atomic_int_fast32_t num_active_; mutex mu_; Status status_ GUARDED_BY(mu_); // i-th kernel is still waiting for pending[i] inputs. class CountDown { public: CountDown() : v_(0) {} void Set(int32 v) { v_.store(v); } bool Dec() { return v_.load(std::memory_order_acquire) == 1 || v_.fetch_sub(1) == 1; } private: std::atomic_int_fast32_t v_; }; std::vector pending_; // Process Node identified by "id" in current thread. "scheduled_usec" // indicates when the node becomes ready and gets scheduled. void Process(int id, int64 scheduled_usec); // Before invoking item->kernel, fills in its "inputs". Status PrepareInputs(const NodeItem& item, TensorValueVec* inputs, DeviceContextVec* input_device_contexts); // After item->kernel computation is done, processes its outputs // and returns nodes that become "ready". typedef gtl::InlinedVector ReadyNodeIds; Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, ReadyNodeIds* ready, NodeExecStats* stats); // "node" just finishes. Takes ownership of "stats". Returns true if // execution has completed. bool NodeDone(const Status& s, const Node* node, const ReadyNodeIds& ready, NodeExecStats* stats, std::deque* inline_ready); // Call Process() on all nodes in 'inline_ready'. void ProcessInline(const std::deque& inline_ready); // Schedule all the expensive nodes in 'ready', and put all the inexpensive // nodes in 'ready' into 'inline_ready'. void ScheduleReady(const ReadyNodeIds& ready, std::deque* inline_ready); // One thread of control finishes. void Finish(); TF_DISALLOW_COPY_AND_ASSIGN(SimpleExecutorState); }; SimpleExecutorState::SimpleExecutorState(const Executor::Args& args, ExecutorImpl* impl) : rendezvous_(args.rendezvous), stats_collector_(args.stats_collector), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), impl_(impl), cancellation_manager_(args.cancellation_manager), runner_(args.runner), num_active_(0), pending_(impl_->nodes_.size()) {} void SimpleExecutorState::ProcessInline(const std::deque& inline_ready) { if (inline_ready.empty()) return; int64 scheduled_usec = 0; if (stats_collector_) { scheduled_usec = nodestats::NowInUsec(); } for (int id : inline_ready) { Process(id, scheduled_usec); } } void SimpleExecutorState::ScheduleReady(const ReadyNodeIds& ready, std::deque* inline_ready) { if (ready.empty()) return; int64 scheduled_usec = 0; if (stats_collector_) { scheduled_usec = nodestats::NowInUsec(); } if (inline_ready == nullptr) { // Schedule to run all the ready ops in thread pool. for (auto id : ready) { runner_(std::bind(&ME::Process, this, id, scheduled_usec)); } return; } const std::vector& nodes = impl_->nodes_; int curr_expensive_node = -1; for (auto id : ready) { if (!nodes[id].kernel->IsExpensive()) { // Inline this inexpensive node. inline_ready->push_back(id); } else { if (curr_expensive_node != -1) { // Dispatch to another thread since there is plenty of work to // do for this thread. runner_( std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec)); } curr_expensive_node = id; } } if (curr_expensive_node != -1) { if (inline_ready->empty()) { // Tail recursion optimization inline_ready->push_back(curr_expensive_node); } else { // There are inline nodes to run already. We dispatch this expensive // node to other thread. runner_( std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec)); } } } void SimpleExecutorState::RunAsync(Executor::DoneCallback done) { const Graph* graph = impl_->graph_; ReadyNodeIds ready; // Ask the device to fill in the device context map. Device* device = impl_->params_.device; device->FillContextMap(graph, &device_context_map_); for (const Node* n : graph->nodes()) { const int id = n->id(); const int num_in_edges = n->in_edges().size(); pending_[id].Set(num_in_edges); if (num_in_edges == 0) { ready.push_back(id); } } if (ready.empty()) { done(Status::OK()); } else { num_active_ = ready.size(); done_cb_ = done; input_tensors_.resize(impl_->total_tensors_); // Schedule to run all the ready ops in thread pool. ScheduleReady(ready, nullptr); } } Status SimpleExecutorState::PrepareInputs( const NodeItem& item, TensorValueVec* inputs, DeviceContextVec* input_device_contexts) { const Node* node = item.node; inputs->clear(); inputs->resize(node->num_inputs()); input_device_contexts->clear(); input_device_contexts->resize(node->num_inputs()); for (int i = 0; i < node->num_inputs(); ++i) { const bool expect_ref = IsRefType(node->input_type(i)); Entry* entry = input_tensors_.data() + item.input_start + i; (*input_device_contexts)[i] = entry->device_context; // i-th input. TensorValue* inp = &(*inputs)[i]; if (entry->ref == nullptr) { if (expect_ref) { return AttachDef( errors::InvalidArgument(i, "-th input expects a ref type"), item.kernel->def()); } inp->tensor = &entry->val; } else { if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) { return AttachDef( errors::FailedPrecondition("Attempting to use uninitialized value ", item.kernel->def().input(i)), item.kernel->def()); } if (expect_ref) { inp->mutex_if_ref = entry->ref_mu; inp->tensor = entry->ref; } else { // Automatically deref the tensor ref when the op expects a // tensor but is given a ref to a tensor. Need to deref it // under the mutex. { mutex_lock l(*(entry->ref_mu)); entry->val = *entry->ref; } inp->tensor = &entry->val; } } } return Status::OK(); } void SimpleExecutorState::Process(int id, int64 scheduled_usec) { const std::vector& nodes = impl_->nodes_; ReadyNodeIds ready; std::deque inline_ready; // Parameters passed to OpKernel::Compute. TensorValueVec inputs; DeviceContextVec input_device_contexts; OpKernelContext::Params params; Device* device = impl_->params_.device; params.device = device; // track allocations if and only if we are collecting statistics params.track_allocations = (stats_collector_ != nullptr); params.rendezvous = rendezvous_; params.cancellation_manager = cancellation_manager_; params.call_frame = call_frame_; params.function_library = impl_->params_.function_library; params.resource_manager = device->resource_manager(); params.step_resource_manager = &step_resource_manager_; params.slice_reader_cache = slice_reader_cache_; params.inputs = &inputs; params.input_device_contexts = &input_device_contexts; params.frame_iter = FrameAndIter(0, 0); Status s; NodeExecStats* stats = nullptr; bool completed = false; inline_ready.push_back(id); while (!inline_ready.empty()) { id = inline_ready.front(); inline_ready.pop_front(); const NodeItem& item = nodes[id]; const Node* node = item.node; // Set the device_context for this node id, if it exists. auto dc_it = device_context_map_.find(id); if (dc_it != device_context_map_.end()) { params.op_device_context = dc_it->second; } if (stats_collector_) { stats = new NodeExecStats; stats->set_node_name(node->name()); nodestats::SetScheduled(stats, scheduled_usec); nodestats::SetAllStart(stats); } VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def()); // Prepares inputs. s = PrepareInputs(item, &inputs, &input_device_contexts); if (!s.ok()) { // Continue to process the nodes in 'inline_ready'. completed = NodeDone(s, item.node, ready, stats, &inline_ready); continue; } OpKernel* op_kernel = item.kernel; params.op_kernel = op_kernel; params.output_alloc_attr = [this, node, op_kernel](int index) { return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index); }; // Asynchronous computes. AsyncOpKernel* async = op_kernel->AsAsync(); if (async) { auto pcopy = CopyParams(params); auto ctx = new OpKernelContext(*pcopy); auto done = [this, item, ctx, stats, pcopy]() { VLOG(2) << this << " Async kernel done: " << SummarizeNodeDef(item.node->def()); if (stats_collector_) nodestats::SetOpEnd(stats); ReadyNodeIds ready; Status s = ProcessOutputs(item, ctx, &ready, stats); if (stats_collector_) nodestats::SetMemory(stats, ctx); // Schedule to run all the ready ops in thread pool. bool completed = NodeDone(s, item.node, ready, stats, nullptr); delete ctx; DeleteParams(pcopy); if (completed) Finish(); }; if (stats_collector_) nodestats::SetOpStart(stats); device->ComputeAsync(async, ctx, done); } else { // Synchronous computes. OpKernelContext ctx(params); if (stats_collector_) nodestats::SetOpStart(stats); device->Compute(CHECK_NOTNULL(op_kernel), &ctx); if (stats_collector_) nodestats::SetOpEnd(stats); s = ProcessOutputs(item, &ctx, &ready, stats); if (stats_collector_) nodestats::SetMemory(stats, &ctx); if (stats_collector_) { scheduled_usec = nodestats::NowInUsec(); } completed = NodeDone(s, node, ready, stats, &inline_ready); } } // while !inline_ready.empty() // This thread of computation is done if completed = true. if (completed) Finish(); } bool SimpleExecutorState::NodeDone(const Status& s, const Node* node, const ReadyNodeIds& ready, NodeExecStats* stats, std::deque* inline_ready) { if (stats_collector_) { nodestats::SetAllEnd(stats); if (!SetTimelineLabel(node, stats)) { // Only record non-transfer nodes. stats_collector_->Save(impl_->params_.device->name(), stats); } else { delete stats; } } Rendezvous* captured_rendezvous = nullptr; // Will be set on error. if (!s.ok()) { // Some error happened. This thread of computation is done. mutex_lock l(mu_); if (status_.ok()) { captured_rendezvous = rendezvous_; if (captured_rendezvous) captured_rendezvous->Ref(); status_ = s; } } if (captured_rendezvous) { // If we captured the rendezvous_ pointer, we are in an error condition. // Use captured_rendezvous, in case "this" is deleted by another thread. TRACEPRINTF("StartAbort: %s", s.ToString().c_str()); captured_rendezvous->StartAbort(s); captured_rendezvous->Unref(); } bool completed = false; int ready_size = ready.size(); if (ready_size == 0 || !s.ok()) { completed = (num_active_.fetch_sub(1) == 1); } else if (ready_size > 1) { num_active_.fetch_add(ready_size - 1, std::memory_order_relaxed); } // Schedule the ready nodes in 'ready'. if (s.ok()) { ScheduleReady(ready, inline_ready); } return completed; } void SimpleExecutorState::Finish() { mu_.lock(); auto ret = status_; auto done_cb = done_cb_; auto runner = runner_; mu_.unlock(); delete this; CHECK(done_cb != nullptr); runner([done_cb, ret]() { done_cb(ret); }); } Status SimpleExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, ReadyNodeIds* ready, NodeExecStats* stats) { Status s = ctx->status(); if (!s.ok()) { s = AttachDef(s, item.kernel->def()); LOG(WARNING) << this << " Compute status: " << s; return s; } // Processes outputs. gtl::InlinedVector outputs; const Node* node = item.node; outputs.resize(node->num_outputs()); // Get the device_context for this node id, if it exists. DeviceContext* device_context = nullptr; auto dc_it = device_context_map_.find(node->id()); if (dc_it != device_context_map_.end()) { device_context = dc_it->second; } for (int i = 0; i < node->num_outputs(); ++i) { TensorValue val = ctx->release_output(i); // Sanity check of output tensor types. DataType dtype = val->dtype(); if (val.is_ref()) dtype = MakeRefType(dtype); if (dtype == node->output_type(i)) { Entry* out = &(outputs[i]); if (val.is_ref()) { out->ref = val.tensor; out->ref_mu = val.mutex_if_ref; } else { out->val = *val.tensor; } // Set the device context of the output entry. out->device_context = device_context; // Set the allocator attributes of the output entry. out->alloc_attr = ctx->output_alloc_attr(i); if (stats_collector_ && val.tensor->IsInitialized()) { nodestats::SetOutput(stats, i, ctx->output_allocation_type(i), val.tensor); } } else { s.Update( errors::Internal("Output ", i, " of type ", DataTypeString(dtype), " does not match declared output type ", DataTypeString(node->output_type(i)), " for operation ", SummarizeNodeDef(node->def()))); } if (!val.is_ref()) { // If OpKernelContext returns outputs via pass-by-value, we // don't need this trouble. delete val.tensor; } } if (!s.ok()) return s; // Clears inputs. for (int i = 0; i < node->num_inputs(); ++i) { input_tensors_[item.input_start + i].val = *kEmptyTensor; } // Propagates outputs along out edges. ready->clear(); const std::vector& nodes = impl_->nodes_; for (const Edge* e : node->out_edges()) { const int src_slot = e->src_output(); const int dst_id = e->dst()->id(); const NodeItem& dst_item = nodes[dst_id]; if (!e->IsControlEdge()) { const int dst_slot = e->dst_input(); input_tensors_[dst_item.input_start + dst_slot] = outputs[src_slot]; } if (pending_[dst_id].Dec()) { ready->push_back(dst_id); } } return Status::OK(); } // NOTE(yuanbyu): Use the executor that supports control flow by default. const bool use_control_flow_executor = true; void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { if (params_.has_control_flow || use_control_flow_executor) { (new ExecutorState(args, this))->RunAsync(done); } else { (new SimpleExecutorState(args, this))->RunAsync(done); } } } // end namespace Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph, Executor** executor) { ExecutorImpl* impl = new ExecutorImpl(params, graph); Status s = impl->Initialize(); if (s.ok()) { *executor = impl; } else { delete impl; } return s; } Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, const NodeDef& ndef, OpKernel** kernel) { auto device_type = DeviceType(device->attributes().device_type()); auto allocator = device->GetAllocator(AllocatorAttributes()); return CreateOpKernel(device_type, device, allocator, flib, ndef, kernel); } void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } Status CreateCachedKernel(Device* device, const string& session, FunctionLibraryRuntime* flib, const NodeDef& ndef, OpKernel** kernel) { auto op_seg = device->op_segment(); auto create_fn = [device, flib, &ndef](OpKernel** kernel) { return CreateNonCachedKernel(device, flib, ndef, kernel); }; return op_seg->FindOrCreate(session, ndef.name(), kernel, create_fn); } // Deletes "kernel". void DeleteCachedKernel(Device* device, const string& session, OpKernel* kernel) { // Do nothing. } } // end namespace tensorflow