aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner.cc1
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py13
2 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc
index 97f456d2a6..3311e97010 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner.cc
@@ -59,6 +59,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!nodes_to_preserve.empty()) {
std::vector<string> terminal_nodes(nodes_to_preserve.begin(),
nodes_to_preserve.end());
+ std::sort(terminal_nodes.begin(), terminal_nodes.end());
bool ill_formed = false;
std::vector<const NodeDef*> keep =
ComputeTransitiveFanin(item.graph, terminal_nodes, &ill_formed);
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index f4f781ad7e..3ee4d7807e 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -52,7 +52,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
- variables.Variable(
+ a1 = variables.Variable(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
@@ -68,12 +68,11 @@ class PyWrapOptimizeGraphTest(test.TestCase):
# Check that the nodes referenced in various collections have been preserved
self.assertEqual(len(optimized_graph.node), 5)
- # Disabled this part of the test until we figure out why it fails on MacOS
- # self.assertEqual(a2.op.name, optimized_graph.node[0].name)
- # self.assertEqual(a1.op.name, optimized_graph.node[1].name)
- # self.assertEqual('Variable/initial_value', optimized_graph.node[2].name)
- # self.assertEqual(d.op.name, optimized_graph.node[3].name)
- # self.assertEqual('Variable/Assign', optimized_graph.node[4].name)
+ self.assertEqual(d.op.name, optimized_graph.node[0].name)
+ self.assertEqual(a1.op.name, optimized_graph.node[1].name)
+ self.assertEqual('Variable/initial_value', optimized_graph.node[2].name)
+ self.assertEqual(a2.op.name, optimized_graph.node[3].name)
+ self.assertEqual('Variable/Assign', optimized_graph.node[4].name)
if __name__ == '__main__':