aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_partition.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/graph_partition.cc')
-rw-r--r--tensorflow/core/graph/graph_partition.cc1050
1 files changed, 1050 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
new file mode 100644
index 0000000000..1571790e59
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -0,0 +1,1050 @@
+#include "tensorflow/core/graph/graph_partition.h"
+
+#include <deque>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+struct DupRecvKey {
+ int src_node_id; // Edge's src node id
+ int src_output_slot; // Edge's src node output slot
+ GraphDef* dst_graph; // Edge's dst node is in this subgraph
+ bool recv_output_on_host; // The output of recv is on host
+};
+
+struct DupRecvKeyHash {
+ size_t operator()(const DupRecvKey& k) const {
+ size_t h = Hash64(reinterpret_cast<const char*>(&k.src_node_id),
+ sizeof(k.src_node_id), k.src_output_slot);
+ h = Hash64(reinterpret_cast<const char*>(&k.dst_graph), sizeof(k.dst_graph),
+ h);
+ h = Hash64(reinterpret_cast<const char*>(&k.recv_output_on_host),
+ sizeof(k.recv_output_on_host), h);
+ return h;
+ }
+};
+
+struct DupRecvKeyEq {
+ bool operator()(const DupRecvKey& x, const DupRecvKey& y) const {
+ return (x.src_node_id == y.src_node_id) &&
+ (x.src_output_slot == y.src_output_slot) &&
+ (x.dst_graph == y.dst_graph) &&
+ (x.recv_output_on_host == y.recv_output_on_host);
+ }
+};
+
+// struct used to store the recvs, so that start times can be properly updated
+struct RecvInfo {
+ NodeDef* recv;
+ NodeDef* real_recv;
+ int64 start_time;
+};
+
+typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq>
+ DupRecvTable;
+
+// Control flow info for a graph node.
+struct ControlFlowInfo {
+ const Node* frame = nullptr; // frame of a node
+ const Node* parent_frame = nullptr; // parent frame of a node
+ string frame_name; // frame name of a node
+ int iter_level = -1; // level of a node
+};
+
+struct PairIntHash {
+ public:
+ std::size_t operator()(const std::pair<int, int>& x) const {
+ return std::hash<int>()(x.first) ^ std::hash<int>()(x.second);
+ }
+};
+// A map used to store memory types for the inputs/outputs of every node.
+// The key is a pair of ints consisting of a node id and input/output index.
+typedef std::unordered_map<std::pair<int, int>, MemoryType, PairIntHash>
+ MemoryTypeMap;
+
+// We collect the following information about the graph before performing
+// graph partitioning.
+struct GraphInfo {
+ std::vector<DeviceType> device_types;
+ MemoryTypeMap input_types;
+ MemoryTypeMap output_types;
+ std::vector<ControlFlowInfo> cf_info;
+};
+
+DataType EdgeType(const Edge* e) {
+ if (e->IsControlEdge()) {
+ return DT_FLOAT;
+ } else {
+ return e->dst()->input_type(e->dst_input());
+ }
+}
+
+// Return true iff we need to add a same device send/recv for 'edge'.
+bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
+ if (edge->IsControlEdge()) {
+ return false;
+ }
+
+ Node* src = edge->src();
+ Node* dst = edge->dst();
+ if (src->assigned_device_name() == dst->assigned_device_name()) {
+ int src_port = edge->src_output();
+ int dst_port = edge->dst_input();
+ if (info.device_types[src->id()] == DEVICE_GPU) {
+ auto src_it = info.output_types.find({src->id(), src_port});
+ DCHECK(src_it != info.output_types.end());
+ auto dst_it = info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != info.input_types.end());
+ return src_it->second != dst_it->second;
+ }
+ }
+ return false;
+}
+
+// Return true iff (dst, dst_input) is specified on host memory.
+bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
+ Node* dst = edge->dst();
+ int dst_port = edge->dst_input();
+ if (info.device_types[dst->id()] == DEVICE_GPU) {
+ if (edge->IsControlEdge()) return false;
+ auto dst_it = info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != info.input_types.end());
+ return dst_it->second == HOST_MEMORY;
+ }
+ return true;
+}
+
+// Add an input to dst that comes from the "src_slot" output of the
+// node named by "src_name".
+void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
+ if (src_slot == Graph::kControlSlot) {
+ dst->add_input(strings::StrCat("^", src_name));
+ } else if (src_slot == 0) {
+ dst->add_input(src_name.data(), src_name.size());
+ } else {
+ dst->add_input(strings::StrCat(src_name, ":", src_slot));
+ }
+}
+
+// Add a control edge from each input to each recv.
+void AddReadControl(const std::vector<NodeDef*>& recvs,
+ const std::vector<string>& inputs) {
+ for (NodeDef* recv : recvs) {
+ for (const string& input : inputs) {
+ recv->add_input(strings::StrCat("^", input));
+ }
+ }
+}
+
+void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
+ NodeDefBuilder* builder) {
+ builder->Attr("tensor_name",
+ strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
+ builder->Attr("send_device", edge->src()->assigned_device_name());
+ builder->Attr("send_device_incarnation",
+ static_cast<int64>(
+ opts.get_incarnation(edge->src()->assigned_device_name())));
+ builder->Attr("recv_device", edge->dst()->assigned_device_name());
+ builder->Attr("client_terminated", false);
+}
+
+NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
+ GraphDef* gdef, const Edge* edge,
+ NodeDefBuilder::NodeOut send_from, int64 start_time,
+ Status* status) {
+ const DataType dtype = send_from.data_type;
+ const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
+ const Node* src = edge->src();
+ const int src_port = edge->src_output();
+
+ // host_memory = true iff we need to use HostSend/HostCast.
+ bool host_memory = false;
+ if (!edge->IsControlEdge()) {
+ auto src_it = g_info.output_types.find({src->id(), src_port});
+ DCHECK(src_it != g_info.output_types.end());
+ host_memory = (src_it->second == HOST_MEMORY);
+ }
+
+ // Add a cast node that casts dtype to cast_dtype.
+ // NOTE(yuanbyu): Only cast for cross-device send/recv.
+ if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
+ const string cast_op = (host_memory) ? "_HostCast" : "Cast";
+ NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
+ cast_builder.Device(src->assigned_device_name()).Input(send_from);
+ if (opts.scheduling_for_recvs) {
+ cast_builder.Attr("_start_time", start_time);
+ }
+ cast_builder.Attr("DstT", cast_dtype);
+ NodeDef* cast = gdef->add_node();
+ *status = cast_builder.Finalize(cast);
+ if (!status->ok()) return nullptr;
+
+ // Connect the Send op to the cast.
+ send_from.Reset(cast->name(), 0, cast_dtype);
+ }
+
+ // Add the send node.
+ const string send_op = (host_memory) ? "_HostSend" : "_Send";
+ NodeDefBuilder send_builder(opts.new_name(src->name()), send_op);
+ SetSendRecvAttrs(opts, edge, &send_builder);
+ send_builder.Device(src->assigned_device_name()).Input(send_from);
+ if (opts.scheduling_for_recvs) {
+ send_builder.Attr("_start_time", start_time);
+ }
+ NodeDef* send = gdef->add_node();
+ *status = send_builder.Finalize(send);
+ return send;
+}
+
+NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
+ GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
+ Status* status) {
+ const DataType dtype = EdgeType(edge);
+ const Node* src = edge->src();
+ const Node* dst = edge->dst();
+ const int dst_port = edge->dst_input();
+ DataType cast_dtype = dtype;
+
+ // NOTE(yuanbyu): Only cast for cross-device send/recv.
+ if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
+ cast_dtype = opts.should_cast(edge);
+ }
+
+ // host_memory = true iff we need to use HostRecv/HostCast.
+ bool host_memory = false;
+ if (!edge->IsControlEdge()) {
+ auto dst_it = g_info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != g_info.input_types.end());
+ host_memory = (dst_it->second == HOST_MEMORY);
+ }
+
+ // Add the recv node.
+ const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
+ NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op);
+ SetSendRecvAttrs(opts, edge, &recv_builder);
+ recv_builder.Device(dst->assigned_device_name())
+ .Attr("tensor_type", cast_dtype);
+ NodeDef* recv = gdef->add_node();
+ *status = recv_builder.Finalize(recv);
+ if (!status->ok()) return nullptr;
+ *real_recv = recv;
+
+ // Add the cast node (from cast_dtype to dtype) or an Identity node.
+ if (dtype != cast_dtype) {
+ const string cast_op = (host_memory) ? "_HostCast" : "Cast";
+ NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
+ cast_builder.Attr("DstT", dtype);
+ cast_builder.Device(dst->assigned_device_name())
+ .Input(recv->name(), 0, cast_dtype);
+ NodeDef* cast = gdef->add_node();
+ *status = cast_builder.Finalize(cast);
+ if (!status->ok()) return nullptr;
+ return cast;
+ } else if (edge->IsControlEdge()) {
+ // An Identity is only needed for control edges.
+ NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity");
+ id_builder.Device(dst->assigned_device_name())
+ .Input(recv->name(), 0, cast_dtype);
+ NodeDef* id = gdef->add_node();
+ *status = id_builder.Finalize(id);
+ if (!status->ok()) return nullptr;
+ return id;
+ } else {
+ return recv;
+ }
+}
+
+NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
+ const Edge* edge, Status* status) {
+ const Node* src = edge->src();
+ Tensor tensor(DT_FLOAT, TensorShape({0}));
+ NodeDef* result = gdef->add_node();
+ *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
+ .Device(src->assigned_device_name())
+ .Attr("dtype", DT_FLOAT)
+ .Attr("value", tensor)
+ .Finalize(result);
+ return result;
+}
+
+// A dummy node for scheduling.
+NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
+ const string& assigned_device_name, int64 epoch,
+ int64 starttime, Status* status) {
+ NodeDef* result = gdef->add_node();
+ *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
+ "ControlTrigger")
+ .Device(assigned_device_name)
+ .Attr("_start_time", starttime)
+ .Finalize(result);
+ return result;
+}
+
+// Assign to each node the name of the frame and the level it belongs to.
+// We check the well-formedness of the graph: All inputs to a node must
+// come from the same frame and have the same "static" iteration level.
+// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level
+// 0. This essentially means there can't be multiple serial Nexts in
+// an iteration, which all sane front-ends should satisfy.
+Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
+ info->clear();
+ info->resize(g->num_node_ids());
+
+ Node* src_node = g->source_node();
+ ControlFlowInfo& src_info = (*info)[src_node->id()];
+ src_info.frame = src_node;
+ src_info.parent_frame = src_node;
+ src_info.iter_level = 0;
+
+ string frame_name;
+ std::deque<const Node*> ready;
+ ready.push_back(src_node);
+ while (!ready.empty()) {
+ const Node* curr_node = ready.front();
+ ready.pop_front();
+ const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
+ const Node* frame = curr_info.frame;
+ const Node* parent = curr_info.parent_frame;
+ frame_name = curr_info.frame_name;
+ int iter_level = curr_info.iter_level;
+
+ if (IsExit(curr_node)) {
+ const ControlFlowInfo& parent_info = (*info)[parent->id()];
+ frame = parent_info.frame;
+ parent = parent_info.parent_frame;
+ frame_name = parent_info.frame_name;
+ iter_level = parent_info.iter_level;
+ }
+
+ for (const Edge* out_edge : curr_node->out_edges()) {
+ const Node* out = out_edge->dst();
+ int out_id = out->id();
+ ControlFlowInfo* out_info = &(*info)[out_id];
+ const Node* out_parent = out_info->parent_frame;
+ bool is_visited = (out_info->iter_level != -1);
+
+ // Skip Sink/Source nodes.
+ if (!out->IsOp()) continue;
+
+ // Add to ready queue if not seen.
+ if (!is_visited) {
+ ready.push_back(out);
+ }
+
+ // Process the node 'out'.
+ if (IsEnter(out)) {
+ if (is_visited) {
+ const string& parent_name = (*info)[out_parent->id()].frame_name;
+ if (parent_name != frame_name || iter_level != out_info->iter_level) {
+ return errors::InvalidArgument(
+ "All inputs to Enter must be from the same frame and level.");
+ }
+ } else {
+ out_info->frame = out;
+ out_info->parent_frame = frame;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(out->def(), "frame_name", &out_info->frame_name));
+ if (out_info->frame_name.empty()) {
+ return errors::InvalidArgument(
+ "Enter must have a non-empty frame name.");
+ }
+ out_info->iter_level = 0;
+ }
+ } else if (IsNextIteration(out)) {
+ if (is_visited) {
+ if (out_info->frame_name != frame_name ||
+ out_info->iter_level != (iter_level + 1)) {
+ return errors::InvalidArgument(
+ "All inputs to NextIteration must be from the same frame "
+ "and level.");
+ }
+ } else {
+ out_info->frame = frame;
+ out_info->parent_frame = parent;
+ out_info->frame_name = frame_name;
+ out_info->iter_level = iter_level + 1;
+ }
+ } else {
+ if (is_visited) {
+ if (out_info->frame_name != frame_name) {
+ return errors::InvalidArgument(
+ "All inputs to a node must be from the same frame.");
+ }
+ } else {
+ out_info->frame = frame;
+ out_info->parent_frame = parent;
+ out_info->frame_name = frame_name;
+ out_info->iter_level = iter_level;
+ }
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+string ControlLoopName(const string& name) {
+ return strings::StrCat("_cloop", name);
+}
+
+bool IsControlLoop(const Node* node) {
+ const string& name = node->def().name();
+ return StringPiece(name).starts_with("_cloop");
+}
+
+// An enter node for control flow.
+Node* AddControlEnter(Graph* g, const string& node_name,
+ const string& device_name, const string& frame_name,
+ const int parallel_iterations, Status* status) {
+ NodeBuilder node_builder(node_name, "Enter", g->op_registry());
+ node_builder.Input({"dummy", 0, DT_FLOAT});
+ node_builder.Attr("frame_name", frame_name);
+ node_builder.Attr("parallel_iterations", parallel_iterations);
+ Node* res_node;
+ *status = node_builder.Finalize(g, &res_node);
+ if (!status->ok()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A merge node for control flow.
+Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
+ const string& node_name, const string& device_name,
+ Status* status) {
+ NodeBuilder node_builder(node_name, "Merge", g->op_registry());
+ node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
+ Node* res_node;
+ *status = node_builder.Finalize(g, &res_node);
+ if (!status->ok()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A switch node for control flow.
+Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
+ const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A next_iteration node for control flow.
+Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = ops::UnaryOp("NextIteration", input, bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+Node* EmptyConst(const GraphDefBuilder::Options& options) {
+ if (options.HaveError()) return nullptr;
+ NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
+ options.op_registry());
+ const DataType dt = DataTypeToEnum<float>::v();
+ TensorProto proto;
+ proto.set_dtype(dt);
+ TensorShape empty_shape({0});
+ empty_shape.AsProto(proto.mutable_tensor_shape());
+ node_builder.Attr("dtype", dt).Attr("value", proto);
+ return options.FinalizeBuilder(&node_builder);
+}
+
+// A dummy const node for control flow.
+Node* AddControlConst(const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = EmptyConst(bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A synthetic loop, made up of dummy nodes. It performs control-flow actions
+// on behalf of a leader on a different device.
+struct ControlLoop {
+ Node* enter = nullptr;
+ Node* merge = nullptr;
+ Node* switch_node = nullptr;
+};
+
+// Add the control flow info of a new node added during partitioning.
+// The new node has the same control flow info as edge->src().
+void AddControlFlowInfo(const Node* node, const Node* src,
+ std::vector<ControlFlowInfo>* cf_info) {
+ int id = node->id();
+ if (static_cast<size_t>(id) >= cf_info->size()) {
+ cf_info->resize(id + 1);
+ }
+ const ControlFlowInfo& src_info = (*cf_info)[src->id()];
+ ControlFlowInfo* info = &(*cf_info)[id];
+ info->frame = src_info.frame;
+ info->parent_frame = src_info.parent_frame;
+ info->frame_name = src_info.frame_name;
+ info->iter_level = src_info.iter_level;
+}
+
+// Constructs a control loop. Returns a struct containing the newly created
+// enter, merge, and switch nodes. The enter and merge nodes are used in the
+// recursive construction of control loops for nested frames (loops). The
+// switch node will be connected to the LoopCond node. The merge node will
+// be connected to all the recvs of the same frame by control edges when
+// the actual partitioning happens.
+Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
+ const Edge* edge, Node* loop_cond,
+ std::vector<ControlFlowInfo>* cf_info,
+ ControlLoop* loop) {
+ Status status;
+ GraphDefBuilder::Options bopts(g, &status);
+ const ControlFlowInfo& src_info = (*cf_info)[src->id()];
+ const string& device_name = edge->dst()->assigned_device_name();
+ const string& frame_name = src_info.frame_name;
+ int parallel_iterations;
+ status = GetNodeAttr(src_info.frame->def(), "parallel_iterations",
+ &parallel_iterations);
+ if (!status.ok()) return status;
+
+ // The names of the nodes to be added.
+ const string& enter_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& merge_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& switch_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
+
+ // Add the nodes to the graph g.
+ Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
+ parallel_iterations, &status);
+ if (!status.ok()) return status;
+ Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
+ device_name, &status);
+ if (!status.ok()) return status;
+ Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
+ bopts.WithName(switch_name));
+ if (!status.ok()) return status;
+ Node* next =
+ AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
+ if (!status.ok()) return status;
+
+ // Add control flow info for these new nodes:
+ AddControlFlowInfo(enter, src, cf_info);
+ AddControlFlowInfo(merge, src, cf_info);
+ AddControlFlowInfo(switch_node, src, cf_info);
+ AddControlFlowInfo(next, src, cf_info);
+
+ // Add input edges for the newly created merge node:
+ g->AddEdge(enter, 0, merge, 0);
+ g->AddEdge(next, 0, merge, 1);
+
+ loop->enter = enter;
+ loop->merge = merge;
+ loop->switch_node = switch_node;
+ return Status::OK();
+}
+
+// Build memory and device type info for every node in the graph.
+// TODO(yuanbyu): It might be simpler if we convert MemoryType to
+// DeviceType for the inputs/outputs of each node.
+Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
+ Status status;
+ MemoryTypeVector input_memory_types;
+ MemoryTypeVector output_memory_types;
+
+ info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
+ for (const Node* node : g.nodes()) {
+ if (!node->IsOp()) continue; // Skip Sink/Source nodes.
+
+ DeviceNameUtils::ParsedName parsed;
+ if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
+ &parsed)) {
+ return errors::Internal("Malformed assigned device '",
+ node->assigned_device_name(), "'");
+ }
+
+ input_memory_types.clear();
+ input_memory_types.resize(node->num_inputs());
+ output_memory_types.clear();
+ output_memory_types.resize(node->num_outputs());
+ status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type),
+ node->def(), &input_memory_types,
+ &output_memory_types);
+ if (!status.ok()) return status;
+
+ int node_id = node->id();
+ info->device_types[node_id] = DeviceType(parsed.type);
+ for (size_t i = 0; i < input_memory_types.size(); ++i) {
+ info->input_types[{node_id, i}] = input_memory_types[i];
+ }
+ for (size_t i = 0; i < output_memory_types.size(); ++i) {
+ info->output_types[{node_id, i}] = output_memory_types[i];
+ }
+ }
+ return status;
+}
+
+// Each participating device needs to decide a) if there is a next iteration,
+// and b) if the loop terminates. We take the approach to encode this control
+// flow logic in the dataflow graph. There are at least two possible encodings.
+// In a completely decentralized encoding, the participants communicate peer
+// to peer. The other encoding uses a frame leader (the participant who owns
+// the pivot termination predicate) to broadcast the termination condition to
+// all the participants. For now we take the latter because it is simpler.
+//
+// TODO(yuanbyu): The correctness of this construction is rather subtle. I got
+// it wrong many times so it would be nice to write a proof to be sure.
+Status AddControlFlow(const PartitionOptions& opts, Graph* g,
+ GraphInfo* g_info) {
+ Status status;
+ GraphDefBuilder::Options bopts(g, &status);
+ std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
+
+ // Build the control flow info for every node.
+ status = BuildControlFlowInfo(g, &cf_info);
+ if (!status.ok()) return status;
+
+ // The map from frames to their LoopCond nodes.
+ std::unordered_map<string, Node*> frame_cond_map;
+ int num_node_ids = g->num_node_ids();
+ for (int i = 0; i < num_node_ids; ++i) {
+ Node* node = g->FindNodeId(i);
+ if (node == nullptr) continue;
+
+ if (IsLoopCond(node)) {
+ const string& frame_name = cf_info[node->id()].frame_name;
+ DCHECK(!frame_name.empty());
+ frame_cond_map[frame_name] = node;
+ }
+ }
+
+ // Add all control loops for cross-device frames.
+ // A control loop is added only when there is a cross-device edge in a
+ // non-root frame. Nothing is added if there is no loops. We also don't
+ // add anything for a frame that is completely local to a device. For
+ // nested loops, we stack the control loops together by connecting
+ // the merge of the outer loop to the enter of the inner loop.
+ //
+ // A map from <frame_name, device_name> to ControlLoop.
+ std::unordered_map<string, ControlLoop> control_loops;
+ int num_edge_ids = g->num_edge_ids();
+ for (int i = 0; i < num_edge_ids; ++i) {
+ const Edge* edge = g->FindEdgeId(i);
+ if (edge == nullptr) continue;
+
+ const Node* src = edge->src();
+ const Node* dst = edge->dst();
+ // Skip Sink/Source nodes.
+ if (!src->IsOp() || !dst->IsOp()) continue;
+
+ const string& src_device = src->assigned_device_name();
+ const string& dst_device = dst->assigned_device_name();
+ // Skip local edges.
+ if (src_device == dst_device) continue;
+
+ const string& src_frame = cf_info[src->id()].frame_name;
+ const string& dst_frame = cf_info[dst->id()].frame_name;
+ // Skip if src and dst are not in the same frame.
+ if (src_frame.empty() || src_frame != dst_frame) {
+ continue;
+ }
+
+ // Add the control loop. Start by adding the control loop for the
+ // current frame if needed, and recursively adding the control loop
+ // for its outer frame when nested.
+ ControlLoop child_loop;
+ while (true) {
+ const string& curr_frame = cf_info[src->id()].frame_name;
+ if (curr_frame.empty()) {
+ // We have reached the root frame.
+ if (child_loop.merge != nullptr) {
+ const string& node_name = opts.new_name(edge->dst()->name());
+ const string& device_name = edge->dst()->assigned_device_name();
+ Node* const_node =
+ AddControlConst(device_name, bopts.WithName(node_name));
+ if (!status.ok()) return status;
+ AddControlFlowInfo(const_node, src, &cf_info);
+ g->AddEdge(const_node, 0, child_loop.enter, 0);
+ }
+ break;
+ }
+
+ const string& cl_key = strings::StrCat(curr_frame, "$$", dst_device);
+ auto it = control_loops.find(cl_key);
+ if (it != control_loops.end()) {
+ if (child_loop.enter != nullptr) {
+ g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
+ }
+ break;
+ }
+
+ // Get the frame's LoopCond.
+ auto cond_it = frame_cond_map.find(curr_frame);
+ if (cond_it == frame_cond_map.end()) {
+ return errors::InvalidArgument(
+ "A cross-device loop must have a pivot predicate: ", curr_frame);
+ }
+ Node* loop_cond = cond_it->second;
+
+ // Add the control loop.
+ ControlLoop curr_loop;
+ status =
+ AddControlLoop(opts, g, src, edge, loop_cond, &cf_info, &curr_loop);
+ if (!status.ok()) return status;
+ control_loops[cl_key] = curr_loop;
+
+ if (child_loop.enter != nullptr) {
+ // Connect the merge of the outer loop to the enter of the inner.
+ g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
+ }
+ src = cf_info[src->id()].parent_frame;
+ child_loop = curr_loop;
+ }
+ }
+
+ // For a cross-device edge, on the dst device, add a control edge
+ // from the merge node of the control loop to dst. If a send/recv is
+ // introduced for this edge in future partitioning, we delete this
+ // control edge and add a new control edge from the merge to the recv.
+ num_edge_ids = g->num_edge_ids();
+ for (int i = 0; i < num_edge_ids; ++i) {
+ const Edge* edge = g->FindEdgeId(i);
+ if (edge == nullptr) continue;
+
+ const Node* src = edge->src();
+ Node* dst = edge->dst();
+ // Skip Sink/Source nodes.
+ if (!src->IsOp() || !dst->IsOp()) continue;
+
+ const string& src_device = src->assigned_device_name();
+ const string& dst_device = dst->assigned_device_name();
+ if (src_device != dst_device) {
+ const string& src_frame = cf_info[src->id()].frame_name;
+ const string& dst_frame = cf_info[dst->id()].frame_name;
+ if (!src_frame.empty() && src_frame == dst_frame) {
+ const string& cl_key = strings::StrCat(dst_frame, "$$", dst_device);
+ ControlLoop loop = control_loops[cl_key];
+ DCHECK(loop.enter != nullptr);
+ g->AddControlEdge(loop.merge, dst);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // end namespace
+
+Status AddControlEdges(const PartitionOptions& opts,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Status status;
+ // TODO(yuanbyu): Very naive for now. To be improved.
+ const int num_epochs = 100;
+ const int prefetch = 6;
+
+ typedef std::pair<const NodeDef*, int64> NodeStartTime;
+ for (auto& part : *partitions) {
+ GraphDef* gdef = &part.second;
+
+ std::vector<NodeStartTime> start_times;
+ start_times.resize(gdef->node_size());
+ for (int n = 0; n < gdef->node_size(); ++n) {
+ const NodeDef& ndef = gdef->node(n);
+ int64 start_time;
+ status = GetNodeAttr(ndef, "_start_time", &start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ start_times[n] = std::make_pair(&ndef, start_time);
+ }
+
+ // Sort the nodes based on their start times.
+ std::sort(
+ start_times.begin(), start_times.end(),
+ [](NodeStartTime x, NodeStartTime y) { return x.second < y.second; });
+
+ // Add a dummy node for every epoch, and add a control edge from the
+ // "last" node in the preceding epoch to the dummy node.
+ string device_name = gdef->node(0).device();
+ int64 makespan = start_times.back().second;
+ int64 resolution = (makespan / num_epochs) + 1;
+
+ int i = 0;
+ int j = 0;
+ std::vector<NodeDef*> dummys;
+ while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
+ if (i * resolution > start_times[j].second) {
+ j++;
+ } else {
+ NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
+ i * resolution, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ dummys.push_back(dummy);
+ if (j > 0) {
+ string src_name = start_times[j - 1].first->name();
+ AddInput(dummy, src_name, Graph::kControlSlot);
+ }
+ i++;
+ }
+ }
+
+ // Finally, add the control edges to recvs.
+ for (int n = 0; n < gdef->node_size(); ++n) {
+ NodeDef* ndef = gdef->mutable_node(n);
+ if (ndef->op() == "_Recv") {
+ int64 start_time;
+ status = GetNodeAttr(*ndef, "_start_time", &start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ int recv_epoch = start_time / resolution;
+ if (recv_epoch >= prefetch) {
+ NodeDef* dummy = dummys[recv_epoch - prefetch];
+ AddInput(ndef, dummy->name(), Graph::kControlSlot);
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status Partition(const PartitionOptions& opts, Graph* g,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Status status;
+ partitions->clear();
+
+ GraphInfo g_info;
+ if (!opts.control_flow_added) {
+ // Add the "code" for distributed execution of control flow. Code is
+ // added only for the frames that are placed on multiple devices. The
+ // new graph is an equivalent transformation of the original graph and
+ // has the property that it can be subsequently partitioned arbitrarily
+ // (down to the level of individual device) for distributed execution.
+ status = AddControlFlow(opts, g, &g_info);
+ if (!status.ok()) return status;
+ }
+ // At this point, all the graph mutations have been done. Build memory
+ // and device type info for every node and edge in the graph.
+ status = BuildMemoryDeviceInfo(*g, &g_info);
+ if (!status.ok()) return status;
+
+ string dstp;
+ std::vector<const Edge*> inputs;
+ DupRecvTable dup_recv(3);
+ // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
+ // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
+ // edge to dst. We will add a control edge for every pair in
+ // (ref_recvs x ref_control_inputs).
+ std::vector<NodeDef*> ref_recvs;
+ std::vector<string> ref_control_inputs;
+
+ int32 num_data = 0;
+ int32 num_control = 0;
+ for (const Node* dst : g->nodes()) {
+ if (!dst->IsOp()) continue; // Skip Sink/Source nodes.
+
+ dstp = opts.node_to_loc(dst);
+ GraphDef* dst_graph = &(*partitions)[dstp];
+ NodeDef* dst_def = dst_graph->add_node();
+ *dst_def = dst->def();
+ dst_def->set_device(dst->assigned_device_name());
+ dst_def->clear_input(); // Inputs are filled below
+ if (opts.need_to_record_start_times) {
+ int64 start_time = opts.start_times[dst->id()].value();
+ AddNodeAttr("_start_time", start_time, dst_def);
+ }
+
+ // Arrange the incoming edges to dst so that input[i] holds the
+ // input flowing into slot numbered i. Trailing entries in input[]
+ // hold control edges.
+ inputs.clear();
+ inputs.resize(dst->num_inputs(), nullptr);
+ ref_recvs.clear();
+ ref_control_inputs.clear();
+ const Edge* control_flow_edge = nullptr;
+ for (const Edge* edge : dst->in_edges()) {
+ if (edge->IsControlEdge()) {
+ if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
+ // This is one of the control edges added for control flow. There
+ // can be multiple such edges as the dest node may have multiple
+ // remote inputs. We will just take one and ignore the others.
+ control_flow_edge = edge;
+ } else {
+ inputs.push_back(edge);
+ }
+ } else {
+ DCHECK(inputs[edge->dst_input()] == nullptr);
+ inputs[edge->dst_input()] = edge;
+ }
+ }
+
+ // Process in order so that all data edges are added as inputs to
+ // dst in Edge::dst_input() order.
+ bool recv_added = false;
+ for (const Edge* edge : inputs) {
+ const Node* src = edge->src();
+ if (!src->IsOp()) continue; // Skip Sink/Source nodes.
+
+ GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
+ if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
+ // Same partition and compatible memory types:
+ AddInput(dst_def, src->name(), edge->src_output());
+ if (edge->IsControlEdge() ||
+ !IsRefType(src->output_type(edge->src_output()))) {
+ ref_control_inputs.push_back(src->name());
+ }
+ continue;
+ }
+
+ int64 send_start_time = 0;
+ int64 recv_start_time = 0;
+ if (opts.scheduling_for_recvs) {
+ if (opts.need_to_record_start_times) {
+ send_start_time = opts.start_times[src->id()].value();
+ recv_start_time = opts.start_times[dst->id()].value();
+ } else {
+ status = GetNodeAttr(src->def(), "_start_time", &send_start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+ }
+
+ // Check whether there is already a send/recv pair transferring
+ // the same tensor/control from the src to dst partition.
+ const bool on_host = IsDstInputOnHost(edge, g_info);
+ DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
+ auto iter = dup_recv.find(key);
+ if (iter != dup_recv.end()) {
+ // We found one. Reuse the data/control transferred already.
+ const string& recv_node_name = iter->second.recv->name();
+ if (edge->IsControlEdge()) {
+ AddInput(dst_def, recv_node_name, Graph::kControlSlot);
+ } else {
+ AddInput(dst_def, recv_node_name, 0);
+ }
+ // We want the start_time for the recv to be the smallest of the start
+ // times of it's consumers. So we update this whenever we use a recv,
+ // and write it out to the attribute at the end of the subroutine
+ if (iter->second.start_time > recv_start_time) {
+ iter->second.start_time = recv_start_time;
+ }
+ continue;
+ }
+
+ NodeDefBuilder::NodeOut send_from;
+ if (edge->IsControlEdge()) {
+ // Insert a dummy const node that will generate a tiny
+ // data element to be sent from send to recv.
+ VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
+ << src->name() << "] -> " << dst->assigned_device_name() << "["
+ << dst->name() << "]";
+ NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
+ if (!status.ok()) return status;
+ // Set the start time for this dummy node.
+ if (opts.scheduling_for_recvs) {
+ AddNodeAttr("_start_time", send_start_time, dummy);
+ }
+ AddInput(dummy, src->name(), Graph::kControlSlot);
+ send_from.Reset(dummy->name(), 0, DT_FLOAT);
+ } else {
+ send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
+ }
+
+ // Need to split edge by placing matching send/recv nodes on
+ // the src/dst sides of the edge.
+ NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
+ send_start_time, &status);
+ if (!status.ok()) return status;
+
+ NodeDef* real_recv = nullptr;
+ NodeDef* recv =
+ AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
+ if (!status.ok()) return status;
+
+ // Fix up the control flow edge. Redirect it to the recv.
+ // NOTE(yuanbyu): 'real_recv' must be the real recv node.
+ recv_added = true;
+ if (control_flow_edge != nullptr) {
+ AddInput(real_recv, control_flow_edge->src()->name(),
+ Graph::kControlSlot);
+ }
+
+ // For same device send/recv, add a control edge from send to recv.
+ // This prevents the asynchronous recv kernel from being scheduled
+ // immediately.
+ if (src_graph == dst_graph) {
+ AddInput(real_recv, send->name(), Graph::kControlSlot);
+ }
+
+ if (!edge->IsControlEdge() &&
+ IsRefType(src->output_type(edge->src_output()))) {
+ // If src is of ref type and the edge is not a control edge, dst has
+ // read semantics and therefore we must control the recv.
+ ref_recvs.push_back(real_recv);
+ } else {
+ // Memorize the send/recv pair, only if this is not a "ref" edge.
+ // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
+ // for now we don't do it.
+ dup_recv[key] = {recv, real_recv, recv_start_time};
+ ref_control_inputs.push_back(recv->name());
+ }
+
+ if (edge->IsControlEdge()) {
+ ++num_control;
+ AddInput(dst_def, recv->name(), Graph::kControlSlot);
+ } else {
+ ++num_data;
+ AddInput(dst_def, recv->name(), 0);
+ }
+ }
+
+ // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
+ // NOTE(yuanbyu): Adding these control edges should not introduce
+ // deadlocks. 'dst' has implicit "read" nodes that, when we split
+ // across devices, are made explicit; Retargettig the dependencies
+ // to 'dst' to those nodes would not introduce cycles if there isn't
+ // one before the transformation.
+ // NOTE(yuanbyu): This may impact performance because it defers the
+ // execution of recvs until all the other inputs become available.
+ AddReadControl(ref_recvs, ref_control_inputs);
+
+ // Add back this control edge for control flow if not used.
+ if (!recv_added && (control_flow_edge != nullptr)) {
+ AddInput(dst_def, control_flow_edge->src()->name(), Graph::kControlSlot);
+ }
+ }
+
+ // Set the start times for recvs at the very end.
+ if (opts.scheduling_for_recvs) {
+ for (auto& it : dup_recv) {
+ AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
+ if (it.second.real_recv != it.second.recv) {
+ AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
+ }
+ }
+ }
+
+ VLOG(1) << "Added send/recv: controls=" << num_control
+ << ", data=" << num_data;
+ return Status::OK();
+}
+
+} // namespace tensorflow