aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph_partition.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-31 14:21:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 15:33:19 -0700
commitf26dcdc819e903267cb03c33c2f09876d0003f52 (patch)
tree0445dc522136f3fcc90c50eb24106e62c961db9f /tensorflow/core/graph/graph_partition.cc
parent5a8af550174de742eedde4288be1ba4bbb1042d1 (diff)
Split code to compute control flow structure of a graph out of graph_partition.cc and into its own module (control_flow.{cc,h}).
Change: 137756517
Diffstat (limited to 'tensorflow/core/graph/graph_partition.cc')
-rw-r--r--tensorflow/core/graph/graph_partition.cc175
1 files changed, 38 insertions, 137 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 3275cde762..488fe47f38 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -72,13 +74,6 @@ struct RecvInfo {
typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq>
DupRecvTable;
-// Control flow info for a graph node.
-struct ControlFlowInfo {
- const Node* frame = nullptr; // frame of a node
- const Node* parent_frame = nullptr; // parent frame of a node
- string frame_name; // frame name of a node
-};
-
struct PairIntHash {
public:
std::size_t operator()(const std::pair<int, int>& x) const {
@@ -315,144 +310,48 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
// TODO(yuanbyu): In this case, we don't respect the requested device in
// the GraphDef for these nodes. Ideally, the placer would enforce the
// colocation to render this unnecessary.
-void OptimizeControlFlowColocation(Node* node) {
- if (IsSwitch(node)) {
- for (const Edge* in_edge : node->in_edges()) {
- if (in_edge->dst_input() == 0) {
- // Colocate with the data input.
- node->set_assigned_device_name(in_edge->src()->assigned_device_name());
- return;
- }
- }
- } else if (IsExit(node)) {
- for (const Edge* in_edge : node->in_edges()) {
- if (!in_edge->IsControlEdge()) {
- // Colocate with upstream node.
- node->set_assigned_device_name(in_edge->src()->assigned_device_name());
- return;
- }
- }
- } else {
- if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
- IsNextIteration(node)) {
- const Edge* data_edge = nullptr;
- for (const Edge* out_edge : node->out_edges()) {
- if (!out_edge->IsControlEdge()) {
- if (data_edge) {
- data_edge = nullptr;
- return;
- }
- data_edge = out_edge;
+void OptimizeControlFlowColocation(Graph* graph) {
+ auto visit = [](Node* node) {
+ if (IsSwitch(node)) {
+ for (const Edge* in_edge : node->in_edges()) {
+ if (in_edge->dst_input() == 0) {
+ // Colocate with the data input.
+ node->set_assigned_device_name(
+ in_edge->src()->assigned_device_name());
+ return;
}
}
- // Colocate if there is only one downstream data node.
- if (data_edge) {
- node->set_assigned_device_name(
- data_edge->dst()->assigned_device_name());
- }
- }
- }
-}
-
-// Assign to each node the name of the frame and the level it belongs to.
-// We check the well-formedness of the graph: All inputs to a node must
-// come from the same frame and have the same "static" iteration level.
-// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level
-// 0. This essentially means there can't be multiple serial Nexts in
-// an iteration, which all sane front-ends should satisfy.
-Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
- info->clear();
- info->resize(g->num_node_ids());
-
- std::vector<const Node*> parent_nodes;
- parent_nodes.resize(g->num_node_ids());
-
- Node* src_node = g->source_node();
- ControlFlowInfo& src_info = (*info)[src_node->id()];
- src_info.frame = src_node;
- src_info.parent_frame = src_node;
-
- string frame_name;
- std::deque<Node*> ready;
- ready.push_back(src_node);
- while (!ready.empty()) {
- Node* curr_node = ready.front();
- ready.pop_front();
- const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
- const Node* frame = curr_info.frame;
- const Node* parent = curr_info.parent_frame;
- frame_name = curr_info.frame_name;
-
- if (IsExit(curr_node)) {
- // Exit to the parent frame.
- const ControlFlowInfo& parent_info = (*info)[parent->id()];
- frame = parent_info.frame;
- parent = parent_info.parent_frame;
- frame_name = parent_info.frame_name;
- }
-
- // Optimize colocation for control flow nodes.
- OptimizeControlFlowColocation(curr_node);
-
- for (const Edge* out_edge : curr_node->out_edges()) {
- Node* out = out_edge->dst();
- int out_id = out->id();
- ControlFlowInfo* out_info = &(*info)[out_id];
- const Node* out_parent = out_info->parent_frame;
- bool is_visited = (parent_nodes[out_id] != nullptr);
-
- // Skip Sink/Source nodes.
- if (!out->IsOp()) continue;
-
- // Add to ready queue if not seen.
- if (!is_visited) {
- parent_nodes[out->id()] = curr_node;
- ready.push_back(out);
+ } else if (IsExit(node)) {
+ for (const Edge* in_edge : node->in_edges()) {
+ if (!in_edge->IsControlEdge()) {
+ // Colocate with upstream node.
+ node->set_assigned_device_name(
+ in_edge->src()->assigned_device_name());
+ return;
+ }
}
-
- // Process the node 'out'.
- if (IsEnter(out)) {
- if (is_visited) {
- const string& parent_frame = (*info)[out_parent->id()].frame_name;
- if (parent_frame != frame_name) {
- return errors::InvalidArgument(
- "The node '", out->name(),
- "' has inputs from different "
- "frames. The input '",
- curr_node->name(), "' is in frame '", frame_name,
- "'. The input '", parent_nodes[out->id()]->name(),
- "' is in frame '", parent_frame, "'.");
- }
- } else {
- out_info->frame = out;
- out_info->parent_frame = frame;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(out->def(), "frame_name", &out_info->frame_name));
- if (out_info->frame_name.empty()) {
- return errors::InvalidArgument("The Enter node ", out->name(),
- " must have a frame name.");
+ } else {
+ if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
+ IsNextIteration(node)) {
+ const Edge* data_edge = nullptr;
+ for (const Edge* out_edge : node->out_edges()) {
+ if (!out_edge->IsControlEdge()) {
+ if (data_edge) {
+ data_edge = nullptr;
+ return;
+ }
+ data_edge = out_edge;
}
}
- } else {
- if (is_visited) {
- if (out_info->frame_name != frame_name) {
- return errors::InvalidArgument(
- "The node '", out->name(),
- "' has inputs from different "
- "frames. The input '",
- curr_node->name(), "' is in frame '", frame_name,
- "'. The input '", parent_nodes[out->id()]->name(),
- "' is in frame '", out_info->frame_name, "'.");
- }
- } else {
- out_info->frame = frame;
- out_info->parent_frame = parent;
- out_info->frame_name = frame_name;
+ // Colocate if there is only one downstream data node.
+ if (data_edge) {
+ node->set_assigned_device_name(
+ data_edge->dst()->assigned_device_name());
}
}
}
- }
- return Status::OK();
+ };
+ DFS(*graph, visit, {});
}
string ControlLoopName(const string& name) {
@@ -689,6 +588,8 @@ Status AddControlFlow(const PartitionOptions& opts, Graph* g,
status = BuildControlFlowInfo(g, &cf_info);
if (!status.ok()) return status;
+ OptimizeControlFlowColocation(g);
+
// The map from frames to their LoopCond nodes.
std::unordered_map<string, Node*> frame_cond_map;
int num_node_ids = g->num_node_ids();