aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2017-11-14 17:08:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-14 17:14:05 -0800
commit8ad5cc00f21eb9d6f1811d7ed771f6f042dba1ba (patch)
tree5b6a157580502448b76f5147f65ca4f9d29c0daf
parent48b25d0a1d71fb426b5765a88785b35a4327e4f5 (diff)
[TFXLA] Add source node and make GetSwitchCluster more conservative.
PiperOrigin-RevId: 175758538
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc55
1 files changed, 29 insertions, 26 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 6ef4860f35..40a484da09 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -731,11 +731,12 @@ string DebugString(const Graph& graph,
FunctionalizeCond::ClusterHandle::Vector* clusters) {
string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n";
std::map<FunctionalizeCond::ClusterHandle, string> subgraphs;
+ auto name = [](const Node* n) {
+ return strings::StrCat(n->type_string(), "_", n->id());
+ };
for (Node* n : graph.nodes()) {
- if (n->IsOp()) {
- strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(),
- " [label=\"", n->name(), "\"];\n");
- }
+ strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"",
+ name(n), "\"];\n");
}
for (auto kv : subgraphs) {
strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n",
@@ -743,16 +744,11 @@ string DebugString(const Graph& graph,
kv.first.ToString(), "\";\n", kv.second, "}\n");
}
for (Node* n : graph.nodes()) {
- if (!n->IsOp()) {
- continue;
- }
for (Node* in : n->in_nodes()) {
- if (in->IsOp()) {
- strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n");
- }
+ strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n");
}
}
- return strings::StrCat(ret, "}");
+ return strings::StrCat(ret, "} // end");
}
string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) {
@@ -761,16 +757,24 @@ string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) {
return cluster.representative.ToString();
};
for (auto kv : clustered_graph) {
- strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second),
- " (", kv.second.switch_nodes.size(), ", ",
- kv.second.merge_nodes.size(), ")\"];\n");
+ if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) {
+ strings::StrAppend(
+ &ret, kv.first.ToString(), " [label=\"", name(kv.second),
+ kv.second.switch_nodes.empty()
+ ? ""
+ : strings::StrCat(" switches=", kv.second.switch_nodes.size()),
+ kv.second.merge_nodes.empty()
+ ? ""
+ : strings::StrCat(" merges=", kv.second.merge_nodes.size()),
+ "\"];\n");
+ }
}
for (auto kv : clustered_graph) {
for (auto in : kv.second.in_nodes) {
strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n");
}
}
- return strings::StrCat(ret, "}");
+ return strings::StrCat(ret, "} // end");
}
bool IsDeadSwitch(const Node* node) {
@@ -790,9 +794,6 @@ bool IsDeadSwitch(const Node* node) {
void FunctionalizeCond::CreateClusters() {
for (Node* node : graph_->nodes()) {
- if (!node->IsOp()) {
- continue;
- }
if (IsSwitch(node)) {
switch_nodes_.insert(node);
} else if (IsMerge(node)) {
@@ -825,6 +826,10 @@ void FunctionalizeCond::CreateClusters() {
clusters_.at(node).Merge(&clusters_.at(in));
}
}
+ // Group all source clusters together.
+ if (node->IsSource() || node->in_edges().empty()) {
+ clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId)));
+ }
}
}
@@ -876,7 +881,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
for (const Node* in : node->in_nodes()) {
ClusterHandle other_repr = Representative(in);
// Skip source, sink and internal edges.
- if (!in->IsOp() || other_repr == repr) {
+ if (other_repr == repr) {
continue;
}
Cluster& cluster_node_in = clustered_graph_[other_repr];
@@ -887,7 +892,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
for (const Node* out : node->out_nodes()) {
ClusterHandle other_repr = Representative(out);
// Skip source, sink and internal edges.
- if (!out->IsOp() || other_repr == repr) {
+ if (other_repr == repr) {
continue;
}
Cluster& cluster_node_out = clustered_graph_[other_repr];
@@ -897,6 +902,7 @@ void FunctionalizeCond::CreateClusteredGraph() {
}
return cluster_node;
};
+ update_cluster_for_node(graph_->source_node());
for (Node* node : switch_nodes_) {
update_cluster_for_node(node).switch_nodes.insert(node);
}
@@ -955,7 +961,7 @@ gtl::optional<FunctionalizeCond::Cluster*> FunctionalizeCond::GetSwitchCluster(
for (Cluster* in : merge_cluster.in_nodes) {
Cluster* cluster = in;
if (in->switch_nodes.empty()) {
- if (in->in_nodes.size() != 1) {
+ if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) {
return gtl::nullopt;
}
// There is only a single `in` cluster.
@@ -1292,11 +1298,8 @@ std::vector<std::pair<int, FunctionalizeCond::Cluster*>>
FunctionalizeCond::SortedMergeNodes() {
VLOG(2) << "ProcessClusteredGraph";
std::stack<std::pair<int, Cluster*>> stack;
- for (auto& c : clustered_graph_) {
- if (c.second.in_nodes.empty()) {
- stack.push({0, &c.second});
- }
- }
+ // Initialize with the source node.
+ stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]});
// Perform a depth-first traversal of the clustered graph computing the
// switch-merge depth.