aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-06-29 12:00:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 12:03:18 -0700
commitf92a5fecad1e4e7a41ffa4c3308ab3049ef708bf (patch)
treeff4282c6127a779bfe00c402d4c498e46d50285b /tensorflow
parent05571105255e2091ddfa99531cb9df33bd251d3b (diff)
Add a method that calls the python function backing a _PolymorphicFunction.
This is at the least useful for testing behavioral differences between a wrapped Python function and the corresponding graph functions. Prior to this change, decorating a Python function with `@function.defun` would render the Python function inaccessible. PiperOrigin-RevId: 202685407
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/eager/function.py4
-rw-r--r--tensorflow/python/eager/function_test.py19
2 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 7edcb0931d..08470f65b0 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -801,6 +801,10 @@ class _PolymorphicFunction(object):
graph_function, inputs = self._maybe_define_function(*args, **kwds)
return graph_function(*inputs)
+ def call_python_function(self, *args, **kwargs):
+ """Directly calls the wrapped python function."""
+ return self._python_function(*args, **kwargs)
+
@property
def variables(self):
"""Returns a list of variables used in any of the defined functions."""
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index cf32f6e7fb..e1801b7ec6 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -829,6 +829,25 @@ class FunctionTest(test.TestCase):
out = foo.two(t)
self.assertEqual(float(out), 1.0)
+ def testPythonCallWithSideEffects(self):
+ state = []
+
+ @function.defun
+ def side_effecting_function():
+ state.append(0)
+
+ side_effecting_function()
+ self.assertAllEqual(state, [0])
+
+ # The second invocation should call the graph function, which shouldn't
+ # trigger the list append.
+ side_effecting_function()
+ self.assertAllEqual(state, [0])
+
+ # Whereas calling the python function directly should create a side-effect.
+ side_effecting_function.call_python_function()
+ self.assertAllEqual(state, [0, 0])
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):