diff options
author | Benoit Steiner <bsteiner@google.com> | 2018-02-21 14:29:27 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-21 14:34:01 -0800 |
commit | 042c60a564d014a19575884f2a0b2cba987b0f7a (patch) | |
tree | 2ec9a463fcaec439abb84ace832d61a52e5c6868 /tensorflow/python/grappler | |
parent | 6583044d980686c04a20085098b335c98618d106 (diff) |
Ensured that the model pruner outputs the nodes of the optimized graph in a
deterministic order
PiperOrigin-RevId: 186520272
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/tf_optimizer_test.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index f4f781ad7e..3ee4d7807e 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -52,7 +52,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): def testKeepNodes(self): g = ops.Graph() with g.as_default(): - variables.Variable( + a1 = variables.Variable( 1.0) # Must be preserved since it's in the collection 'variables'. a2 = constant_op.constant(0, shape=[50, 50], name='keep') ops.add_to_collection('a2', a2) # Explicitly add to collection. @@ -68,12 +68,11 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Check that the nodes referenced in various collections have been preserved self.assertEqual(len(optimized_graph.node), 5) - # Disabled this part of the test until we figure out why it fails on MacOS - # self.assertEqual(a2.op.name, optimized_graph.node[0].name) - # self.assertEqual(a1.op.name, optimized_graph.node[1].name) - # self.assertEqual('Variable/initial_value', optimized_graph.node[2].name) - # self.assertEqual(d.op.name, optimized_graph.node[3].name) - # self.assertEqual('Variable/Assign', optimized_graph.node[4].name) + self.assertEqual(d.op.name, optimized_graph.node[0].name) + self.assertEqual(a1.op.name, optimized_graph.node[1].name) + self.assertEqual('Variable/initial_value', optimized_graph.node[2].name) + self.assertEqual(a2.op.name, optimized_graph.node[3].name) + self.assertEqual('Variable/Assign', optimized_graph.node[4].name) if __name__ == '__main__': |