diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-31 14:21:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-31 15:33:19 -0700 |
commit | f26dcdc819e903267cb03c33c2f09876d0003f52 (patch) | |
tree | 0445dc522136f3fcc90c50eb24106e62c961db9f /tensorflow/core/graph/graph_partition.cc | |
parent | 5a8af550174de742eedde4288be1ba4bbb1042d1 (diff) |
Split code to compute control flow structure of a graph out of graph_partition.cc and into its own module (control_flow.{cc,h}).
Change: 137756517
Diffstat (limited to 'tensorflow/core/graph/graph_partition.cc')
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 175 |
1 files changed, 38 insertions, 137 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 3275cde762..488fe47f38 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/node_builder.h" @@ -72,13 +74,6 @@ struct RecvInfo { typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq> DupRecvTable; -// Control flow info for a graph node. -struct ControlFlowInfo { - const Node* frame = nullptr; // frame of a node - const Node* parent_frame = nullptr; // parent frame of a node - string frame_name; // frame name of a node -}; - struct PairIntHash { public: std::size_t operator()(const std::pair<int, int>& x) const { @@ -315,144 +310,48 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, // TODO(yuanbyu): In this case, we don't respect the requested device in // the GraphDef for these nodes. Ideally, the placer would enforce the // colocation to render this unnecessary. -void OptimizeControlFlowColocation(Node* node) { - if (IsSwitch(node)) { - for (const Edge* in_edge : node->in_edges()) { - if (in_edge->dst_input() == 0) { - // Colocate with the data input. - node->set_assigned_device_name(in_edge->src()->assigned_device_name()); - return; - } - } - } else if (IsExit(node)) { - for (const Edge* in_edge : node->in_edges()) { - if (!in_edge->IsControlEdge()) { - // Colocate with upstream node. - node->set_assigned_device_name(in_edge->src()->assigned_device_name()); - return; - } - } - } else { - if ((IsEnter(node) && !IsRefType(node->input_type(0))) || - IsNextIteration(node)) { - const Edge* data_edge = nullptr; - for (const Edge* out_edge : node->out_edges()) { - if (!out_edge->IsControlEdge()) { - if (data_edge) { - data_edge = nullptr; - return; - } - data_edge = out_edge; +void OptimizeControlFlowColocation(Graph* graph) { + auto visit = [](Node* node) { + if (IsSwitch(node)) { + for (const Edge* in_edge : node->in_edges()) { + if (in_edge->dst_input() == 0) { + // Colocate with the data input. + node->set_assigned_device_name( + in_edge->src()->assigned_device_name()); + return; } } - // Colocate if there is only one downstream data node. - if (data_edge) { - node->set_assigned_device_name( - data_edge->dst()->assigned_device_name()); - } - } - } -} - -// Assign to each node the name of the frame and the level it belongs to. -// We check the well-formedness of the graph: All inputs to a node must -// come from the same frame and have the same "static" iteration level. -// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level -// 0. This essentially means there can't be multiple serial Nexts in -// an iteration, which all sane front-ends should satisfy. -Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) { - info->clear(); - info->resize(g->num_node_ids()); - - std::vector<const Node*> parent_nodes; - parent_nodes.resize(g->num_node_ids()); - - Node* src_node = g->source_node(); - ControlFlowInfo& src_info = (*info)[src_node->id()]; - src_info.frame = src_node; - src_info.parent_frame = src_node; - - string frame_name; - std::deque<Node*> ready; - ready.push_back(src_node); - while (!ready.empty()) { - Node* curr_node = ready.front(); - ready.pop_front(); - const ControlFlowInfo& curr_info = (*info)[curr_node->id()]; - const Node* frame = curr_info.frame; - const Node* parent = curr_info.parent_frame; - frame_name = curr_info.frame_name; - - if (IsExit(curr_node)) { - // Exit to the parent frame. - const ControlFlowInfo& parent_info = (*info)[parent->id()]; - frame = parent_info.frame; - parent = parent_info.parent_frame; - frame_name = parent_info.frame_name; - } - - // Optimize colocation for control flow nodes. - OptimizeControlFlowColocation(curr_node); - - for (const Edge* out_edge : curr_node->out_edges()) { - Node* out = out_edge->dst(); - int out_id = out->id(); - ControlFlowInfo* out_info = &(*info)[out_id]; - const Node* out_parent = out_info->parent_frame; - bool is_visited = (parent_nodes[out_id] != nullptr); - - // Skip Sink/Source nodes. - if (!out->IsOp()) continue; - - // Add to ready queue if not seen. - if (!is_visited) { - parent_nodes[out->id()] = curr_node; - ready.push_back(out); + } else if (IsExit(node)) { + for (const Edge* in_edge : node->in_edges()) { + if (!in_edge->IsControlEdge()) { + // Colocate with upstream node. + node->set_assigned_device_name( + in_edge->src()->assigned_device_name()); + return; + } } - - // Process the node 'out'. - if (IsEnter(out)) { - if (is_visited) { - const string& parent_frame = (*info)[out_parent->id()].frame_name; - if (parent_frame != frame_name) { - return errors::InvalidArgument( - "The node '", out->name(), - "' has inputs from different " - "frames. The input '", - curr_node->name(), "' is in frame '", frame_name, - "'. The input '", parent_nodes[out->id()]->name(), - "' is in frame '", parent_frame, "'."); - } - } else { - out_info->frame = out; - out_info->parent_frame = frame; - TF_RETURN_IF_ERROR( - GetNodeAttr(out->def(), "frame_name", &out_info->frame_name)); - if (out_info->frame_name.empty()) { - return errors::InvalidArgument("The Enter node ", out->name(), - " must have a frame name."); + } else { + if ((IsEnter(node) && !IsRefType(node->input_type(0))) || + IsNextIteration(node)) { + const Edge* data_edge = nullptr; + for (const Edge* out_edge : node->out_edges()) { + if (!out_edge->IsControlEdge()) { + if (data_edge) { + data_edge = nullptr; + return; + } + data_edge = out_edge; } } - } else { - if (is_visited) { - if (out_info->frame_name != frame_name) { - return errors::InvalidArgument( - "The node '", out->name(), - "' has inputs from different " - "frames. The input '", - curr_node->name(), "' is in frame '", frame_name, - "'. The input '", parent_nodes[out->id()]->name(), - "' is in frame '", out_info->frame_name, "'."); - } - } else { - out_info->frame = frame; - out_info->parent_frame = parent; - out_info->frame_name = frame_name; + // Colocate if there is only one downstream data node. + if (data_edge) { + node->set_assigned_device_name( + data_edge->dst()->assigned_device_name()); } } } - } - return Status::OK(); + }; + DFS(*graph, visit, {}); } string ControlLoopName(const string& name) { @@ -689,6 +588,8 @@ Status AddControlFlow(const PartitionOptions& opts, Graph* g, status = BuildControlFlowInfo(g, &cf_info); if (!status.ok()) return status; + OptimizeControlFlowColocation(g); + // The map from frames to their LoopCond nodes. std::unordered_map<string, Node*> frame_cond_map; int num_node_ids = g->num_node_ids(); |