diff options
-rw-r--r-- | tensorflow/python/client/session_test.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 040cc33315..5a42c50fff 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -31,6 +31,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op @@ -45,6 +46,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -2099,6 +2101,43 @@ class SessionTest(test_util.TensorFlowTestCase): res = sess.partial_run(h, r2, feed_dict={c: 3}) self.assertEqual(9, res) + def testGraphOptimizer(self): + rewrite_options = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=False, constant_folding=True) + graph_options = config_pb2.GraphOptions( + rewrite_options=rewrite_options, build_cost_model=1) + config = config_pb2.ConfigProto(graph_options=graph_options) + + with ops.Graph().as_default() as g: + r1 = random_ops.random_normal(shape=[2, 3], name='R1') + r2 = random_ops.random_normal(shape=[2, 3], name='R2') + copy1 = array_ops.stop_gradient(r1) + copy2 = array_ops.identity(r2) + result = copy1 + copy2 + + with session.Session(graph=g, config=config) as sess: + metadata = config_pb2.RunMetadata() + sess.run(result, run_metadata=metadata) + + # Check that we optimized the graph by looking at the cost model: the add + # node should have been reconnected directly to the R1 and R2 nodes. + found_valid_nodes = 0 + for node in metadata.cost_graph.node: + if node.name == 'R1': + r1_cost_id = node.id + found_valid_nodes += 1 + if node.name == 'R2': + r2_cost_id = node.id + found_valid_nodes += 1 + if node.name == 'add': + if node.input_info[0].preceding_node == r1_cost_id: + self.assertEqual(node.input_info[1].preceding_node, r2_cost_id) + found_valid_nodes += 1 + elif node.input_info[0].preceding_node == r2_cost_id: + self.assertEqual(node.input_info[1].preceding_node, r1_cost_id) + found_valid_nodes += 1 + self.assertEqual(3, found_valid_nodes) + if __name__ == '__main__': googletest.main() |