aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-09-11 11:12:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 11:16:14 -0700
commit7e5ae7109f558cafaa87e3bcebabfc0e1f67aabc (patch)
tree276ecdab1377dbb2fcba28b1b0a359932e813bac /tensorflow/compiler/tf2xla
parent232fcbb6fcf8c5ab3713261a0ef9a771b270753e (diff)
Handle control dependencies from switch nodes as nonreachable.
In DeleteReachableNodes all the nodes reachable from nodes deleted from the graph during extraction was considered. But if a node had a control dependency on a switch, then that node doesn't conditionally execute based on the switch predicate and is not part of the conditional extracted, so it should be considered reachable for deletion. Additionally perform sweep of graph for dead nodes together with deleting the reachable nodes to keep all dead node deletion together. Also delete a dead function and ensure all graph dumps from functionalize_cond has that as prefix. PiperOrigin-RevId: 212485183
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc71
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.h13
2 files changed, 54 insertions, 30 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index 0911550f1f..3ad1d1d5b4 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -217,10 +217,6 @@ void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
added_node_ancestorid_mapping_[node->id()] = id;
}
-const StateMap::CondState& StateMap::LookupState(const Node* node) const {
- return *LookupCondId(node);
-}
-
void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
string StateMap::CondStateToString(const Node* node) const {
@@ -791,7 +787,6 @@ Status Conditional::BuildAndReplace(Graph* graph,
TF_RETURN_IF_ERROR(AddInputEdges(graph));
TF_RETURN_IF_ERROR(AddOutputEdges(graph));
TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
- for (Node* m : merges_) state_map_->MarkDead(m);
// Check that the if_node doesn't feed into itself.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@@ -1056,7 +1051,6 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
" has no non-dead inputs.");
}
state_map_.MarkDead(node);
- delete_nodes_.push_back(node->id());
VLOG(5) << "removing redundant merge: " << node->name();
while (!node->out_edges().empty()) {
const Edge* oe = *node->out_edges().begin();
@@ -1132,7 +1126,6 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
}
} else if (BranchType(switch_branch) != b) {
state_map_.MarkDead(dst_node);
- delete_nodes_.push_back(dst_node->id());
continue;
}
graph_->AddEdge(
@@ -1154,7 +1147,7 @@ Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
<< " @ " << state_map_.AncestorStateToString(dst);
- if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it");
+ if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
}
return Status::OK();
}
@@ -1184,23 +1177,62 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
return Status::OK();
}
-void FunctionalizeCond::DeleteReachableNodes() {
+void FunctionalizeCond::DeleteReachableAndDeadNodes(
+ const std::vector<int>& switch_ids, const std::vector<Node*>& merge_order) {
// Delete all nodes that have been extracted or are reachable from
// deleted/dead nodes. The input and outgoing edges should have already been
// removed.
+ std::deque<int> delete_nodes;
std::vector<bool> deleted(graph_->num_node_ids(), false);
// Don't try to delete source or sink nodes.
deleted[graph_->kSourceId] = true;
deleted[graph_->kSinkId] = true;
- while (!delete_nodes_.empty()) {
- int d_id = delete_nodes_.front();
- delete_nodes_.pop_front();
+
+ // All remaining Switch nodes are not reachable from a Merge node and
+ // removed. This is to account for dead Switch nodes.
+ for (int s_id : switch_ids) {
+ Node* s = graph_->FindNodeId(s_id);
+ if (s == nullptr) continue;
+ for (const Edge* e : s->out_edges()) {
+ // Control outputs of switch nodes (which are unconditionally executed if
+ // the switch is) are not removed as they need not be part of a
+ // conditional.
+ if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+ }
+ deleted[s_id] = true;
+ graph_->RemoveNode(s);
+ }
+
+ // All merge nodes should have been transformed at this point and we remove
+ // them from the graph here.
+ for (Node* m : merge_order) {
+ for (const Edge* e : m->out_edges()) {
+ // Similar to control outputs of switch nodes don't remove control
+ // outputs of merge nodes.
+ // TODO(jpienaar): Check cases where output edges still exist here vs
+ // being removed in AddOutputEdges.
+ if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+ }
+ deleted[m->id()] = true;
+ graph_->RemoveNode(m);
+ }
+
+ // Enqueue all the dead nodes.
+ for (Node* n : graph_->nodes()) {
+ if (state_map_.IsDead(state_map_.LookupCondId(n))) {
+ delete_nodes.push_back(n->id());
+ }
+ }
+
+ while (!delete_nodes.empty()) {
+ int d_id = delete_nodes.front();
+ delete_nodes.pop_front();
if (deleted[d_id]) continue;
Node* d = graph_->FindNodeId(d_id);
// Switch and Merge nodes could have been deleted already.
if (d == nullptr) continue;
for (const Edge* e : d->out_edges()) {
- delete_nodes_.push_back(e->dst()->id());
+ delete_nodes.push_back(e->dst()->id());
}
deleted[d_id] = true;
graph_->RemoveNode(d);
@@ -1274,7 +1306,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
}
TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
- if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id");
+ if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
// Sort the merge nodes from innermost outwards.
SortMergeNodes(&merge_order);
@@ -1312,11 +1344,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
}
- // All remaining Switch nodes are not reachable from a Merge node and
- // removed. This is to account for dead Switch nodes.
- for (int s_id : switch_ids) delete_nodes_.push_back(s_id);
- for (Node* m : merge_order) delete_nodes_.push_back(m->id());
- DeleteReachableNodes();
+ DeleteReachableAndDeadNodes(switch_ids, merge_order);
return Status::OK();
}
@@ -1331,8 +1359,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
state_map_.AncestorStateToString(n)));
}
LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
- << dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name),
- *graph_, library_);
+ << dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_cond_", name), *graph_,
+ library_);
}
Status FunctionalizeCond::Functionalize(Graph* graph,
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
index 28301150ea..1899808940 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.h
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -91,10 +91,6 @@ class StateMap {
// Resets the AncestorId for a given node.
void ResetAncestorId(const Node* node, AncestorId id);
- // Returns the CondState for a Node.
- // REQUIRES: node has a non-empty CondState.
- const CondState& LookupState(const Node* node) const;
-
// Marks `node` as dead.
void MarkDead(const Node* node);
@@ -221,8 +217,10 @@ class FunctionalizeCond {
// nesting depth.
void SortMergeNodes(std::vector<Node*>* merge_order);
- // Deletes all nodes in/consumers of `delete_nodes_`.
- void DeleteReachableNodes();
+ // Deletes all nodes in/consumers reachable from switch/merge nodes that were
+ // extracted.
+ void DeleteReachableAndDeadNodes(const std::vector<int>& switch_ids,
+ const std::vector<Node*>& merge_order);
// Member used to unique the CondState to a unique CondId (AncestorState to a
// unique AncestorId) and keep track of CondState/CondId
@@ -232,9 +230,6 @@ class FunctionalizeCond {
// Mapping from merge nodes to predicate.
std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
- // Nodes to be deleted.
- std::deque<int> delete_nodes_;
-
FunctionLibraryDefinition* library_;
Graph* graph_;