aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py183
1 files changed, 61 insertions, 122 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b56483f373..2bc2a189fa 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -27,6 +27,7 @@ import random
import re
import tempfile
import threading
+import unittest
import numpy as np
import six
@@ -61,13 +62,13 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export
@@ -321,32 +322,6 @@ def NCHWToNHWC(input_tensor):
return [input_tensor[a] for a in new_axes[ndims]]
-# TODO(skyewm): remove this eventually
-# pylint: disable=protected-access
-def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
- prev_value = ops._USE_C_API
- ops._USE_C_API = use_c_api
- try:
- # Reset the default graph so it has the C API enabled. We call
- # reset_default_graph() instead of creating a new default Graph context to
- # make this robust to tests that call reset_default_graph(), which requires
- # that the current default graph isn't nested.
- ops.reset_default_graph()
- fn(*args, **kwargs)
- finally:
- ops._USE_C_API = prev_value
- # Make sure default graph reflects prev_value in case next test doesn't call
- # reset_default_graph().
- ops.reset_default_graph()
-
-
-# pylint: disable=protected-access
-
-
-def c_api_and_cuda_enabled():
- return ops._USE_C_API and IsGoogleCudaEnabled()
-
-
def skip_if(condition):
"""Skips the decorated function if condition is or evaluates to True.
@@ -372,46 +347,6 @@ def skip_if(condition):
return real_skip_if
-# TODO(skyewm): remove this eventually
-def disable_c_api(fn):
- """Decorator for disabling the C API on a test.
-
- Note this disables the C API after running the test class's setup/teardown
- methods.
-
- Args:
- fn: the function to be wrapped
-
- Returns:
- The wrapped function
- """
-
- def wrapper(*args, **kwargs):
- _use_c_api_wrapper(fn, False, *args, **kwargs)
-
- return wrapper
-
-
-# TODO(skyewm): remove this eventually
-def enable_c_api(fn):
- """Decorator for enabling the C API on a test.
-
- Note this enables the C API after running the test class's setup/teardown
- methods.
-
- Args:
- fn: the function to be wrapped
-
- Returns:
- The wrapped function
- """
-
- def wrapper(*args, **kwargs):
- _use_c_api_wrapper(fn, True, *args, **kwargs)
-
- return wrapper
-
-
def enable_c_shapes(fn):
"""Decorator for enabling C shapes on a test.
@@ -425,46 +360,19 @@ def enable_c_shapes(fn):
The wrapped function
"""
+ # pylint: disable=protected-access
def wrapper(*args, **kwargs):
prev_value = ops._USE_C_SHAPES
- # Only use C shapes if the C API is already enabled.
- ops._USE_C_SHAPES = ops._USE_C_API
+ ops._USE_C_SHAPES = True
try:
fn(*args, **kwargs)
finally:
ops._USE_C_SHAPES = prev_value
+ # pylint: enable=protected-access
return wrapper
-# This decorator is a hacky way to run all the test methods in a decorated
-# class with and without C API enabled.
-# TODO(iga): Remove this and its uses once we switch to using C API by default.
-def with_c_api(cls):
- """Adds methods that call original methods but with C API enabled.
-
- Note this enables the C API in new methods after running the test class's
- setup method. This can be a problem if some objects are created in it
- before the C API is enabled.
-
- Args:
- cls: class to decorate
-
- Returns:
- cls with new test methods added
- """
- # If the C API is already enabled, don't do anything. Some tests break if the
- # same test is run twice, so this allows us to turn on the C API by default
- # without breaking these tests.
- if ops._USE_C_API:
- return cls
-
- for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCApi", enable_c_api(value))
- return cls
-
-
def with_c_shapes(cls):
"""Adds methods that call original methods but with C API shapes enabled.
@@ -507,8 +415,28 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
+ collection_sizes_before = {
+ collection: len(ops.get_collection(collection))
+ for collection in ops.get_default_graph().collections}
for _ in range(3):
f(self, **kwargs)
+ # Note that gc.get_objects misses anything that isn't subject to garbage
+ # collection (C types). Collections are a common source of leaks, so we
+ # test for collection sizes explicitly.
+ for collection_key in ops.get_default_graph().collections:
+ collection = ops.get_collection(collection_key)
+ size_before = collection_sizes_before.get(collection_key, 0)
+ if len(collection) > size_before:
+ raise AssertionError(
+ ("Collection %s increased in size from "
+ "%d to %d (current items %s).")
+ % (collection_key, size_before, len(collection), collection))
+ # Make sure our collection checks don't show up as leaked memory by
+ # removing references to temporary variables.
+ del collection
+ del collection_key
+ del size_before
+ del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
@@ -644,14 +572,15 @@ def assert_no_garbage_created(f):
def run_all_in_graph_and_eager_modes(cls):
- base_decorator = run_in_graph_and_eager_modes()
+ """Execute all test methods in the given class with and without eager."""
+ base_decorator = run_in_graph_and_eager_modes
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
setattr(cls, name, base_decorator(value))
return cls
-def run_in_graph_and_eager_modes(__unused__=None,
+def run_in_graph_and_eager_modes(func=None,
config=None,
use_gpu=True,
reset_test=True,
@@ -669,7 +598,7 @@ def run_in_graph_and_eager_modes(__unused__=None,
```python
class MyTests(tf.test.TestCase):
- @run_in_graph_and_eager_modes()
+ @run_in_graph_and_eager_modes
def test_foo(self):
x = tf.constant([1, 2])
y = tf.constant([3, 4])
@@ -686,7 +615,9 @@ def run_in_graph_and_eager_modes(__unused__=None,
Args:
- __unused__: Prevents silently skipping tests.
+ func: function to be annotated. If `func` is None, this method returns a
+ decorator the can be applied to a function. If `func` is not None this
+ returns the decorator applied to `func`.
config: An optional config_pb2.ConfigProto to use to configure the
session when executing graphs.
use_gpu: If True, attempt to run as many operations as possible on GPU.
@@ -708,20 +639,19 @@ def run_in_graph_and_eager_modes(__unused__=None,
eager execution enabled.
"""
- assert not __unused__, "Add () after run_in_graph_and_eager_modes."
-
def decorator(f):
- def decorated(self, **kwargs):
- with context.graph_mode():
- with self.test_session(use_gpu=use_gpu):
- f(self, **kwargs)
+ if tf_inspect.isclass(f):
+ raise ValueError(
+ "`run_test_in_graph_and_eager_modes` only supports test methods. "
+ "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?")
- if reset_test:
- # This decorator runs the wrapped test twice.
- # Reset the test environment between runs.
- self.tearDown()
- self._tempdir = None
- self.setUp()
+ def decorated(self, **kwargs):
+ try:
+ with context.graph_mode():
+ with self.test_session(use_gpu=use_gpu, config=config):
+ f(self, **kwargs)
+ except unittest.case.SkipTest:
+ pass
def run_eagerly(self, **kwargs):
if not use_gpu:
@@ -736,10 +666,20 @@ def run_in_graph_and_eager_modes(__unused__=None,
assert_no_garbage_created(run_eagerly))
with context.eager_mode():
+ if reset_test:
+ # This decorator runs the wrapped test twice.
+ # Reset the test environment between runs.
+ self.tearDown()
+ self._tempdir = None
+ self.setUp()
+
run_eagerly(self, **kwargs)
return decorated
+ if func is not None:
+ return decorator(func)
+
return decorator
@@ -922,14 +862,13 @@ class TensorFlowTestCase(googletest.TestCase):
def _eval_tensor(self, tensor):
if tensor is None:
return None
- elif isinstance(tensor, ops.EagerTensor):
- return tensor.numpy()
- elif isinstance(tensor, resource_variable_ops.ResourceVariable):
- return tensor.read_value().numpy()
elif callable(tensor):
return self._eval_helper(tensor())
else:
- raise ValueError("Unsupported type %s." % type(tensor))
+ try:
+ return tensor.numpy()
+ except AttributeError as e:
+ six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e)
def _eval_helper(self, tensors):
if tensors is None:
@@ -1334,11 +1273,11 @@ class TensorFlowTestCase(googletest.TestCase):
b,
rtol=rtol,
atol=atol,
- msg="Mismatched value: a%s is different from b%s." % (path_str,
- path_str))
+ msg=("Mismatched value: a%s is different from b%s. %s" %
+ (path_str, path_str, msg)))
except TypeError as e:
- msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
- path_str, type(b))
+ msg = ("Error: a%s has %s, but b%s has %s. %s" %
+ (path_str, type(a), path_str, type(b), msg))
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
raise