aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-29 17:50:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 17:55:17 -0800
commite3d99c92975efc2010d0e1e2dd4c3eb787a8d67c (patch)
treeb25a5fee98ef6de775bf149f1bf0bf24c8c45d32 /tensorflow/python/debug
parenta4b88a5b795d5496bffe4ff80875a5bf0954a4d6 (diff)
Remove Identity nodes if num_inputs * num_outputs <= num_inputs + num_outputs. Exceptions are Identity nodes after Variable nodes, and Identity nodes after Switch nodes when removing the node would require anchoring a control dependency on the Switch.
Another exception is Identity nodes where inputs or outputs cross a device boundary, since we are not allowed to remove Identity nodes after _Recv that might be inserted in the graph later. PiperOrigin-RevId: 183759826
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/lib/debug_gradients_test.py7
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py3
2 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py
index b6c7280a41..c1e9869d97 100644
--- a/tensorflow/python/debug/lib/debug_gradients_test.py
+++ b/tensorflow/python/debug/lib/debug_gradients_test.py
@@ -22,6 +22,7 @@ import shutil
import tempfile
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_gradients
@@ -38,7 +39,11 @@ from tensorflow.python.training import gradient_descent
class IdentifyGradientTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config)
with self.sess.as_default():
self.u = variables.Variable(2.0, name="u")
self.v = variables.Variable(3.0, name="v")
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index 367b353545..b623ee31c5 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -54,7 +54,8 @@ from tensorflow.python.training import monitored_session
def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
- arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)