diff options
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 183 |
1 files changed, 61 insertions, 122 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index b56483f373..2bc2a189fa 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -27,6 +27,7 @@ import random import re import tempfile import threading +import unittest import numpy as np import six @@ -61,13 +62,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export @@ -321,32 +322,6 @@ def NCHWToNHWC(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] -# TODO(skyewm): remove this eventually -# pylint: disable=protected-access -def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs): - prev_value = ops._USE_C_API - ops._USE_C_API = use_c_api - try: - # Reset the default graph so it has the C API enabled. We call - # reset_default_graph() instead of creating a new default Graph context to - # make this robust to tests that call reset_default_graph(), which requires - # that the current default graph isn't nested. - ops.reset_default_graph() - fn(*args, **kwargs) - finally: - ops._USE_C_API = prev_value - # Make sure default graph reflects prev_value in case next test doesn't call - # reset_default_graph(). - ops.reset_default_graph() - - -# pylint: disable=protected-access - - -def c_api_and_cuda_enabled(): - return ops._USE_C_API and IsGoogleCudaEnabled() - - def skip_if(condition): """Skips the decorated function if condition is or evaluates to True. @@ -372,46 +347,6 @@ def skip_if(condition): return real_skip_if -# TODO(skyewm): remove this eventually -def disable_c_api(fn): - """Decorator for disabling the C API on a test. - - Note this disables the C API after running the test class's setup/teardown - methods. - - Args: - fn: the function to be wrapped - - Returns: - The wrapped function - """ - - def wrapper(*args, **kwargs): - _use_c_api_wrapper(fn, False, *args, **kwargs) - - return wrapper - - -# TODO(skyewm): remove this eventually -def enable_c_api(fn): - """Decorator for enabling the C API on a test. - - Note this enables the C API after running the test class's setup/teardown - methods. - - Args: - fn: the function to be wrapped - - Returns: - The wrapped function - """ - - def wrapper(*args, **kwargs): - _use_c_api_wrapper(fn, True, *args, **kwargs) - - return wrapper - - def enable_c_shapes(fn): """Decorator for enabling C shapes on a test. @@ -425,46 +360,19 @@ def enable_c_shapes(fn): The wrapped function """ + # pylint: disable=protected-access def wrapper(*args, **kwargs): prev_value = ops._USE_C_SHAPES - # Only use C shapes if the C API is already enabled. - ops._USE_C_SHAPES = ops._USE_C_API + ops._USE_C_SHAPES = True try: fn(*args, **kwargs) finally: ops._USE_C_SHAPES = prev_value + # pylint: enable=protected-access return wrapper -# This decorator is a hacky way to run all the test methods in a decorated -# class with and without C API enabled. -# TODO(iga): Remove this and its uses once we switch to using C API by default. -def with_c_api(cls): - """Adds methods that call original methods but with C API enabled. - - Note this enables the C API in new methods after running the test class's - setup method. This can be a problem if some objects are created in it - before the C API is enabled. - - Args: - cls: class to decorate - - Returns: - cls with new test methods added - """ - # If the C API is already enabled, don't do anything. Some tests break if the - # same test is run twice, so this allows us to turn on the C API by default - # without breaking these tests. - if ops._USE_C_API: - return cls - - for name, value in cls.__dict__.copy().items(): - if callable(value) and name.startswith("test"): - setattr(cls, name + "WithCApi", enable_c_api(value)) - return cls - - def with_c_shapes(cls): """Adds methods that call original methods but with C API shapes enabled. @@ -507,8 +415,28 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections} for _ in range(3): f(self, **kwargs) + # Note that gc.get_objects misses anything that isn't subject to garbage + # collection (C types). Collections are a common source of leaks, so we + # test for collection sizes explicitly. + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") + % (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) @@ -644,14 +572,15 @@ def assert_no_garbage_created(f): def run_all_in_graph_and_eager_modes(cls): - base_decorator = run_in_graph_and_eager_modes() + """Execute all test methods in the given class with and without eager.""" + base_decorator = run_in_graph_and_eager_modes for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): setattr(cls, name, base_decorator(value)) return cls -def run_in_graph_and_eager_modes(__unused__=None, +def run_in_graph_and_eager_modes(func=None, config=None, use_gpu=True, reset_test=True, @@ -669,7 +598,7 @@ def run_in_graph_and_eager_modes(__unused__=None, ```python class MyTests(tf.test.TestCase): - @run_in_graph_and_eager_modes() + @run_in_graph_and_eager_modes def test_foo(self): x = tf.constant([1, 2]) y = tf.constant([3, 4]) @@ -686,7 +615,9 @@ def run_in_graph_and_eager_modes(__unused__=None, Args: - __unused__: Prevents silently skipping tests. + func: function to be annotated. If `func` is None, this method returns a + decorator the can be applied to a function. If `func` is not None this + returns the decorator applied to `func`. config: An optional config_pb2.ConfigProto to use to configure the session when executing graphs. use_gpu: If True, attempt to run as many operations as possible on GPU. @@ -708,20 +639,19 @@ def run_in_graph_and_eager_modes(__unused__=None, eager execution enabled. """ - assert not __unused__, "Add () after run_in_graph_and_eager_modes." - def decorator(f): - def decorated(self, **kwargs): - with context.graph_mode(): - with self.test_session(use_gpu=use_gpu): - f(self, **kwargs) + if tf_inspect.isclass(f): + raise ValueError( + "`run_test_in_graph_and_eager_modes` only supports test methods. " + "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?") - if reset_test: - # This decorator runs the wrapped test twice. - # Reset the test environment between runs. - self.tearDown() - self._tempdir = None - self.setUp() + def decorated(self, **kwargs): + try: + with context.graph_mode(): + with self.test_session(use_gpu=use_gpu, config=config): + f(self, **kwargs) + except unittest.case.SkipTest: + pass def run_eagerly(self, **kwargs): if not use_gpu: @@ -736,10 +666,20 @@ def run_in_graph_and_eager_modes(__unused__=None, assert_no_garbage_created(run_eagerly)) with context.eager_mode(): + if reset_test: + # This decorator runs the wrapped test twice. + # Reset the test environment between runs. + self.tearDown() + self._tempdir = None + self.setUp() + run_eagerly(self, **kwargs) return decorated + if func is not None: + return decorator(func) + return decorator @@ -922,14 +862,13 @@ class TensorFlowTestCase(googletest.TestCase): def _eval_tensor(self, tensor): if tensor is None: return None - elif isinstance(tensor, ops.EagerTensor): - return tensor.numpy() - elif isinstance(tensor, resource_variable_ops.ResourceVariable): - return tensor.read_value().numpy() elif callable(tensor): return self._eval_helper(tensor()) else: - raise ValueError("Unsupported type %s." % type(tensor)) + try: + return tensor.numpy() + except AttributeError as e: + six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) def _eval_helper(self, tensors): if tensors is None: @@ -1334,11 +1273,11 @@ class TensorFlowTestCase(googletest.TestCase): b, rtol=rtol, atol=atol, - msg="Mismatched value: a%s is different from b%s." % (path_str, - path_str)) + msg=("Mismatched value: a%s is different from b%s. %s" % + (path_str, path_str, msg))) except TypeError as e: - msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a), - path_str, type(b)) + msg = ("Error: a%s has %s, but b%s has %s. %s" % + (path_str, type(a), path_str, type(b), msg)) e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) raise |