diff options
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 708ab1707e..3988238609 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -552,14 +552,14 @@ def assert_no_garbage_created(f): def run_all_in_graph_and_eager_modes(cls): """Execute all test methods in the given class with and without eager.""" - base_decorator = run_in_graph_and_eager_modes() + base_decorator = run_in_graph_and_eager_modes for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): setattr(cls, name, base_decorator(value)) return cls -def run_in_graph_and_eager_modes(__unused__=None, +def run_in_graph_and_eager_modes(func=None, config=None, use_gpu=True, reset_test=True, @@ -577,7 +577,7 @@ def run_in_graph_and_eager_modes(__unused__=None, ```python class MyTests(tf.test.TestCase): - @run_in_graph_and_eager_modes() + @run_in_graph_and_eager_modes def test_foo(self): x = tf.constant([1, 2]) y = tf.constant([3, 4]) @@ -594,7 +594,9 @@ def run_in_graph_and_eager_modes(__unused__=None, Args: - __unused__: Prevents silently skipping tests. + func: function to be annotated. If `func` is None, this method returns a + decorator the can be applied to a function. If `func` is not None this + returns the decorator applied to `func`. config: An optional config_pb2.ConfigProto to use to configure the session when executing graphs. use_gpu: If True, attempt to run as many operations as possible on GPU. @@ -616,8 +618,6 @@ def run_in_graph_and_eager_modes(__unused__=None, eager execution enabled. """ - assert not __unused__, "Add () after run_in_graph_and_eager_modes." - def decorator(f): if tf_inspect.isclass(f): raise ValueError( @@ -626,7 +626,7 @@ def run_in_graph_and_eager_modes(__unused__=None, def decorated(self, **kwargs): with context.graph_mode(): - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, config=config): f(self, **kwargs) if reset_test: @@ -653,6 +653,9 @@ def run_in_graph_and_eager_modes(__unused__=None, return decorated + if func is not None: + return decorator(func) + return decorator |