From 84ace0358526bb51c04a3bef4b3072b93b9d1bec Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Tue, 9 Oct 2018 11:16:32 -0700 Subject: Improves tf.function prototype. Specifically: - renames from def_function - returns an object with well-defined methods - doesn't force-retrace twice - uses the python descriptor API ( https://docs.python.org/3/howto/descriptor.html ) to remove the need for a tf.method PiperOrigin-RevId: 216388957 --- tensorflow/python/eager/def_function.py | 188 ++++++++++++++++++++++----- tensorflow/python/eager/def_function_test.py | 32 ++++- 2 files changed, 179 insertions(+), 41 deletions(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 8dcacd5c99..b23891d394 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -19,8 +19,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import weakref + from tensorflow.python.eager import context -from tensorflow.python.eager import function +from tensorflow.python.eager import function as function_lib from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops @@ -165,71 +168,184 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): self._cached_shape_as_list = None -def _defun_with_scope(scope, fn): +def _defun_with_scope(scope, fn, input_signature): def wrapped_fn(*args, **kwds): with variable_scope.variable_creator_scope(scope): return fn(*args, **kwds) - return function.defun(wrapped_fn) + return function_lib.defun(wrapped_fn, input_signature=input_signature) -def def_function(fn): - """Defines a function as per the "functions, not sessions" document.""" +def _call_concrete(fn, args, unused_kwargs): + """Calls the given concrete function with only the tensor arguments.""" + + def inner(): + # TODO(apassos) figure out what to do with kwargs and concrete functions. + return fn(*[x for x in args if isinstance(x, ops.Tensor)]) + + return inner + + +class PolymorphicFunction(object): + """Wrapper class for the graph functions defined for a Python function. + + See the documentation for `tf.function` for more information on the semantics + of defined functions. - # Wrapping the values in lists to bypass python's lack of way to mutate - # symbols from an outer scope. - first_call = [True] - function_to_call = [] + PolymorphicFunction is thread-compatible. + """ + + def __init__(self, + python_function, + input_signature=None,): + """Initializes a polymorphic function. + + Args: + python_function: the function to be wrapped. + input_signature: a possibly nested sequence of `TensorSpec` objects + specifying the input signature of this function. If `None`, a separate + function is instantiated for each inferred input signature. + + Raises: + ValueError: if `input_signature` is not None and the `python_function`'s + argspec has keyword arguments. + """ + self._python_function = python_function + self._input_signature = input_signature + self._created_variables = None + self._stateful_fn = None + self._descriptor_cache = weakref.WeakKeyDictionary() - # TODO(apassos) represent this as an object and not as a closure. - def decorated_fn(*args, **kwds): - """Graph function for fn.""" - if not first_call[0]: - return function_to_call[0](*args, **kwds) + def _initialize(self, args, kwds): + """Initializes, on the first call.""" - first_call[0] = False - created_variables = [] + self._created_variables = [] - def variable_creator_scope(unused_next_creator, **kwds): + def variable_capturing_scope(unused_next_creator, **kwds): """Creates UnliftedInitializerVariables and saves references to them.""" v = UnliftedInitializerVariable(**kwds) - created_variables.append(v) + self._created_variables.append(v) return v - first_graph_function = _defun_with_scope(variable_creator_scope, fn) + self._stateful_fn = _defun_with_scope( + variable_capturing_scope, self._python_function, self._input_signature) # Force the definition of the function for these arguments - first_concrete = first_graph_function.get_concrete_function(*args, **kwds) + self._concrete_stateful_fn = self._stateful_fn.get_concrete_function( + *args, **kwds) def invalid_creator_scope(*unused_args, **unused_kwds): """Disables variable creation.""" raise ValueError( - "def_function-decorated function tried to create " - "variables on second call.") + "tf.function-decorated function tried to create " + "variables on non-first call.") - second_graph_function = _defun_with_scope(invalid_creator_scope, fn) + self._stateless_fn = _defun_with_scope( + invalid_creator_scope, self._python_function, self._input_signature) - function_to_call.append(second_graph_function) - if not created_variables: - # Note: this retracing might be unnecessary, but running the function - # forever in the scope which disallows variable creation is safer than not - # doing so. - return second_graph_function(*args, **kwds) + def __call__(self, *args, **kwds): + """Calls the graph function.""" + if self._created_variables: + # In this case we have created variables on the first call, so we run the + # defunned version which is guaranteed to never create variables. + return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable + elif self._stateful_fn is not None: + # In this case we have not created variables on the first call. So we can + # run the first trace but we should fail if variables are created. + results = self._first_trace(*args, **kwds) + if self._created_variables: + raise ValueError("Creating variables on a non-first call to a function" + " decorated with tf.function.") + return results + + self._initialize(args, kwds) + + if not self._created_variables: + # If we did not create any variables the trace we have is good enough. + return _call_concrete(self._concrete_stateful_fn, args, kwds)() def fn_with_cond(*inner_args, **inner_kwds): """Conditionally runs initialization if it's needed.""" condition = True - for variable in created_variables: + for variable in self._created_variables: condition = condition and resource_variable_ops.var_is_initialized_op( variable.handle) - # We want to call second_graph_function if possible because it avoids - # recomputing potentially expensive initializers. + # We want to call stateless_fn if possible because it avoids recomputing + # potentially expensive initializers. return control_flow_ops.cond( condition, - lambda: second_graph_function(*inner_args, **inner_kwds), - lambda: first_concrete(*inner_args, **inner_kwds)) + lambda: self._stateless_fn(*inner_args, **inner_kwds), + _call_concrete(self._concrete_stateful_fn, inner_args, inner_kwds)) + + return function_lib.defun(fn_with_cond)(*args, **kwds) + + @property + def python_function(self): + """The python function wrapped in this tf.function.""" + return self._python_function + + def get_concrete_function(self, *args, **kwargs): + """Returns a `Function` object specialized to inputs and execution context. + + `args` and `kwargs` are ignored if this `PolymorphicFunction` was created + with an `input_signature`. + + Args: + *args: inputs to specialize on. + **kwargs: inputs to specialize on. - return function.defun(fn_with_cond)(*args, **kwds) + Raises: + ValueError: if this object has not yet been called on concrete values. + """ + # TODO(apassos) figure out how to handle this case (what should we return + # here?) + if self._stateful_fn is None: + raise ValueError( + "Call this function with concrete values before asking for a" + " concrete function. Calling the function will ensure that, in" + " case this function creates variables, that those are properly" + " initialized.") + if self._created_variables: + # In this case we have created variables on the first call, so we run the + # defunned version which is guaranteed to never create variables. + return self._stateless_fn.get_concrete_function(*args, **kwargs) + elif self._stateful_fn is not None: + # In this case we have not created variables on the first call. So we can + # run the first trace but we should fail if variables are created. + concrete = self._first_trace.get_concrete_function(*args, **kwargs) + if self._created_variables: + raise ValueError("Creating variables on a non-first call to a function" + " decorated with tf.function.") + return concrete - return decorated_fn + def __get__(self, instance, owner): + """Makes it possible to defun instance methods.""" + del owner + # `instance` here is the instance that this `PolymorphicFunction` was + # accessed through; e.g., for + # + # class Foo(object): + # + # @function.defun + # def bar(self): + # ... + # + # foo = Foo() + # foo.bar() # `foo.bar` is a `PolymorphicFunction` instance + # + # then `instance` will be `foo` (and `owner` will be `Foo`). We create a + # new instance of PolymorphicFunction here to allow different instances each + # to create variables once, thereby allowing methods to be decorated with + # tf.function. Keeps a cache to avoid retracing the function every time the + # descriptor is accessed. + if instance not in self._descriptor_cache: + self._descriptor_cache[instance] = PolymorphicFunction( + functools.partial(self.python_function, instance), + self._input_signature) + return self._descriptor_cache[instance] + + +def function(fn=None, input_signature=None): + """Defines a function as per the "functions, not sessions" document.""" + return PolymorphicFunction(fn, input_signature) diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 804436c4bb..39bad726d0 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -29,7 +29,7 @@ class DefFunctionTest(test.TestCase): def testNoVariables(self): - @def_function.def_function + @def_function.function def fn(x): return 2 * x @@ -37,7 +37,7 @@ class DefFunctionTest(test.TestCase): def testFailIfVariablesAreCreatedMoreThanOnce(self): - @def_function.def_function + @def_function.function def fn(x): return variables.Variable(1.0) + x @@ -47,7 +47,7 @@ class DefFunctionTest(test.TestCase): def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self): state = [] - @def_function.def_function + @def_function.function def fn(x): state.append(variables.Variable(1.0)) return state[-1] + x @@ -59,7 +59,7 @@ class DefFunctionTest(test.TestCase): state = [] - @def_function.def_function + @def_function.function def fn(x): if not state: state.append(variables.Variable(2.0)) @@ -72,7 +72,7 @@ class DefFunctionTest(test.TestCase): state = [] - @def_function.def_function + @def_function.function def fn(x): if not state: state.append(variables.Variable(2.0 * x)) @@ -81,6 +81,28 @@ class DefFunctionTest(test.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) + def testMethod(self): + + class MyModel(object): + + def __init__(self): + self.var = None + + @def_function.function + def apply(self, x): + if self.var is None: + self.var = variables.Variable(2.0) + return self.var * x + + m0 = MyModel() + self.assertAllEqual(m0.apply(3.0), 6.0) + # Calling twice to exercise that we do not recreate variables. + m0.var.assign(3.0) + self.assertAllEqual(m0.apply(3.0), 9.0) + + m1 = MyModel() + self.assertAllEqual(m1.apply(3.0), 6.0) + if __name__ == '__main__': ops.enable_eager_execution() -- cgit v1.2.3