diff options
-rw-r--r-- | tensorflow/compiler/tf2xla/functionalize_control_flow.cc | 55 |
1 files changed, 29 insertions, 26 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 6ef4860f35..40a484da09 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -731,11 +731,12 @@ string DebugString(const Graph& graph, FunctionalizeCond::ClusterHandle::Vector* clusters) { string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; std::map<FunctionalizeCond::ClusterHandle, string> subgraphs; + auto name = [](const Node* n) { + return strings::StrCat(n->type_string(), "_", n->id()); + }; for (Node* n : graph.nodes()) { - if (n->IsOp()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), - " [label=\"", n->name(), "\"];\n"); - } + strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"", + name(n), "\"];\n"); } for (auto kv : subgraphs) { strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", @@ -743,16 +744,11 @@ string DebugString(const Graph& graph, kv.first.ToString(), "\";\n", kv.second, "}\n"); } for (Node* n : graph.nodes()) { - if (!n->IsOp()) { - continue; - } for (Node* in : n->in_nodes()) { - if (in->IsOp()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } + strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { @@ -761,16 +757,24 @@ string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { return cluster.representative.ToString(); }; for (auto kv : clustered_graph) { - strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), - " (", kv.second.switch_nodes.size(), ", ", - kv.second.merge_nodes.size(), ")\"];\n"); + if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) { + strings::StrAppend( + &ret, kv.first.ToString(), " [label=\"", name(kv.second), + kv.second.switch_nodes.empty() + ? "" + : strings::StrCat(" switches=", kv.second.switch_nodes.size()), + kv.second.merge_nodes.empty() + ? "" + : strings::StrCat(" merges=", kv.second.merge_nodes.size()), + "\"];\n"); + } } for (auto kv : clustered_graph) { for (auto in : kv.second.in_nodes) { strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } bool IsDeadSwitch(const Node* node) { @@ -790,9 +794,6 @@ bool IsDeadSwitch(const Node* node) { void FunctionalizeCond::CreateClusters() { for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } if (IsSwitch(node)) { switch_nodes_.insert(node); } else if (IsMerge(node)) { @@ -825,6 +826,10 @@ void FunctionalizeCond::CreateClusters() { clusters_.at(node).Merge(&clusters_.at(in)); } } + // Group all source clusters together. + if (node->IsSource() || node->in_edges().empty()) { + clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId))); + } } } @@ -876,7 +881,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* in : node->in_nodes()) { ClusterHandle other_repr = Representative(in); // Skip source, sink and internal edges. - if (!in->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_in = clustered_graph_[other_repr]; @@ -887,7 +892,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* out : node->out_nodes()) { ClusterHandle other_repr = Representative(out); // Skip source, sink and internal edges. - if (!out->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_out = clustered_graph_[other_repr]; @@ -897,6 +902,7 @@ void FunctionalizeCond::CreateClusteredGraph() { } return cluster_node; }; + update_cluster_for_node(graph_->source_node()); for (Node* node : switch_nodes_) { update_cluster_for_node(node).switch_nodes.insert(node); } @@ -955,7 +961,7 @@ gtl::optional<FunctionalizeCond::Cluster*> FunctionalizeCond::GetSwitchCluster( for (Cluster* in : merge_cluster.in_nodes) { Cluster* cluster = in; if (in->switch_nodes.empty()) { - if (in->in_nodes.size() != 1) { + if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) { return gtl::nullopt; } // There is only a single `in` cluster. @@ -1292,11 +1298,8 @@ std::vector<std::pair<int, FunctionalizeCond::Cluster*>> FunctionalizeCond::SortedMergeNodes() { VLOG(2) << "ProcessClusteredGraph"; std::stack<std::pair<int, Cluster*>> stack; - for (auto& c : clustered_graph_) { - if (c.second.in_nodes.empty()) { - stack.push({0, &c.second}); - } - } + // Initialize with the source node. + stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]}); // Perform a depth-first traversal of the clustered graph computing the // switch-merge depth. |