aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 15:47:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 15:50:26 -0700
commitb4742b76c386409c96c60172e6ca1c1534e2b4af (patch)
tree44d10b296e31754b6bc0ed1dc4f25fce73b320c6
parent05ddf373980fae94a2c73cf93161332484e102fd (diff)
Add node types for DFS traversal to catch more issues with deduping inputs to in-place ops.
PiperOrigin-RevId: 190687820
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 23e21855c8..5dd0b6f4b0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1089,7 +1089,8 @@ namespace {
bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) {
const std::unordered_set<string> op_types_to_traverse = {
- node.op(), "Identity", "IdentityN", "Reshape"};
+ node.op(), "Identity", "IdentityN", "Reshape",
+ "ExpandDims", "Enter", "Switch", "Merge"};
int node_idx = graph_view.index(node.name());
std::set<int> node_fanout;
graph_view.DepthFirstSearch(op_types_to_traverse, node_idx, &node_fanout);