aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tom Hennigan <tomhennigan@google.com>2018-06-21 12:56:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 12:59:27 -0700
commit505d2018a34bbffbca7a17dc47ea968787938174 (patch)
treee8aa7b681343eab943bbaded55b4bd2ccfd7fe5d
parent846520326327d5eb8e1be13cad7d5526adf67db2 (diff)
Allow run_in_graph_and_eager_modes annotation without ().
PiperOrigin-RevId: 201571378
-rw-r--r--tensorflow/python/framework/test_util.py17
-rw-r--r--tensorflow/python/framework/test_util_test.py27
2 files changed, 34 insertions, 10 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 708ab1707e..3988238609 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -552,14 +552,14 @@ def assert_no_garbage_created(f):
def run_all_in_graph_and_eager_modes(cls):
"""Execute all test methods in the given class with and without eager."""
- base_decorator = run_in_graph_and_eager_modes()
+ base_decorator = run_in_graph_and_eager_modes
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
setattr(cls, name, base_decorator(value))
return cls
-def run_in_graph_and_eager_modes(__unused__=None,
+def run_in_graph_and_eager_modes(func=None,
config=None,
use_gpu=True,
reset_test=True,
@@ -577,7 +577,7 @@ def run_in_graph_and_eager_modes(__unused__=None,
```python
class MyTests(tf.test.TestCase):
- @run_in_graph_and_eager_modes()
+ @run_in_graph_and_eager_modes
def test_foo(self):
x = tf.constant([1, 2])
y = tf.constant([3, 4])
@@ -594,7 +594,9 @@ def run_in_graph_and_eager_modes(__unused__=None,
Args:
- __unused__: Prevents silently skipping tests.
+ func: function to be annotated. If `func` is None, this method returns a
+ decorator the can be applied to a function. If `func` is not None this
+ returns the decorator applied to `func`.
config: An optional config_pb2.ConfigProto to use to configure the
session when executing graphs.
use_gpu: If True, attempt to run as many operations as possible on GPU.
@@ -616,8 +618,6 @@ def run_in_graph_and_eager_modes(__unused__=None,
eager execution enabled.
"""
- assert not __unused__, "Add () after run_in_graph_and_eager_modes."
-
def decorator(f):
if tf_inspect.isclass(f):
raise ValueError(
@@ -626,7 +626,7 @@ def run_in_graph_and_eager_modes(__unused__=None,
def decorated(self, **kwargs):
with context.graph_mode():
- with self.test_session(use_gpu=use_gpu):
+ with self.test_session(use_gpu=use_gpu, config=config):
f(self, **kwargs)
if reset_test:
@@ -653,6 +653,9 @@ def run_in_graph_and_eager_modes(__unused__=None,
return decorated
+ if func is not None:
+ return decorator(func)
+
return decorator
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 2a7cf88d6e..5498376181 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -569,7 +569,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(a_np_rand, b_np_rand)
self.assertEqual(a_rand, b_rand)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_callable_evaluate(self):
def model():
return resource_variable_ops.ResourceVariable(
@@ -578,7 +578,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with context.eager_mode():
self.assertEqual(2, self.evaluate(model))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_nested_tensors_evaluate(self):
expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
nested = {"a": constant_op.constant(1),
@@ -588,6 +588,27 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(expected, self.evaluate(nested))
+ def test_run_in_graph_and_eager_modes(self):
+ l = []
+ def inc(self, with_brackets):
+ del self # self argument is required by run_in_graph_and_eager_modes.
+ mode = "eager" if context.executing_eagerly() else "graph"
+ with_brackets = "with_brackets" if with_brackets else "without_brackets"
+ l.append((with_brackets, mode))
+
+ f = test_util.run_in_graph_and_eager_modes(inc)
+ f(self, with_brackets=False)
+ f = test_util.run_in_graph_and_eager_modes()(inc)
+ f(self, with_brackets=True)
+
+ self.assertEqual(len(l), 4)
+ self.assertEqual(set(l), {
+ ("with_brackets", "graph"),
+ ("with_brackets", "eager"),
+ ("without_brackets", "graph"),
+ ("without_brackets", "eager"),
+ })
+
def test_get_node_def_from_graph(self):
graph_def = graph_pb2.GraphDef()
node_foo = graph_def.node.add()
@@ -627,7 +648,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
ReferenceCycleTest().test_has_no_cycle()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_no_leaked_tensor_decorator(self):
class LeakedTensorTest(object):