aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/functionalize_cond.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/functionalize_cond.h')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.h13
1 files changed, 4 insertions, 9 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
index 28301150ea..1899808940 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.h
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -91,10 +91,6 @@ class StateMap {
// Resets the AncestorId for a given node.
void ResetAncestorId(const Node* node, AncestorId id);
- // Returns the CondState for a Node.
- // REQUIRES: node has a non-empty CondState.
- const CondState& LookupState(const Node* node) const;
-
// Marks `node` as dead.
void MarkDead(const Node* node);
@@ -221,8 +217,10 @@ class FunctionalizeCond {
// nesting depth.
void SortMergeNodes(std::vector<Node*>* merge_order);
- // Deletes all nodes in/consumers of `delete_nodes_`.
- void DeleteReachableNodes();
+ // Deletes all nodes in/consumers reachable from switch/merge nodes that were
+ // extracted.
+ void DeleteReachableAndDeadNodes(const std::vector<int>& switch_ids,
+ const std::vector<Node*>& merge_order);
// Member used to unique the CondState to a unique CondId (AncestorState to a
// unique AncestorId) and keep track of CondState/CondId
@@ -232,9 +230,6 @@ class FunctionalizeCond {
// Mapping from merge nodes to predicate.
std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
- // Nodes to be deleted.
- std::deque<int> delete_nodes_;
-
FunctionLibraryDefinition* library_;
Graph* graph_;