diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-06-14 11:23:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-14 11:25:43 -0700 |
commit | 3d7b33f7576216adeb6ea345dc2b41bc921fcf52 (patch) | |
tree | 7518dd220279f8ebbe7991aa9580bcac36f3668d | |
parent | f596bcc78639bb59894fd8e97779e6f53eeef190 (diff) |
Make it possible to retrieve the variables used in a defined function.
Creates a class that encapsulates the graph functions created for a particular
Python function. This class has a `.variables` property that fetches the
variables used in any of the graph functions defined for the Python function.
The class is internal for now.
PiperOrigin-RevId: 200588595
-rw-r--r-- | tensorflow/python/eager/function.py | 76 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 17 |
2 files changed, 65 insertions, 28 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 03393bcd46..dd3166735c 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -222,6 +222,11 @@ def _inference_name(n): return "__inference_%s_%s" % (n, ops.uid()) +def _register(fn): + """Registers the function `fn`.""" + context.context().add_function(fn) + + # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction # so it doesn't have the definition-generating logic and is just a container for # an already-defined function. @@ -591,7 +596,7 @@ def _get_defun_inputs(args): return nest.pack_sequence_as(args, ret) -def _defun_internal(name, func, compiled, args, kwds): +def _trace_and_define_function(name, func, compiled, args, kwds): """Defines and returns graph-mode version of func.""" graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access with context.graph_mode(): @@ -699,42 +704,57 @@ def _cache_key(x): return x -def _register(fn): - """Registers the function `fn`.""" - context.context().add_function(fn) +class _PolymorphicFunction(object): + """Wrapper class for the graph functions defined for a Python function. + See the documentation for `defun` for more information on the semantics of + defined functions. + """ -# TODO(apassos): better error messages for non-hashable arguments. -def named_defun(func, name, compiled=False): - """Defines a function with a given name. + def __init__(self, python_function, name, compiled=False): + """Initializes a polymorphic function. - See the documentation for `defun` for more information on the semantics of - this function. + Args: + python_function: the function to be wrapped. + name: the name given to it. + compiled: if True, the framework will attempt to compile func with XLA. + """ - Args: - func: the function to be wrapped. - name: the name given to it. - compiled: if true, the framework will attempt to compile func with XLA. + self._python_function = python_function + self._name = name + self._compiled = compiled + self._arguments_to_functions = {} + self._variables = [] - Returns: - the wrapped function. - """ - arguments_to_functions = {} + def _maybe_define_function(self, *args, **kwds): + """Gets a function for these inputs, defining it if necessary.""" - def decorated(*args, **kwds): - """Decorated version of func.""" - # Macroexpand on non-Tensor arguments - cache_key = tuple(_cache_key(x) for x in args) + # TODO(akshayka): Remove this restriction. if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): raise ValueError("Tensor keyword arguments are not supported.") + + # TODO(apassos): Better error messages for non-hashable arguments. + cache_key = tuple(_cache_key(x) for x in args) cache_key = (cache_key, tuple(kwds.items())) - if cache_key not in arguments_to_functions: - arguments_to_functions[cache_key] = _defun_internal( - name, func, compiled, args, kwds) - return arguments_to_functions[cache_key](*args) + if cache_key not in self._arguments_to_functions: + graph_function = _trace_and_define_function( + self._name, self._python_function, self._compiled, args, kwds) + self._arguments_to_functions[cache_key] = graph_function + self._variables.extend( + [v for v in graph_function.variables if v not in self._variables]) + return graph_function + else: + return self._arguments_to_functions[cache_key] - return decorated + def __call__(self, *args, **kwds): + """Calls a graph function specialized for this input signature.""" + return self._maybe_define_function(*args, **kwds)(*args) + + @property + def variables(self): + """Returns a list of variables used in any of the defined functions.""" + return self._variables # TODO(akshayka): Remove the `compiled` flag and create a separate @@ -991,7 +1011,7 @@ def defun(func=None, compiled=False): except AttributeError: name = "function" return tf_decorator.make_decorator( - function, named_defun(function, name, compiled=compiled)) + function, _PolymorphicFunction(function, name, compiled=compiled)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: @@ -1056,7 +1076,7 @@ def make_defun_op(func, *args, **kwds): name = func.__name__ if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): raise ValueError("Tensor keyword arguments are not supported.") - return _defun_internal(name, func, False, args, kwds) + return _trace_and_define_function(name, func, False, args, kwds) class AutomaticControlDependencies(object): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index cfdbe5f079..6ce2ceffda 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -633,6 +633,23 @@ class FunctionTest(test.TestCase): y = model(x) self.assertAllEqual([[[[4.0]]]], y.numpy()) + def testVariablesAreTracked(self): + v = resource_variable_ops.ResourceVariable(1.0) + + def foo(x): + return v * x + + defined = function.defun(foo) + + x = constant_op.constant([1.0]) + self.assertAllEqual(defined.variables, []) + _ = defined(x) + self.assertAllEqual(defined.variables, [v]) + + x = constant_op.constant([1.0, 2.0]) + _ = defined(x) # ensure the variables list remains the same + self.assertAllEqual(defined.variables, [v]) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): |