aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-02-20 18:17:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-20 18:21:16 -0800
commit3e7ed13c2dac79c05a63a9c25e3c8eb6f1d99ac2 (patch)
treef2b7e3e480410e55a606dff894c81269272d1760 /tensorflow/python/grappler
parent205baa86fe9e559f458dcf534d18c80215890ecd (diff)
Make sure the nodes that are refered to by a collection are preserved during an
optimization PiperOrigin-RevId: 186394467
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 55dcbe2071..5683ab5a04 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -48,6 +49,31 @@ class PyWrapOptimizeGraphTest(test.TestCase):
self.assertEqual(len(graph.node), 1)
self.assertItemsEqual([node.name for node in graph.node], ['d'])
+ def testKeepNodes(self):
+ g = ops.Graph()
+ with g.as_default():
+ a1 = variables.Variable(
+ 1.0) # Must be preserved since it's in the collection 'variables'.
+ a2 = constant_op.constant(0, shape=[50, 50], name='keep')
+ ops.add_to_collection('a2', a2) # Explicitly add to collection.
+ b = constant_op.constant(1, shape=[100, 10])
+ c = constant_op.constant(0, shape=[10, 30])
+ d = math_ops.matmul(b, c)
+ ops.add_to_collection('train_op', d) # d is the fetch node.
+
+ # Optimize the graph.
+ mg = meta_graph.create_meta_graph_def(graph=g)
+ rewriter_config = rewriter_config_pb2.RewriterConfig()
+ optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
+
+ # Check that the nodes referenced in various collections have been preserved
+ self.assertEqual(len(optimized_graph.node), 5)
+ self.assertEqual(a2.op.name, optimized_graph.node[0].name)
+ self.assertEqual(a1.op.name, optimized_graph.node[1].name)
+ self.assertEqual('Variable/initial_value', optimized_graph.node[2].name)
+ self.assertEqual(d.op.name, optimized_graph.node[3].name)
+ self.assertEqual('Variable/Assign', optimized_graph.node[4].name)
+
if __name__ == '__main__':
test.main()