aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-09-05 10:50:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 10:56:47 -0700
commit587808a8ad12fdb20270bb4fefbf85a48702383b (patch)
treee996b579b96db54cf9587b5591f8ea6fd17145aa /tensorflow/python/framework
parent5032036e1f2a7060848aed64bce94a1f882142d5 (diff)
test_util.py: Allow use_gpu to change between calls to self.cached_session()
use_gpu does not affect the creation of the session, it only affects the context manager in which nodes are added to the graph, so it should not be included in the consistency check. PiperOrigin-RevId: 211659833
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/test_util.py156
-rw-r--r--tensorflow/python/framework/test_util_test.py3
2 files changed, 66 insertions, 93 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 3b63e49a84..0925598e33 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1073,13 +1073,9 @@ class TensorFlowTestCase(googletest.TestCase):
if context.executing_eagerly():
yield None
else:
- sess = self._create_session(graph, config, use_gpu, force_gpu)
- with self._constrain_devices_and_set_default(
- sess, use_gpu, force_gpu) as constrained_sess:
- # We need to do this to make sure the session closes, otherwise, even
- # if the user does with self.session():, it will not close the session.
- with constrained_sess:
- yield constrained_sess
+ with self._create_session(graph, config, force_gpu) as sess:
+ with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
+ yield sess
@contextlib.contextmanager
def cached_session(self,
@@ -1127,10 +1123,11 @@ class TensorFlowTestCase(googletest.TestCase):
if context.executing_eagerly():
yield None
else:
- with self._get_cached_session(
- graph, config, use_gpu, force_gpu,
- crash_if_inconsistent_args=True) as sess:
- yield sess
+ sess = self._get_cached_session(
+ graph, config, force_gpu, crash_if_inconsistent_args=True)
+ with self._constrain_devices_and_set_default(sess, use_gpu,
+ force_gpu) as cached:
+ yield cached
@contextlib.contextmanager
def test_session(self,
@@ -1146,10 +1143,11 @@ class TensorFlowTestCase(googletest.TestCase):
yield None
else:
if graph is None:
- with self._get_cached_session(
- graph, config, use_gpu, force_gpu,
- crash_if_inconsistent_args=False) as sess:
- yield sess
+ sess = self._get_cached_session(
+ graph, config, force_gpu, crash_if_inconsistent_args=False)
+ with self._constrain_devices_and_set_default(sess, use_gpu,
+ force_gpu) as cached:
+ yield cached
else:
with self.session(graph, config, use_gpu, force_gpu) as sess:
yield sess
@@ -1835,91 +1833,69 @@ class TensorFlowTestCase(googletest.TestCase):
with sess.graph.device("/cpu:0"):
yield sess
- def _create_session(self, graph, config, use_gpu, force_gpu):
+ def _create_session(self, graph, config, force_gpu):
"""See session() for details."""
- if context.executing_eagerly():
- return None
- else:
+ def prepare_config(config):
+ """Returns a config for sessions.
- def prepare_config(config):
- """Returns a config for sessions.
-
- Args:
- config: An optional config_pb2.ConfigProto to use to configure the
- session.
- Returns:
- A config_pb2.ConfigProto object.
- """
- if config is None:
- config = config_pb2.ConfigProto()
- config.allow_soft_placement = not force_gpu
- config.gpu_options.per_process_gpu_memory_fraction = 0.3
- elif force_gpu and config.allow_soft_placement:
- config = config_pb2.ConfigProto().CopyFrom(config)
- config.allow_soft_placement = False
- # Don't perform optimizations for tests so we don't inadvertently run
- # gpu ops on cpu
- config.graph_options.optimizer_options.opt_level = -1
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- config.graph_options.rewrite_options.arithmetic_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
- return config
-
- return ErrorLoggingSession(graph=graph, config=prepare_config(config))
+ Args:
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+
+ Returns:
+ A config_pb2.ConfigProto object.
+ """
+ if config is None:
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = not force_gpu
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ elif force_gpu and config.allow_soft_placement:
+ config = config_pb2.ConfigProto().CopyFrom(config)
+ config.allow_soft_placement = False
+ # Don't perform optimizations for tests so we don't inadvertently run
+ # gpu ops on cpu
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+ return ErrorLoggingSession(graph=graph, config=prepare_config(config))
- @contextlib.contextmanager
def _get_cached_session(self,
graph=None,
config=None,
- use_gpu=False,
force_gpu=False,
crash_if_inconsistent_args=True):
"""See cached_session() for documentation."""
- if context.executing_eagerly():
- yield None
+ if self._cached_session is None:
+ sess = self._create_session(
+ graph=graph, config=config, force_gpu=force_gpu)
+ self._cached_session = sess
+ self._cached_graph = graph
+ self._cached_config = config
+ self._cached_force_gpu = force_gpu
+ return sess
else:
- if self._cached_session is None:
- sess = self._create_session(
- graph=graph, config=config, use_gpu=use_gpu, force_gpu=force_gpu)
- self._cached_session = sess
- self._cached_graph = graph
- self._cached_config = config
- self._cached_use_gpu = use_gpu
- self._cached_force_gpu = force_gpu
- with self._constrain_devices_and_set_default(
- sess, use_gpu, force_gpu) as constrained_sess:
- yield constrained_sess
- else:
- if crash_if_inconsistent_args and self._cached_graph is not graph:
- raise ValueError("The graph used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- if crash_if_inconsistent_args and self._cached_config is not config:
- raise ValueError("The config used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- if crash_if_inconsistent_args and self._cached_use_gpu is not use_gpu:
- raise ValueError(
- "The use_gpu value used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- if crash_if_inconsistent_args and (self._cached_force_gpu is
- not force_gpu):
- raise ValueError(
- "The force_gpu value used to get the cached session is "
- "different than the one that was used to create the "
- "session. Maybe create a new session with "
- "self.session()")
- # If you modify this logic, make sure to modify it in _create_session
- # as well.
- sess = self._cached_session
- with self._constrain_devices_and_set_default(
- sess, use_gpu, force_gpu) as constrained_sess:
- yield constrained_sess
+ if crash_if_inconsistent_args and self._cached_graph is not graph:
+ raise ValueError("The graph used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and self._cached_config is not config:
+ raise ValueError("The config used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and (self._cached_force_gpu is
+ not force_gpu):
+ raise ValueError(
+ "The force_gpu value used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ return self._cached_session
@tf_export("test.create_local_cluster")
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index a0939f98b2..c4f8fa9108 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -71,9 +71,6 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.cached_session(graph=ops.Graph()) as sess2:
pass
with self.assertRaises(ValueError):
- with self.cached_session(use_gpu=True) as sess2:
- pass
- with self.assertRaises(ValueError):
with self.cached_session(force_gpu=True) as sess2:
pass
# We make sure that test_session will cache the session even after the