aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-22 05:15:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-22 05:19:06 -0700
commitca552d54ac67be8837aeabdb43269846d9df4eb5 (patch)
tree11d592685766ab64187b520d91d7dfa2b6f231fc /tensorflow/python/debug
parente317152dad1aa66bc493abc046a60dbbf650de92 (diff)
Add PinToHostOptimizer to grappler: force small ops to happen on CPU (instead of
GPU). This avoids many unnecessary CPU<->GPU memcpy and syncs. PiperOrigin-RevId: 214108484
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py3
-rw-r--r--tensorflow/python/debug/lib/debug_graph_reconstruction_test.py3
2 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 55231954d1..4630bda590 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -57,7 +57,8 @@ def no_rewrite_session_config():
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
- dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
index 676097fde9..1f67f8a0d4 100644
--- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
+++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
@@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
min_graph_nodes=-1)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
@@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
sess, cond, expected_output=21.0)
def testReconstructGraphWithWhileLoop(self):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
loop_body = lambda i: math_ops.add(i, 2)
loop_cond = lambda i: math_ops.less(i, 16)
i = constant_op.constant(10, name="i")