aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 708ab1707e..3988238609 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -552,14 +552,14 @@ def assert_no_garbage_created(f):
def run_all_in_graph_and_eager_modes(cls):
"""Execute all test methods in the given class with and without eager."""
- base_decorator = run_in_graph_and_eager_modes()
+ base_decorator = run_in_graph_and_eager_modes
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
setattr(cls, name, base_decorator(value))
return cls
-def run_in_graph_and_eager_modes(__unused__=None,
+def run_in_graph_and_eager_modes(func=None,
config=None,
use_gpu=True,
reset_test=True,
@@ -577,7 +577,7 @@ def run_in_graph_and_eager_modes(__unused__=None,
```python
class MyTests(tf.test.TestCase):
- @run_in_graph_and_eager_modes()
+ @run_in_graph_and_eager_modes
def test_foo(self):
x = tf.constant([1, 2])
y = tf.constant([3, 4])
@@ -594,7 +594,9 @@ def run_in_graph_and_eager_modes(__unused__=None,
Args:
- __unused__: Prevents silently skipping tests.
+ func: function to be annotated. If `func` is None, this method returns a
+ decorator the can be applied to a function. If `func` is not None this
+ returns the decorator applied to `func`.
config: An optional config_pb2.ConfigProto to use to configure the
session when executing graphs.
use_gpu: If True, attempt to run as many operations as possible on GPU.
@@ -616,8 +618,6 @@ def run_in_graph_and_eager_modes(__unused__=None,
eager execution enabled.
"""
- assert not __unused__, "Add () after run_in_graph_and_eager_modes."
-
def decorator(f):
if tf_inspect.isclass(f):
raise ValueError(
@@ -626,7 +626,7 @@ def run_in_graph_and_eager_modes(__unused__=None,
def decorated(self, **kwargs):
with context.graph_mode():
- with self.test_session(use_gpu=use_gpu):
+ with self.test_session(use_gpu=use_gpu, config=config):
f(self, **kwargs)
if reset_test:
@@ -653,6 +653,9 @@ def run_in_graph_and_eager_modes(__unused__=None,
return decorated
+ if func is not None:
+ return decorator(func)
+
return decorator