diff options
author | 2018-03-27 15:47:23 -0700 | |
---|---|---|
committer | 2018-03-27 15:50:26 -0700 | |
commit | b4742b76c386409c96c60172e6ca1c1534e2b4af (patch) | |
tree | 44d10b296e31754b6bc0ed1dc4f25fce73b320c6 | |
parent | 05ddf373980fae94a2c73cf93161332484e102fd (diff) |
Add node types for DFS traversal to catch more issues with deduping inputs to in-place ops.
PiperOrigin-RevId: 190687820
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 23e21855c8..5dd0b6f4b0 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1089,7 +1089,8 @@ namespace { bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) { const std::unordered_set<string> op_types_to_traverse = { - node.op(), "Identity", "IdentityN", "Reshape"}; + node.op(), "Identity", "IdentityN", "Reshape", + "ExpandDims", "Enter", "Switch", "Merge"}; int node_idx = graph_view.index(node.name()); std::set<int> node_fanout; graph_view.DepthFirstSearch(op_types_to_traverse, node_idx, &node_fanout); |