aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-06-14 11:23:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 11:25:43 -0700
commit3d7b33f7576216adeb6ea345dc2b41bc921fcf52 (patch)
tree7518dd220279f8ebbe7991aa9580bcac36f3668d
parentf596bcc78639bb59894fd8e97779e6f53eeef190 (diff)
Make it possible to retrieve the variables used in a defined function.
Creates a class that encapsulates the graph functions created for a particular Python function. This class has a `.variables` property that fetches the variables used in any of the graph functions defined for the Python function. The class is internal for now. PiperOrigin-RevId: 200588595
-rw-r--r--tensorflow/python/eager/function.py76
-rw-r--r--tensorflow/python/eager/function_test.py17
2 files changed, 65 insertions, 28 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 03393bcd46..dd3166735c 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -222,6 +222,11 @@ def _inference_name(n):
return "__inference_%s_%s" % (n, ops.uid())
+def _register(fn):
+ """Registers the function `fn`."""
+ context.context().add_function(fn)
+
+
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
@@ -591,7 +596,7 @@ def _get_defun_inputs(args):
return nest.pack_sequence_as(args, ret)
-def _defun_internal(name, func, compiled, args, kwds):
+def _trace_and_define_function(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with context.graph_mode():
@@ -699,42 +704,57 @@ def _cache_key(x):
return x
-def _register(fn):
- """Registers the function `fn`."""
- context.context().add_function(fn)
+class _PolymorphicFunction(object):
+ """Wrapper class for the graph functions defined for a Python function.
+ See the documentation for `defun` for more information on the semantics of
+ defined functions.
+ """
-# TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name, compiled=False):
- """Defines a function with a given name.
+ def __init__(self, python_function, name, compiled=False):
+ """Initializes a polymorphic function.
- See the documentation for `defun` for more information on the semantics of
- this function.
+ Args:
+ python_function: the function to be wrapped.
+ name: the name given to it.
+ compiled: if True, the framework will attempt to compile func with XLA.
+ """
- Args:
- func: the function to be wrapped.
- name: the name given to it.
- compiled: if true, the framework will attempt to compile func with XLA.
+ self._python_function = python_function
+ self._name = name
+ self._compiled = compiled
+ self._arguments_to_functions = {}
+ self._variables = []
- Returns:
- the wrapped function.
- """
- arguments_to_functions = {}
+ def _maybe_define_function(self, *args, **kwds):
+ """Gets a function for these inputs, defining it if necessary."""
- def decorated(*args, **kwds):
- """Decorated version of func."""
- # Macroexpand on non-Tensor arguments
- cache_key = tuple(_cache_key(x) for x in args)
+ # TODO(akshayka): Remove this restriction.
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
+
+ # TODO(apassos): Better error messages for non-hashable arguments.
+ cache_key = tuple(_cache_key(x) for x in args)
cache_key = (cache_key, tuple(kwds.items()))
- if cache_key not in arguments_to_functions:
- arguments_to_functions[cache_key] = _defun_internal(
- name, func, compiled, args, kwds)
- return arguments_to_functions[cache_key](*args)
+ if cache_key not in self._arguments_to_functions:
+ graph_function = _trace_and_define_function(
+ self._name, self._python_function, self._compiled, args, kwds)
+ self._arguments_to_functions[cache_key] = graph_function
+ self._variables.extend(
+ [v for v in graph_function.variables if v not in self._variables])
+ return graph_function
+ else:
+ return self._arguments_to_functions[cache_key]
- return decorated
+ def __call__(self, *args, **kwds):
+ """Calls a graph function specialized for this input signature."""
+ return self._maybe_define_function(*args, **kwds)(*args)
+
+ @property
+ def variables(self):
+ """Returns a list of variables used in any of the defined functions."""
+ return self._variables
# TODO(akshayka): Remove the `compiled` flag and create a separate
@@ -991,7 +1011,7 @@ def defun(func=None, compiled=False):
except AttributeError:
name = "function"
return tf_decorator.make_decorator(
- function, named_defun(function, name, compiled=compiled))
+ function, _PolymorphicFunction(function, name, compiled=compiled))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1056,7 +1076,7 @@ def make_defun_op(func, *args, **kwds):
name = func.__name__
if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
raise ValueError("Tensor keyword arguments are not supported.")
- return _defun_internal(name, func, False, args, kwds)
+ return _trace_and_define_function(name, func, False, args, kwds)
class AutomaticControlDependencies(object):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index cfdbe5f079..6ce2ceffda 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -633,6 +633,23 @@ class FunctionTest(test.TestCase):
y = model(x)
self.assertAllEqual([[[[4.0]]]], y.numpy())
+ def testVariablesAreTracked(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def foo(x):
+ return v * x
+
+ defined = function.defun(foo)
+
+ x = constant_op.constant([1.0])
+ self.assertAllEqual(defined.variables, [])
+ _ = defined(x)
+ self.assertAllEqual(defined.variables, [v])
+
+ x = constant_op.constant([1.0, 2.0])
+ _ = defined(x) # ensure the variables list remains the same
+ self.assertAllEqual(defined.variables, [v])
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):