diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-29 17:50:56 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-29 17:55:17 -0800 |
commit | e3d99c92975efc2010d0e1e2dd4c3eb787a8d67c (patch) | |
tree | b25a5fee98ef6de775bf149f1bf0bf24c8c45d32 /tensorflow/python/debug | |
parent | a4b88a5b795d5496bffe4ff80875a5bf0954a4d6 (diff) |
Remove Identity nodes if num_inputs * num_outputs <= num_inputs + num_outputs. Exceptions are Identity nodes after Variable nodes, and Identity nodes after Switch nodes when removing the node would require anchoring a control dependency on the Switch.
Another exception is Identity nodes where inputs or outputs cross a device boundary, since we are not allowed to remove Identity nodes after _Recv that might be inserted in the graph later.
PiperOrigin-RevId: 183759826
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r-- | tensorflow/python/debug/lib/debug_gradients_test.py | 7 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/session_debug_grpc_test.py | 3 |
2 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index b6c7280a41..c1e9869d97 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -22,6 +22,7 @@ import shutil import tempfile from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_gradients @@ -38,7 +39,11 @@ from tensorflow.python.training import gradient_descent class IdentifyGradientTest(test_util.TensorFlowTestCase): def setUp(self): - self.sess = session.Session() + rewriter_config = rewriter_config_pb2.RewriterConfig( + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + config = config_pb2.ConfigProto(graph_options=graph_options) + self.sess = session.Session(config=config) with self.sess.as_default(): self.u = variables.Variable(2.0, name="u") self.v = variables.Variable(3.0, name="v") diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index 367b353545..b623ee31c5 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -54,7 +54,8 @@ from tensorflow.python.training import monitored_session def no_rewrite_session_config(): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) |