aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/client/session_test.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 040cc33315..5a42c50fff 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -31,6 +31,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
@@ -45,6 +46,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -2099,6 +2101,43 @@ class SessionTest(test_util.TensorFlowTestCase):
res = sess.partial_run(h, r2, feed_dict={c: 3})
self.assertEqual(9, res)
+ def testGraphOptimizer(self):
+ rewrite_options = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=False, constant_folding=True)
+ graph_options = config_pb2.GraphOptions(
+ rewrite_options=rewrite_options, build_cost_model=1)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+
+ with ops.Graph().as_default() as g:
+ r1 = random_ops.random_normal(shape=[2, 3], name='R1')
+ r2 = random_ops.random_normal(shape=[2, 3], name='R2')
+ copy1 = array_ops.stop_gradient(r1)
+ copy2 = array_ops.identity(r2)
+ result = copy1 + copy2
+
+ with session.Session(graph=g, config=config) as sess:
+ metadata = config_pb2.RunMetadata()
+ sess.run(result, run_metadata=metadata)
+
+ # Check that we optimized the graph by looking at the cost model: the add
+ # node should have been reconnected directly to the R1 and R2 nodes.
+ found_valid_nodes = 0
+ for node in metadata.cost_graph.node:
+ if node.name == 'R1':
+ r1_cost_id = node.id
+ found_valid_nodes += 1
+ if node.name == 'R2':
+ r2_cost_id = node.id
+ found_valid_nodes += 1
+ if node.name == 'add':
+ if node.input_info[0].preceding_node == r1_cost_id:
+ self.assertEqual(node.input_info[1].preceding_node, r2_cost_id)
+ found_valid_nodes += 1
+ elif node.input_info[0].preceding_node == r2_cost_id:
+ self.assertEqual(node.input_info[1].preceding_node, r1_cost_id)
+ found_valid_nodes += 1
+ self.assertEqual(3, found_valid_nodes)
+
if __name__ == '__main__':
googletest.main()