diff options
author | 2018-06-21 12:56:14 -0700 | |
---|---|---|
committer | 2018-06-21 12:59:27 -0700 | |
commit | 505d2018a34bbffbca7a17dc47ea968787938174 (patch) | |
tree | e8aa7b681343eab943bbaded55b4bd2ccfd7fe5d | |
parent | 846520326327d5eb8e1be13cad7d5526adf67db2 (diff) |
Allow run_in_graph_and_eager_modes annotation without ().
PiperOrigin-RevId: 201571378
-rw-r--r-- | tensorflow/python/framework/test_util.py | 17 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util_test.py | 27 |
2 files changed, 34 insertions, 10 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 708ab1707e..3988238609 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -552,14 +552,14 @@ def assert_no_garbage_created(f): def run_all_in_graph_and_eager_modes(cls): """Execute all test methods in the given class with and without eager.""" - base_decorator = run_in_graph_and_eager_modes() + base_decorator = run_in_graph_and_eager_modes for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): setattr(cls, name, base_decorator(value)) return cls -def run_in_graph_and_eager_modes(__unused__=None, +def run_in_graph_and_eager_modes(func=None, config=None, use_gpu=True, reset_test=True, @@ -577,7 +577,7 @@ def run_in_graph_and_eager_modes(__unused__=None, ```python class MyTests(tf.test.TestCase): - @run_in_graph_and_eager_modes() + @run_in_graph_and_eager_modes def test_foo(self): x = tf.constant([1, 2]) y = tf.constant([3, 4]) @@ -594,7 +594,9 @@ def run_in_graph_and_eager_modes(__unused__=None, Args: - __unused__: Prevents silently skipping tests. + func: function to be annotated. If `func` is None, this method returns a + decorator the can be applied to a function. If `func` is not None this + returns the decorator applied to `func`. config: An optional config_pb2.ConfigProto to use to configure the session when executing graphs. use_gpu: If True, attempt to run as many operations as possible on GPU. @@ -616,8 +618,6 @@ def run_in_graph_and_eager_modes(__unused__=None, eager execution enabled. """ - assert not __unused__, "Add () after run_in_graph_and_eager_modes." - def decorator(f): if tf_inspect.isclass(f): raise ValueError( @@ -626,7 +626,7 @@ def run_in_graph_and_eager_modes(__unused__=None, def decorated(self, **kwargs): with context.graph_mode(): - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, config=config): f(self, **kwargs) if reset_test: @@ -653,6 +653,9 @@ def run_in_graph_and_eager_modes(__unused__=None, return decorated + if func is not None: + return decorator(func) + return decorator diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 2a7cf88d6e..5498376181 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -569,7 +569,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(a_np_rand, b_np_rand) self.assertEqual(a_rand, b_rand) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_callable_evaluate(self): def model(): return resource_variable_ops.ResourceVariable( @@ -578,7 +578,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): with context.eager_mode(): self.assertEqual(2, self.evaluate(model)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_tensors_evaluate(self): expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}} nested = {"a": constant_op.constant(1), @@ -588,6 +588,27 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(expected, self.evaluate(nested)) + def test_run_in_graph_and_eager_modes(self): + l = [] + def inc(self, with_brackets): + del self # self argument is required by run_in_graph_and_eager_modes. + mode = "eager" if context.executing_eagerly() else "graph" + with_brackets = "with_brackets" if with_brackets else "without_brackets" + l.append((with_brackets, mode)) + + f = test_util.run_in_graph_and_eager_modes(inc) + f(self, with_brackets=False) + f = test_util.run_in_graph_and_eager_modes()(inc) + f(self, with_brackets=True) + + self.assertEqual(len(l), 4) + self.assertEqual(set(l), { + ("with_brackets", "graph"), + ("with_brackets", "eager"), + ("without_brackets", "graph"), + ("without_brackets", "eager"), + }) + def test_get_node_def_from_graph(self): graph_def = graph_pb2.GraphDef() node_foo = graph_def.node.add() @@ -627,7 +648,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): ReferenceCycleTest().test_has_no_cycle() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_no_leaked_tensor_decorator(self): class LeakedTensorTest(object): |