diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-06-29 12:00:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-29 12:03:18 -0700 |
commit | f92a5fecad1e4e7a41ffa4c3308ab3049ef708bf (patch) | |
tree | ff4282c6127a779bfe00c402d4c498e46d50285b /tensorflow | |
parent | 05571105255e2091ddfa99531cb9df33bd251d3b (diff) |
Add a method that calls the python function backing a _PolymorphicFunction.
This is at the least useful for testing behavioral differences between
a wrapped Python function and the corresponding graph functions. Prior to this change,
decorating a Python function with `@function.defun` would render the Python function
inaccessible.
PiperOrigin-RevId: 202685407
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/eager/function.py | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 19 |
2 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 7edcb0931d..08470f65b0 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -801,6 +801,10 @@ class _PolymorphicFunction(object): graph_function, inputs = self._maybe_define_function(*args, **kwds) return graph_function(*inputs) + def call_python_function(self, *args, **kwargs): + """Directly calls the wrapped python function.""" + return self._python_function(*args, **kwargs) + @property def variables(self): """Returns a list of variables used in any of the defined functions.""" diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index cf32f6e7fb..e1801b7ec6 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -829,6 +829,25 @@ class FunctionTest(test.TestCase): out = foo.two(t) self.assertEqual(float(out), 1.0) + def testPythonCallWithSideEffects(self): + state = [] + + @function.defun + def side_effecting_function(): + state.append(0) + + side_effecting_function() + self.assertAllEqual(state, [0]) + + # The second invocation should call the graph function, which shouldn't + # trigger the list append. + side_effecting_function() + self.assertAllEqual(state, [0]) + + # Whereas calling the python function directly should create a side-effect. + side_effecting_function.call_python_function() + self.assertAllEqual(state, [0, 0]) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): |