#include "tensorflow/core/graph/graph_partition.h" #include #include #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { namespace { struct DupRecvKey { int src_node_id; // Edge's src node id int src_output_slot; // Edge's src node output slot GraphDef* dst_graph; // Edge's dst node is in this subgraph bool recv_output_on_host; // The output of recv is on host }; struct DupRecvKeyHash { size_t operator()(const DupRecvKey& k) const { size_t h = Hash64(reinterpret_cast(&k.src_node_id), sizeof(k.src_node_id), k.src_output_slot); h = Hash64(reinterpret_cast(&k.dst_graph), sizeof(k.dst_graph), h); h = Hash64(reinterpret_cast(&k.recv_output_on_host), sizeof(k.recv_output_on_host), h); return h; } }; struct DupRecvKeyEq { bool operator()(const DupRecvKey& x, const DupRecvKey& y) const { return (x.src_node_id == y.src_node_id) && (x.src_output_slot == y.src_output_slot) && (x.dst_graph == y.dst_graph) && (x.recv_output_on_host == y.recv_output_on_host); } }; // struct used to store the recvs, so that start times can be properly updated struct RecvInfo { NodeDef* recv; NodeDef* real_recv; int64 start_time; }; typedef std::unordered_map DupRecvTable; // Control flow info for a graph node. struct ControlFlowInfo { const Node* frame = nullptr; // frame of a node const Node* parent_frame = nullptr; // parent frame of a node string frame_name; // frame name of a node int iter_level = -1; // level of a node }; struct PairIntHash { public: std::size_t operator()(const std::pair& x) const { return std::hash()(x.first) ^ std::hash()(x.second); } }; // A map used to store memory types for the inputs/outputs of every node. // The key is a pair of ints consisting of a node id and input/output index. typedef std::unordered_map, MemoryType, PairIntHash> MemoryTypeMap; // We collect the following information about the graph before performing // graph partitioning. struct GraphInfo { std::vector device_types; MemoryTypeMap input_types; MemoryTypeMap output_types; std::vector cf_info; }; DataType EdgeType(const Edge* e) { if (e->IsControlEdge()) { return DT_FLOAT; } else { return e->dst()->input_type(e->dst_input()); } } // Return true iff we need to add a same device send/recv for 'edge'. bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) { if (edge->IsControlEdge()) { return false; } Node* src = edge->src(); Node* dst = edge->dst(); if (src->assigned_device_name() == dst->assigned_device_name()) { int src_port = edge->src_output(); int dst_port = edge->dst_input(); if (info.device_types[src->id()] == DEVICE_GPU) { auto src_it = info.output_types.find({src->id(), src_port}); DCHECK(src_it != info.output_types.end()); auto dst_it = info.input_types.find({dst->id(), dst_port}); DCHECK(dst_it != info.input_types.end()); return src_it->second != dst_it->second; } } return false; } // Return true iff (dst, dst_input) is specified on host memory. bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) { Node* dst = edge->dst(); int dst_port = edge->dst_input(); if (info.device_types[dst->id()] == DEVICE_GPU) { if (edge->IsControlEdge()) return false; auto dst_it = info.input_types.find({dst->id(), dst_port}); DCHECK(dst_it != info.input_types.end()); return dst_it->second == HOST_MEMORY; } return true; } // Add an input to dst that comes from the "src_slot" output of the // node named by "src_name". void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { if (src_slot == Graph::kControlSlot) { dst->add_input(strings::StrCat("^", src_name)); } else if (src_slot == 0) { dst->add_input(src_name.data(), src_name.size()); } else { dst->add_input(strings::StrCat(src_name, ":", src_slot)); } } // Add a control edge from each input to each recv. void AddReadControl(const std::vector& recvs, const std::vector& inputs) { for (NodeDef* recv : recvs) { for (const string& input : inputs) { recv->add_input(strings::StrCat("^", input)); } } } void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge, NodeDefBuilder* builder) { builder->Attr("tensor_name", strings::StrCat("edge_", edge->id(), "_", edge->src()->name())); builder->Attr("send_device", edge->src()->assigned_device_name()); builder->Attr("send_device_incarnation", static_cast( opts.get_incarnation(edge->src()->assigned_device_name()))); builder->Attr("recv_device", edge->dst()->assigned_device_name()); builder->Attr("client_terminated", false); } NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, GraphDef* gdef, const Edge* edge, NodeDefBuilder::NodeOut send_from, int64 start_time, Status* status) { const DataType dtype = send_from.data_type; const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; const Node* src = edge->src(); const int src_port = edge->src_output(); // host_memory = true iff we need to use HostSend/HostCast. bool host_memory = false; if (!edge->IsControlEdge()) { auto src_it = g_info.output_types.find({src->id(), src_port}); DCHECK(src_it != g_info.output_types.end()); host_memory = (src_it->second == HOST_MEMORY); } // Add a cast node that casts dtype to cast_dtype. // NOTE(yuanbyu): Only cast for cross-device send/recv. if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) { const string cast_op = (host_memory) ? "_HostCast" : "Cast"; NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op); cast_builder.Device(src->assigned_device_name()).Input(send_from); if (opts.scheduling_for_recvs) { cast_builder.Attr("_start_time", start_time); } cast_builder.Attr("DstT", cast_dtype); NodeDef* cast = gdef->add_node(); *status = cast_builder.Finalize(cast); if (!status->ok()) return nullptr; // Connect the Send op to the cast. send_from.Reset(cast->name(), 0, cast_dtype); } // Add the send node. const string send_op = (host_memory) ? "_HostSend" : "_Send"; NodeDefBuilder send_builder(opts.new_name(src->name()), send_op); SetSendRecvAttrs(opts, edge, &send_builder); send_builder.Device(src->assigned_device_name()).Input(send_from); if (opts.scheduling_for_recvs) { send_builder.Attr("_start_time", start_time); } NodeDef* send = gdef->add_node(); *status = send_builder.Finalize(send); return send; } NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, GraphDef* gdef, const Edge* edge, NodeDef** real_recv, Status* status) { const DataType dtype = EdgeType(edge); const Node* src = edge->src(); const Node* dst = edge->dst(); const int dst_port = edge->dst_input(); DataType cast_dtype = dtype; // NOTE(yuanbyu): Only cast for cross-device send/recv. if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) { cast_dtype = opts.should_cast(edge); } // host_memory = true iff we need to use HostRecv/HostCast. bool host_memory = false; if (!edge->IsControlEdge()) { auto dst_it = g_info.input_types.find({dst->id(), dst_port}); DCHECK(dst_it != g_info.input_types.end()); host_memory = (dst_it->second == HOST_MEMORY); } // Add the recv node. const string recv_op = (host_memory) ? "_HostRecv" : "_Recv"; NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op); SetSendRecvAttrs(opts, edge, &recv_builder); recv_builder.Device(dst->assigned_device_name()) .Attr("tensor_type", cast_dtype); NodeDef* recv = gdef->add_node(); *status = recv_builder.Finalize(recv); if (!status->ok()) return nullptr; *real_recv = recv; // Add the cast node (from cast_dtype to dtype) or an Identity node. if (dtype != cast_dtype) { const string cast_op = (host_memory) ? "_HostCast" : "Cast"; NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op); cast_builder.Attr("DstT", dtype); cast_builder.Device(dst->assigned_device_name()) .Input(recv->name(), 0, cast_dtype); NodeDef* cast = gdef->add_node(); *status = cast_builder.Finalize(cast); if (!status->ok()) return nullptr; return cast; } else if (edge->IsControlEdge()) { // An Identity is only needed for control edges. NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity"); id_builder.Device(dst->assigned_device_name()) .Input(recv->name(), 0, cast_dtype); NodeDef* id = gdef->add_node(); *status = id_builder.Finalize(id); if (!status->ok()) return nullptr; return id; } else { return recv; } } NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, const Edge* edge, Status* status) { const Node* src = edge->src(); Tensor tensor(DT_FLOAT, TensorShape({0})); NodeDef* result = gdef->add_node(); *status = NodeDefBuilder(opts.new_name(src->name()), "Const") .Device(src->assigned_device_name()) .Attr("dtype", DT_FLOAT) .Attr("value", tensor) .Finalize(result); return result; } // A dummy node for scheduling. NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, const string& assigned_device_name, int64 epoch, int64 starttime, Status* status) { NodeDef* result = gdef->add_node(); *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)), "ControlTrigger") .Device(assigned_device_name) .Attr("_start_time", starttime) .Finalize(result); return result; } // Assign to each node the name of the frame and the level it belongs to. // We check the well-formedness of the graph: All inputs to a node must // come from the same frame and have the same "static" iteration level. // NOTE(yuanbyu): For now, we require all sends/recvs have iteration level // 0. This essentially means there can't be multiple serial Nexts in // an iteration, which all sane front-ends should satisfy. Status BuildControlFlowInfo(Graph* g, std::vector* info) { info->clear(); info->resize(g->num_node_ids()); Node* src_node = g->source_node(); ControlFlowInfo& src_info = (*info)[src_node->id()]; src_info.frame = src_node; src_info.parent_frame = src_node; src_info.iter_level = 0; string frame_name; std::deque ready; ready.push_back(src_node); while (!ready.empty()) { const Node* curr_node = ready.front(); ready.pop_front(); const ControlFlowInfo& curr_info = (*info)[curr_node->id()]; const Node* frame = curr_info.frame; const Node* parent = curr_info.parent_frame; frame_name = curr_info.frame_name; int iter_level = curr_info.iter_level; if (IsExit(curr_node)) { const ControlFlowInfo& parent_info = (*info)[parent->id()]; frame = parent_info.frame; parent = parent_info.parent_frame; frame_name = parent_info.frame_name; iter_level = parent_info.iter_level; } for (const Edge* out_edge : curr_node->out_edges()) { const Node* out = out_edge->dst(); int out_id = out->id(); ControlFlowInfo* out_info = &(*info)[out_id]; const Node* out_parent = out_info->parent_frame; bool is_visited = (out_info->iter_level != -1); // Skip Sink/Source nodes. if (!out->IsOp()) continue; // Add to ready queue if not seen. if (!is_visited) { ready.push_back(out); } // Process the node 'out'. if (IsEnter(out)) { if (is_visited) { const string& parent_name = (*info)[out_parent->id()].frame_name; if (parent_name != frame_name || iter_level != out_info->iter_level) { return errors::InvalidArgument( "All inputs to Enter must be from the same frame and level."); } } else { out_info->frame = out; out_info->parent_frame = frame; TF_RETURN_IF_ERROR( GetNodeAttr(out->def(), "frame_name", &out_info->frame_name)); if (out_info->frame_name.empty()) { return errors::InvalidArgument( "Enter must have a non-empty frame name."); } out_info->iter_level = 0; } } else if (IsNextIteration(out)) { if (is_visited) { if (out_info->frame_name != frame_name || out_info->iter_level != (iter_level + 1)) { return errors::InvalidArgument( "All inputs to NextIteration must be from the same frame " "and level."); } } else { out_info->frame = frame; out_info->parent_frame = parent; out_info->frame_name = frame_name; out_info->iter_level = iter_level + 1; } } else { if (is_visited) { if (out_info->frame_name != frame_name) { return errors::InvalidArgument( "All inputs to a node must be from the same frame."); } } else { out_info->frame = frame; out_info->parent_frame = parent; out_info->frame_name = frame_name; out_info->iter_level = iter_level; } } } } return Status::OK(); } string ControlLoopName(const string& name) { return strings::StrCat("_cloop", name); } bool IsControlLoop(const Node* node) { const string& name = node->def().name(); return StringPiece(name).starts_with("_cloop"); } // An enter node for control flow. Node* AddControlEnter(Graph* g, const string& node_name, const string& device_name, const string& frame_name, const int parallel_iterations, Status* status) { NodeBuilder node_builder(node_name, "Enter", g->op_registry()); node_builder.Input({"dummy", 0, DT_FLOAT}); node_builder.Attr("frame_name", frame_name); node_builder.Attr("parallel_iterations", parallel_iterations); Node* res_node; *status = node_builder.Finalize(g, &res_node); if (!status->ok()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; } // A merge node for control flow. Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, const string& node_name, const string& device_name, Status* status) { NodeBuilder node_builder(node_name, "Merge", g->op_registry()); node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}}); Node* res_node; *status = node_builder.Finalize(g, &res_node); if (!status->ok()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; } // A switch node for control flow. Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, const string& device_name, const GraphDefBuilder::Options& bopts) { Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts); if (bopts.HaveError()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; } // A next_iteration node for control flow. Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name, const GraphDefBuilder::Options& bopts) { Node* res_node = ops::UnaryOp("NextIteration", input, bopts); if (bopts.HaveError()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; } Node* EmptyConst(const GraphDefBuilder::Options& options) { if (options.HaveError()) return nullptr; NodeBuilder node_builder(options.GetNameForOp("Const"), "Const", options.op_registry()); const DataType dt = DataTypeToEnum::v(); TensorProto proto; proto.set_dtype(dt); TensorShape empty_shape({0}); empty_shape.AsProto(proto.mutable_tensor_shape()); node_builder.Attr("dtype", dt).Attr("value", proto); return options.FinalizeBuilder(&node_builder); } // A dummy const node for control flow. Node* AddControlConst(const string& device_name, const GraphDefBuilder::Options& bopts) { Node* res_node = EmptyConst(bopts); if (bopts.HaveError()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; } // A synthetic loop, made up of dummy nodes. It performs control-flow actions // on behalf of a leader on a different device. struct ControlLoop { Node* enter = nullptr; Node* merge = nullptr; Node* switch_node = nullptr; }; // Add the control flow info of a new node added during partitioning. // The new node has the same control flow info as edge->src(). void AddControlFlowInfo(const Node* node, const Node* src, std::vector* cf_info) { int id = node->id(); if (static_cast(id) >= cf_info->size()) { cf_info->resize(id + 1); } const ControlFlowInfo& src_info = (*cf_info)[src->id()]; ControlFlowInfo* info = &(*cf_info)[id]; info->frame = src_info.frame; info->parent_frame = src_info.parent_frame; info->frame_name = src_info.frame_name; info->iter_level = src_info.iter_level; } // Constructs a control loop. Returns a struct containing the newly created // enter, merge, and switch nodes. The enter and merge nodes are used in the // recursive construction of control loops for nested frames (loops). The // switch node will be connected to the LoopCond node. The merge node will // be connected to all the recvs of the same frame by control edges when // the actual partitioning happens. Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, const Edge* edge, Node* loop_cond, std::vector* cf_info, ControlLoop* loop) { Status status; GraphDefBuilder::Options bopts(g, &status); const ControlFlowInfo& src_info = (*cf_info)[src->id()]; const string& device_name = edge->dst()->assigned_device_name(); const string& frame_name = src_info.frame_name; int parallel_iterations; status = GetNodeAttr(src_info.frame->def(), "parallel_iterations", ¶llel_iterations); if (!status.ok()) return status; // The names of the nodes to be added. const string& enter_name = ControlLoopName(opts.new_name(edge->dst()->name())); const string& merge_name = ControlLoopName(opts.new_name(edge->dst()->name())); const string& switch_name = ControlLoopName(opts.new_name(edge->dst()->name())); const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name())); // Add the nodes to the graph g. Node* enter = AddControlEnter(g, enter_name, device_name, frame_name, parallel_iterations, &status); if (!status.ok()) return status; Node* merge = AddControlMerge(enter_name, next_name, g, merge_name, device_name, &status); if (!status.ok()) return status; Node* switch_node = AddControlSwitch(merge, loop_cond, device_name, bopts.WithName(switch_name)); if (!status.ok()) return status; Node* next = AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name)); if (!status.ok()) return status; // Add control flow info for these new nodes: AddControlFlowInfo(enter, src, cf_info); AddControlFlowInfo(merge, src, cf_info); AddControlFlowInfo(switch_node, src, cf_info); AddControlFlowInfo(next, src, cf_info); // Add input edges for the newly created merge node: g->AddEdge(enter, 0, merge, 0); g->AddEdge(next, 0, merge, 1); loop->enter = enter; loop->merge = merge; loop->switch_node = switch_node; return Status::OK(); } // Build memory and device type info for every node in the graph. // TODO(yuanbyu): It might be simpler if we convert MemoryType to // DeviceType for the inputs/outputs of each node. Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { Status status; MemoryTypeVector input_memory_types; MemoryTypeVector output_memory_types; info->device_types.resize(g.num_node_ids(), DEVICE_CPU); for (const Node* node : g.nodes()) { if (!node->IsOp()) continue; // Skip Sink/Source nodes. DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(), &parsed)) { return errors::Internal("Malformed assigned device '", node->assigned_device_name(), "'"); } input_memory_types.clear(); input_memory_types.resize(node->num_inputs()); output_memory_types.clear(); output_memory_types.resize(node->num_outputs()); status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type), node->def(), &input_memory_types, &output_memory_types); if (!status.ok()) return status; int node_id = node->id(); info->device_types[node_id] = DeviceType(parsed.type); for (size_t i = 0; i < input_memory_types.size(); ++i) { info->input_types[{node_id, i}] = input_memory_types[i]; } for (size_t i = 0; i < output_memory_types.size(); ++i) { info->output_types[{node_id, i}] = output_memory_types[i]; } } return status; } // Each participating device needs to decide a) if there is a next iteration, // and b) if the loop terminates. We take the approach to encode this control // flow logic in the dataflow graph. There are at least two possible encodings. // In a completely decentralized encoding, the participants communicate peer // to peer. The other encoding uses a frame leader (the participant who owns // the pivot termination predicate) to broadcast the termination condition to // all the participants. For now we take the latter because it is simpler. // // TODO(yuanbyu): The correctness of this construction is rather subtle. I got // it wrong many times so it would be nice to write a proof to be sure. Status AddControlFlow(const PartitionOptions& opts, Graph* g, GraphInfo* g_info) { Status status; GraphDefBuilder::Options bopts(g, &status); std::vector& cf_info = g_info->cf_info; // Build the control flow info for every node. status = BuildControlFlowInfo(g, &cf_info); if (!status.ok()) return status; // The map from frames to their LoopCond nodes. std::unordered_map frame_cond_map; int num_node_ids = g->num_node_ids(); for (int i = 0; i < num_node_ids; ++i) { Node* node = g->FindNodeId(i); if (node == nullptr) continue; if (IsLoopCond(node)) { const string& frame_name = cf_info[node->id()].frame_name; DCHECK(!frame_name.empty()); frame_cond_map[frame_name] = node; } } // Add all control loops for cross-device frames. // A control loop is added only when there is a cross-device edge in a // non-root frame. Nothing is added if there is no loops. We also don't // add anything for a frame that is completely local to a device. For // nested loops, we stack the control loops together by connecting // the merge of the outer loop to the enter of the inner loop. // // A map from to ControlLoop. std::unordered_map control_loops; int num_edge_ids = g->num_edge_ids(); for (int i = 0; i < num_edge_ids; ++i) { const Edge* edge = g->FindEdgeId(i); if (edge == nullptr) continue; const Node* src = edge->src(); const Node* dst = edge->dst(); // Skip Sink/Source nodes. if (!src->IsOp() || !dst->IsOp()) continue; const string& src_device = src->assigned_device_name(); const string& dst_device = dst->assigned_device_name(); // Skip local edges. if (src_device == dst_device) continue; const string& src_frame = cf_info[src->id()].frame_name; const string& dst_frame = cf_info[dst->id()].frame_name; // Skip if src and dst are not in the same frame. if (src_frame.empty() || src_frame != dst_frame) { continue; } // Add the control loop. Start by adding the control loop for the // current frame if needed, and recursively adding the control loop // for its outer frame when nested. ControlLoop child_loop; while (true) { const string& curr_frame = cf_info[src->id()].frame_name; if (curr_frame.empty()) { // We have reached the root frame. if (child_loop.merge != nullptr) { const string& node_name = opts.new_name(edge->dst()->name()); const string& device_name = edge->dst()->assigned_device_name(); Node* const_node = AddControlConst(device_name, bopts.WithName(node_name)); if (!status.ok()) return status; AddControlFlowInfo(const_node, src, &cf_info); g->AddEdge(const_node, 0, child_loop.enter, 0); } break; } const string& cl_key = strings::StrCat(curr_frame, "$$", dst_device); auto it = control_loops.find(cl_key); if (it != control_loops.end()) { if (child_loop.enter != nullptr) { g->AddEdge(it->second.merge, 0, child_loop.enter, 0); } break; } // Get the frame's LoopCond. auto cond_it = frame_cond_map.find(curr_frame); if (cond_it == frame_cond_map.end()) { return errors::InvalidArgument( "A cross-device loop must have a pivot predicate: ", curr_frame); } Node* loop_cond = cond_it->second; // Add the control loop. ControlLoop curr_loop; status = AddControlLoop(opts, g, src, edge, loop_cond, &cf_info, &curr_loop); if (!status.ok()) return status; control_loops[cl_key] = curr_loop; if (child_loop.enter != nullptr) { // Connect the merge of the outer loop to the enter of the inner. g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0); } src = cf_info[src->id()].parent_frame; child_loop = curr_loop; } } // For a cross-device edge, on the dst device, add a control edge // from the merge node of the control loop to dst. If a send/recv is // introduced for this edge in future partitioning, we delete this // control edge and add a new control edge from the merge to the recv. num_edge_ids = g->num_edge_ids(); for (int i = 0; i < num_edge_ids; ++i) { const Edge* edge = g->FindEdgeId(i); if (edge == nullptr) continue; const Node* src = edge->src(); Node* dst = edge->dst(); // Skip Sink/Source nodes. if (!src->IsOp() || !dst->IsOp()) continue; const string& src_device = src->assigned_device_name(); const string& dst_device = dst->assigned_device_name(); if (src_device != dst_device) { const string& src_frame = cf_info[src->id()].frame_name; const string& dst_frame = cf_info[dst->id()].frame_name; if (!src_frame.empty() && src_frame == dst_frame) { const string& cl_key = strings::StrCat(dst_frame, "$$", dst_device); ControlLoop loop = control_loops[cl_key]; DCHECK(loop.enter != nullptr); g->AddControlEdge(loop.merge, dst); } } } return Status::OK(); } } // end namespace Status AddControlEdges(const PartitionOptions& opts, std::unordered_map* partitions) { Status status; // TODO(yuanbyu): Very naive for now. To be improved. const int num_epochs = 100; const int prefetch = 6; typedef std::pair NodeStartTime; for (auto& part : *partitions) { GraphDef* gdef = &part.second; std::vector start_times; start_times.resize(gdef->node_size()); for (int n = 0; n < gdef->node_size(); ++n) { const NodeDef& ndef = gdef->node(n); int64 start_time; status = GetNodeAttr(ndef, "_start_time", &start_time); if (!status.ok()) { return status; } start_times[n] = std::make_pair(&ndef, start_time); } // Sort the nodes based on their start times. std::sort( start_times.begin(), start_times.end(), [](NodeStartTime x, NodeStartTime y) { return x.second < y.second; }); // Add a dummy node for every epoch, and add a control edge from the // "last" node in the preceding epoch to the dummy node. string device_name = gdef->node(0).device(); int64 makespan = start_times.back().second; int64 resolution = (makespan / num_epochs) + 1; int i = 0; int j = 0; std::vector dummys; while (i < num_epochs && static_cast(j) < start_times.size()) { if (i * resolution > start_times[j].second) { j++; } else { NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i, i * resolution, &status); if (!status.ok()) { return status; } dummys.push_back(dummy); if (j > 0) { string src_name = start_times[j - 1].first->name(); AddInput(dummy, src_name, Graph::kControlSlot); } i++; } } // Finally, add the control edges to recvs. for (int n = 0; n < gdef->node_size(); ++n) { NodeDef* ndef = gdef->mutable_node(n); if (ndef->op() == "_Recv") { int64 start_time; status = GetNodeAttr(*ndef, "_start_time", &start_time); if (!status.ok()) { return status; } int recv_epoch = start_time / resolution; if (recv_epoch >= prefetch) { NodeDef* dummy = dummys[recv_epoch - prefetch]; AddInput(ndef, dummy->name(), Graph::kControlSlot); } } } } return Status::OK(); } Status Partition(const PartitionOptions& opts, Graph* g, std::unordered_map* partitions) { Status status; partitions->clear(); GraphInfo g_info; if (!opts.control_flow_added) { // Add the "code" for distributed execution of control flow. Code is // added only for the frames that are placed on multiple devices. The // new graph is an equivalent transformation of the original graph and // has the property that it can be subsequently partitioned arbitrarily // (down to the level of individual device) for distributed execution. status = AddControlFlow(opts, g, &g_info); if (!status.ok()) return status; } // At this point, all the graph mutations have been done. Build memory // and device type info for every node and edge in the graph. status = BuildMemoryDeviceInfo(*g, &g_info); if (!status.ok()) return status; string dstp; std::vector inputs; DupRecvTable dup_recv(3); // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref // edge to dst. We will add a control edge for every pair in // (ref_recvs x ref_control_inputs). std::vector ref_recvs; std::vector ref_control_inputs; int32 num_data = 0; int32 num_control = 0; for (const Node* dst : g->nodes()) { if (!dst->IsOp()) continue; // Skip Sink/Source nodes. dstp = opts.node_to_loc(dst); GraphDef* dst_graph = &(*partitions)[dstp]; NodeDef* dst_def = dst_graph->add_node(); *dst_def = dst->def(); dst_def->set_device(dst->assigned_device_name()); dst_def->clear_input(); // Inputs are filled below if (opts.need_to_record_start_times) { int64 start_time = opts.start_times[dst->id()].value(); AddNodeAttr("_start_time", start_time, dst_def); } // Arrange the incoming edges to dst so that input[i] holds the // input flowing into slot numbered i. Trailing entries in input[] // hold control edges. inputs.clear(); inputs.resize(dst->num_inputs(), nullptr); ref_recvs.clear(); ref_control_inputs.clear(); const Edge* control_flow_edge = nullptr; for (const Edge* edge : dst->in_edges()) { if (edge->IsControlEdge()) { if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { // This is one of the control edges added for control flow. There // can be multiple such edges as the dest node may have multiple // remote inputs. We will just take one and ignore the others. control_flow_edge = edge; } else { inputs.push_back(edge); } } else { DCHECK(inputs[edge->dst_input()] == nullptr); inputs[edge->dst_input()] = edge; } } // Process in order so that all data edges are added as inputs to // dst in Edge::dst_input() order. bool recv_added = false; for (const Edge* edge : inputs) { const Node* src = edge->src(); if (!src->IsOp()) continue; // Skip Sink/Source nodes. GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) { // Same partition and compatible memory types: AddInput(dst_def, src->name(), edge->src_output()); if (edge->IsControlEdge() || !IsRefType(src->output_type(edge->src_output()))) { ref_control_inputs.push_back(src->name()); } continue; } int64 send_start_time = 0; int64 recv_start_time = 0; if (opts.scheduling_for_recvs) { if (opts.need_to_record_start_times) { send_start_time = opts.start_times[src->id()].value(); recv_start_time = opts.start_times[dst->id()].value(); } else { status = GetNodeAttr(src->def(), "_start_time", &send_start_time); if (!status.ok()) { return status; } status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time); if (!status.ok()) { return status; } } } // Check whether there is already a send/recv pair transferring // the same tensor/control from the src to dst partition. const bool on_host = IsDstInputOnHost(edge, g_info); DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; auto iter = dup_recv.find(key); if (iter != dup_recv.end()) { // We found one. Reuse the data/control transferred already. const string& recv_node_name = iter->second.recv->name(); if (edge->IsControlEdge()) { AddInput(dst_def, recv_node_name, Graph::kControlSlot); } else { AddInput(dst_def, recv_node_name, 0); } // We want the start_time for the recv to be the smallest of the start // times of it's consumers. So we update this whenever we use a recv, // and write it out to the attribute at the end of the subroutine if (iter->second.start_time > recv_start_time) { iter->second.start_time = recv_start_time; } continue; } NodeDefBuilder::NodeOut send_from; if (edge->IsControlEdge()) { // Insert a dummy const node that will generate a tiny // data element to be sent from send to recv. VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" << src->name() << "] -> " << dst->assigned_device_name() << "[" << dst->name() << "]"; NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status); if (!status.ok()) return status; // Set the start time for this dummy node. if (opts.scheduling_for_recvs) { AddNodeAttr("_start_time", send_start_time, dummy); } AddInput(dummy, src->name(), Graph::kControlSlot); send_from.Reset(dummy->name(), 0, DT_FLOAT); } else { send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); } // Need to split edge by placing matching send/recv nodes on // the src/dst sides of the edge. NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from, send_start_time, &status); if (!status.ok()) return status; NodeDef* real_recv = nullptr; NodeDef* recv = AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status); if (!status.ok()) return status; // Fix up the control flow edge. Redirect it to the recv. // NOTE(yuanbyu): 'real_recv' must be the real recv node. recv_added = true; if (control_flow_edge != nullptr) { AddInput(real_recv, control_flow_edge->src()->name(), Graph::kControlSlot); } // For same device send/recv, add a control edge from send to recv. // This prevents the asynchronous recv kernel from being scheduled // immediately. if (src_graph == dst_graph) { AddInput(real_recv, send->name(), Graph::kControlSlot); } if (!edge->IsControlEdge() && IsRefType(src->output_type(edge->src_output()))) { // If src is of ref type and the edge is not a control edge, dst has // read semantics and therefore we must control the recv. ref_recvs.push_back(real_recv); } else { // Memorize the send/recv pair, only if this is not a "ref" edge. // NOTE(yuanbyu): Collapsing ref edges requires extreme care so // for now we don't do it. dup_recv[key] = {recv, real_recv, recv_start_time}; ref_control_inputs.push_back(recv->name()); } if (edge->IsControlEdge()) { ++num_control; AddInput(dst_def, recv->name(), Graph::kControlSlot); } else { ++num_data; AddInput(dst_def, recv->name(), 0); } } // Add control edges from 'ref_control_inputs' to 'ref_recvs'. // NOTE(yuanbyu): Adding these control edges should not introduce // deadlocks. 'dst' has implicit "read" nodes that, when we split // across devices, are made explicit; Retargettig the dependencies // to 'dst' to those nodes would not introduce cycles if there isn't // one before the transformation. // NOTE(yuanbyu): This may impact performance because it defers the // execution of recvs until all the other inputs become available. AddReadControl(ref_recvs, ref_control_inputs); // Add back this control edge for control flow if not used. if (!recv_added && (control_flow_edge != nullptr)) { AddInput(dst_def, control_flow_edge->src()->name(), Graph::kControlSlot); } } // Set the start times for recvs at the very end. if (opts.scheduling_for_recvs) { for (auto& it : dup_recv) { AddNodeAttr("_start_time", it.second.start_time, it.second.recv); if (it.second.real_recv != it.second.recv) { AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv); } } } VLOG(1) << "Added send/recv: controls=" << num_control << ", data=" << num_data; return Status::OK(); } } // namespace tensorflow