aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-20 17:06:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 17:16:54 -0700
commit9962eb5e84b15e309410071b06c2ed2d6148ed44 (patch)
tree01dbe9a94f5ef80ad36d9019e934346daed88b81 /tensorflow
parenta4404f6158a5489f4cf8769b2d7ec1b336f37cc7 (diff)
Deprecate test_session() and introduce instead session() and cached_session().
Implicitly, test_session(graph=None,...) will by default cache the session and thus the object will not be deleted if used in a with statement: with self.test_session(): # Some code # Some other code where by default we would expect that the session is closed by now. The fact that the session is not closed when we exit the with scope has caused surprise to users and created hard to diagnose bugs. To keep backward compatibility, we are deprecating this function and introducing a new one called cached_session that will by default always cache the session (even when graph is not None) and a new function where session will always be closed. Note that the behavior of test_session has slightly changed as it will always set the graph and session as global default (with as_default). The documentation of test_session has also changed as the documentation was not matching the previous behavior of test_session. PiperOrigin-RevId: 209511997
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/framework/test_util.py310
-rw-r--r--tensorflow/python/framework/test_util_test.py28
2 files changed, 244 insertions, 94 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index d2d18222ba..d690f08d88 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -369,6 +369,7 @@ def enable_c_shapes(fn):
fn(*args, **kwargs)
finally:
ops._USE_C_SHAPES = prev_value
+
# pylint: enable=protected-access
return wrapper
@@ -418,7 +419,8 @@ def assert_no_new_pyobjects_executing_eagerly(f):
previous_count = len(gc.get_objects())
collection_sizes_before = {
collection: len(ops.get_collection(collection))
- for collection in ops.get_default_graph().collections}
+ for collection in ops.get_default_graph().collections
+ }
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
@@ -430,8 +432,8 @@ def assert_no_new_pyobjects_executing_eagerly(f):
if len(collection) > size_before:
raise AssertionError(
("Collection %s increased in size from "
- "%d to %d (current items %s).")
- % (collection_key, size_before, len(collection), collection))
+ "%d to %d (current items %s).") % (collection_key, size_before,
+ len(collection), collection))
# Make sure our collection checks don't show up as leaked memory by
# removing references to temporary variables.
del collection
@@ -446,8 +448,8 @@ def assert_no_new_pyobjects_executing_eagerly(f):
# Using plain assert because not all classes using this decorator
# have assertLessEqual
assert new_count <= previous_count, (
- "new_count(%d) is not less than or equal to previous_count(%d)" % (
- new_count, previous_count))
+ "new_count(%d) is not less than or equal to previous_count(%d)" %
+ (new_count, previous_count))
gc.enable()
return decorator
@@ -547,10 +549,12 @@ def assert_no_garbage_created(f):
return "<%s %d>" % (obj.__class__.__name__, id(obj))
logging.error(" Object type: %s", _safe_object_str(obj))
- logging.error(" Referrer types: %s", ", ".join(
- [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
- logging.error(" Referent types: %s", ", ".join(
- [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
+ logging.error(
+ " Referrer types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
+ logging.error(
+ " Referent types: %s", ", ".join(
+ [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
logging.error(" Object attribute names: %s", dir(obj))
logging.error(" Object __str__:")
logging.error(obj)
@@ -629,9 +633,8 @@ def generate_combinations_with_testcase_name(**kwargs):
for combination in combinations:
assert isinstance(combination, OrderedDict)
name = "".join([
- "_{}_{}".format(
- "".join(filter(str.isalnum, key)),
- "".join(filter(str.isalnum, str(value))))
+ "_{}_{}".format("".join(filter(str.isalnum, key)), "".join(
+ filter(str.isalnum, str(value))))
for key, value in combination.items()
])
named_combinations.append(
@@ -971,21 +974,64 @@ class TensorFlowTestCase(googletest.TestCase):
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
- def test_session(self,
- graph=None,
- config=None,
- use_gpu=False,
- force_gpu=False):
+ def session(self, graph=None, config=None, use_gpu=False, force_gpu=False):
"""Returns a TensorFlow Session for use in executing tests.
- This method should be used for all functional tests.
+ Note that this will set this session and the graph as global defaults.
- This method behaves different than session.Session: for performance reasons
- `test_session` will by default (if `graph` is None) reuse the same session
- across tests. This means you may want to either call the function
- `reset_default_graph()` before tests, or if creating an explicit new graph,
- pass it here (simply setting it with `as_default()` won't do it), which will
- trigger the creation of a new session.
+ Use the `use_gpu` and `force_gpu` options to control where ops are run. If
+ `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
+ `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
+ possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
+ the CPU.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ with self.session(use_gpu=True):
+ valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ result = MyOperator(valid_input).eval()
+ self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
+ invalid_input = [-1.0, 2.0, 7.0]
+ with self.assertRaisesOpError("negative input not supported"):
+ MyOperator(invalid_input).eval()
+ ```
+
+ Args:
+ graph: Optional graph to use during the returned session.
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+ use_gpu: If True, attempt to run as many ops as possible on GPU.
+ force_gpu: If True, pin all ops to `/device:GPU:0`.
+
+ Yields:
+ A Session object that should be used as a context manager to surround
+ the graph building and execution code in a test case.
+ """
+ if context.executing_eagerly():
+ yield None
+ else:
+ sess = self._create_session(graph, config, use_gpu, force_gpu)
+ with self._constrain_devices_and_set_default(
+ sess, use_gpu, force_gpu) as constrained_sess:
+ # We need to do this to make sure the session closes, otherwise, even
+ # if the user does with self.session():, it will not close the session.
+ with constrained_sess:
+ yield constrained_sess
+
+ @contextlib.contextmanager
+ def cached_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Returns a TensorFlow Session for use in executing tests.
+
+ This method behaves differently than self.session(): for performance reasons
+ `cached_session` will by default reuse the same session within the same
+ test. The session returned by this function will only be closed at the end
+ of the test (in the TearDown function).
Use the `use_gpu` and `force_gpu` options to control where ops are run. If
`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
@@ -997,7 +1043,7 @@ class TensorFlowTestCase(googletest.TestCase):
```python
class MyOperatorTest(test_util.TensorFlowTestCase):
def testMyOperator(self):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True) as sess:
valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
result = MyOperator(valid_input).eval()
self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
@@ -1013,74 +1059,39 @@ class TensorFlowTestCase(googletest.TestCase):
use_gpu: If True, attempt to run as many ops as possible on GPU.
force_gpu: If True, pin all ops to `/device:GPU:0`.
- Returns:
+ Yields:
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case.
"""
+ if context.executing_eagerly():
+ yield None
+ else:
+ with self._get_cached_session(
+ graph, config, use_gpu, force_gpu,
+ crash_if_inconsistent_args=True) as sess:
+ yield sess
+
+ @contextlib.contextmanager
+ def test_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False):
+ """Use cached_session instead."""
if self.id().endswith(".test_session"):
self.skipTest("Not a test.")
- def prepare_config(config):
- """Returns a config for sessions.
-
- Args:
- config: An optional config_pb2.ConfigProto to use to configure the
- session.
- Returns:
- A config_pb2.ConfigProto object.
- """
- if config is None:
- config = config_pb2.ConfigProto()
- config.allow_soft_placement = not force_gpu
- config.gpu_options.per_process_gpu_memory_fraction = 0.3
- elif force_gpu and config.allow_soft_placement:
- config = config_pb2.ConfigProto().CopyFrom(config)
- config.allow_soft_placement = False
- # Don't perform optimizations for tests so we don't inadvertently run
- # gpu ops on cpu
- config.graph_options.optimizer_options.opt_level = -1
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- config.graph_options.rewrite_options.arithmetic_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
- return config
-
if context.executing_eagerly():
yield None
- elif graph is None:
- if self._cached_session is None:
- self._cached_session = session.Session(
- graph=None, config=prepare_config(config))
- sess = self._cached_session
- with sess.graph.as_default(), sess.as_default():
- if force_gpu:
- # Use the name of an actual device if one is detected, or '/device:GPU:0'
- # otherwise
- gpu_name = gpu_device_name()
- if not gpu_name:
- gpu_name = "/device:GPU:0"
- with sess.graph.device(gpu_name):
- yield sess
- elif use_gpu:
- yield sess
- else:
- with sess.graph.device("/cpu:0"):
- yield sess
else:
- with session.Session(graph=graph, config=prepare_config(config)) as sess:
- if force_gpu:
- # Use the name of an actual device if one is detected, or '/device:GPU:0'
- # otherwise
- gpu_name = gpu_device_name()
- if not gpu_name:
- gpu_name = "/device:GPU:0"
- with sess.graph.device(gpu_name):
- yield sess
- elif use_gpu:
+ if graph is None:
+ with self._get_cached_session(
+ graph, config, use_gpu, force_gpu,
+ crash_if_inconsistent_args=False) as sess:
+ yield sess
+ else:
+ with self.session(graph, config, use_gpu, force_gpu) as sess:
yield sess
- else:
- with sess.graph.device("/cpu:0"):
- yield sess
# pylint: enable=g-doc-return-or-yield
@@ -1206,9 +1217,10 @@ class TensorFlowTestCase(googletest.TestCase):
msg: An optional string message to append to the failure message.
"""
# f1 == f2 is needed here as we might have: f1, f2 = inf, inf
- self.assertTrue(f1 == f2 or math.fabs(f1 - f2) <= err,
- "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
- if msg is not None else ""))
+ self.assertTrue(
+ f1 == f2 or math.fabs(f1 - f2) <= err,
+ "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
+ if msg is not None else ""))
def assertArrayNear(self, farray1, farray2, err, msg=None):
"""Asserts that two float arrays are near each other.
@@ -1254,8 +1266,9 @@ class TensorFlowTestCase(googletest.TestCase):
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
- (a.shape, b.shape))
+ self.assertEqual(
+ a.shape, b.shape,
+ "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
#
@@ -1457,8 +1470,9 @@ class TensorFlowTestCase(googletest.TestCase):
msg = msg if msg else ""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
- self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s."
- " %s" % (a.shape, b.shape, msg))
+ self.assertEqual(
+ a.shape, b.shape, "Shape mismatch: expected %s, got %s."
+ " %s" % (a.shape, b.shape, msg))
same = (a == b)
if (a.dtype in [
@@ -1686,8 +1700,8 @@ class TensorFlowTestCase(googletest.TestCase):
self.fail(exception_type.__name__ + " not raised")
except Exception as e: # pylint: disable=broad-except
if not isinstance(e, exception_type) or not predicate(e):
- raise AssertionError("Exception of type %s: %s" % (str(type(e)),
- str(e)))
+ raise AssertionError(
+ "Exception of type %s: %s" % (str(type(e)), str(e)))
# pylint: enable=g-doc-return-or-yield
@@ -1723,8 +1737,9 @@ class TensorFlowTestCase(googletest.TestCase):
"""
device1 = pydev.canonical_name(device1)
device2 = pydev.canonical_name(device2)
- self.assertEqual(device1, device2, "Devices %s and %s are not equal. %s" %
- (device1, device2, msg))
+ self.assertEqual(
+ device1, device2,
+ "Devices %s and %s are not equal. %s" % (device1, device2, msg))
# Fix Python 3 compatibility issues
if six.PY3:
@@ -1738,6 +1753,113 @@ class TensorFlowTestCase(googletest.TestCase):
# pylint: enable=invalid-name
+ @contextlib.contextmanager
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """Set the session and its graph to global default and constrain devices."""
+ if context.executing_eagerly():
+ yield None
+ else:
+ with sess.graph.as_default(), sess.as_default():
+ if force_gpu:
+ # Use the name of an actual device if one is detected, or
+ # '/device:GPU:0' otherwise
+ gpu_name = gpu_device_name()
+ if not gpu_name:
+ gpu_name = "/device:GPU:0"
+ with sess.graph.device(gpu_name):
+ yield sess
+ elif use_gpu:
+ yield sess
+ else:
+ with sess.graph.device("/cpu:0"):
+ yield sess
+
+ def _create_session(self, graph, config, use_gpu, force_gpu):
+ """See session() for details."""
+ if context.executing_eagerly():
+ return None
+ else:
+
+ def prepare_config(config):
+ """Returns a config for sessions.
+
+ Args:
+ config: An optional config_pb2.ConfigProto to use to configure the
+ session.
+ Returns:
+ A config_pb2.ConfigProto object.
+ """
+ if config is None:
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = not force_gpu
+ config.gpu_options.per_process_gpu_memory_fraction = 0.3
+ elif force_gpu and config.allow_soft_placement:
+ config = config_pb2.ConfigProto().CopyFrom(config)
+ config.allow_soft_placement = False
+ # Don't perform optimizations for tests so we don't inadvertently run
+ # gpu ops on cpu
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+ return session.Session(graph=graph, config=prepare_config(config))
+
+ @contextlib.contextmanager
+ def _get_cached_session(self,
+ graph=None,
+ config=None,
+ use_gpu=False,
+ force_gpu=False,
+ crash_if_inconsistent_args=True):
+ """See cached_session() for documentation."""
+ if context.executing_eagerly():
+ yield None
+ else:
+ if self._cached_session is None:
+ sess = self._create_session(
+ graph=graph, config=config, use_gpu=use_gpu, force_gpu=force_gpu)
+ self._cached_session = sess
+ self._cached_graph = graph
+ self._cached_config = config
+ self._cached_use_gpu = use_gpu
+ self._cached_force_gpu = force_gpu
+ with self._constrain_devices_and_set_default(
+ sess, use_gpu, force_gpu) as constrained_sess:
+ yield constrained_sess
+ else:
+ if crash_if_inconsistent_args and self._cached_graph is not graph:
+ raise ValueError("The graph used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and self._cached_config is not config:
+ raise ValueError("The config used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and self._cached_use_gpu is not use_gpu:
+ raise ValueError(
+ "The use_gpu value used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ if crash_if_inconsistent_args and (self._cached_force_gpu is
+ not force_gpu):
+ raise ValueError(
+ "The force_gpu value used to get the cached session is "
+ "different than the one that was used to create the "
+ "session. Maybe create a new session with "
+ "self.session()")
+ # If you modify this logic, make sure to modify it in _create_session
+ # as well.
+ sess = self._cached_session
+ with self._constrain_devices_and_set_default(
+ sess, use_gpu, force_gpu) as constrained_sess:
+ yield constrained_sess
+
@tf_export("test.create_local_cluster")
def create_local_cluster(num_workers,
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 3a34dd9505..f68c0ddecb 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -22,6 +22,7 @@ import collections
import copy
import random
import threading
+import weakref
import numpy as np
@@ -58,6 +59,33 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertRaises(ValueError, test_util.assert_ops_in_graph,
{"hello": "Variable"}, ops.get_default_graph())
+ def test_session_functions(self):
+ with self.test_session() as sess:
+ sess_ref = weakref.ref(sess)
+ with self.cached_session(graph=None, config=None) as sess2:
+ # We make sure that sess2 is sess.
+ assert sess2 is sess
+ # We make sure we raise an exception if we use cached_session with
+ # different values.
+ with self.assertRaises(ValueError):
+ with self.cached_session(graph=ops.Graph()) as sess2:
+ pass
+ with self.assertRaises(ValueError):
+ with self.cached_session(use_gpu=True) as sess2:
+ pass
+ with self.assertRaises(ValueError):
+ with self.cached_session(force_gpu=True) as sess2:
+ pass
+ # We make sure that test_session will cache the session even after the
+ # with scope.
+ assert not sess_ref()._closed
+ with self.session() as unique_sess:
+ unique_sess_ref = weakref.ref(unique_sess)
+ with self.session() as sess2:
+ assert sess2 is not unique_sess
+ # We make sure the session is closed when we leave the with statement.
+ assert unique_sess_ref()._closed
+
def test_assert_equal_graph_def(self):
with ops.Graph().as_default() as g:
def_empty = g.as_graph_def()