diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-22 05:15:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-22 05:19:06 -0700 |
commit | ca552d54ac67be8837aeabdb43269846d9df4eb5 (patch) | |
tree | 11d592685766ab64187b520d91d7dfa2b6f231fc /tensorflow/python/debug | |
parent | e317152dad1aa66bc493abc046a60dbbf650de92 (diff) |
Add PinToHostOptimizer to grappler: force small ops to happen on CPU (instead of
GPU). This avoids many unnecessary CPU<->GPU memcpy and syncs.
PiperOrigin-RevId: 214108484
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r-- | tensorflow/python/debug/cli/analyzer_cli_test.py | 3 | ||||
-rw-r--r-- | tensorflow/python/debug/lib/debug_graph_reconstruction_test.py | 3 |
2 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index 55231954d1..4630bda590 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -57,7 +57,8 @@ def no_rewrite_session_config(): disable_model_pruning=True, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, - dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py index 676097fde9..1f67f8a0d4 100644 --- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py +++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py @@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase): def _no_rewrite_session_config(self): rewriter_config = rewriter_config_pb2.RewriterConfig( dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF, min_graph_nodes=-1) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) @@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase): sess, cond, expected_output=21.0) def testReconstructGraphWithWhileLoop(self): - with session.Session() as sess: + with session.Session(config=self._no_rewrite_session_config()) as sess: loop_body = lambda i: math_ops.add(i, 2) loop_cond = lambda i: math_ops.less(i, 16) i = constant_op.constant(10, name="i") |