aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-09 11:16:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:21:01 -0700
commit84ace0358526bb51c04a3bef4b3072b93b9d1bec (patch)
treead94e250a62f4d4bd1dbe095ab34741646af6add
parent1e4a3baad388b5d5250efdb19f91d5b670816fbe (diff)
Improves tf.function prototype.
Specifically: - renames from def_function - returns an object with well-defined methods - doesn't force-retrace twice - uses the python descriptor API ( https://docs.python.org/3/howto/descriptor.html ) to remove the need for a tf.method PiperOrigin-RevId: 216388957
-rw-r--r--tensorflow/python/eager/def_function.py188
-rw-r--r--tensorflow/python/eager/def_function_test.py32
2 files changed, 179 insertions, 41 deletions
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 8dcacd5c99..b23891d394 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -19,8 +19,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+import weakref
+
from tensorflow.python.eager import context
-from tensorflow.python.eager import function
+from tensorflow.python.eager import function as function_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
@@ -165,71 +168,184 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
self._cached_shape_as_list = None
-def _defun_with_scope(scope, fn):
+def _defun_with_scope(scope, fn, input_signature):
def wrapped_fn(*args, **kwds):
with variable_scope.variable_creator_scope(scope):
return fn(*args, **kwds)
- return function.defun(wrapped_fn)
+ return function_lib.defun(wrapped_fn, input_signature=input_signature)
-def def_function(fn):
- """Defines a function as per the "functions, not sessions" document."""
+def _call_concrete(fn, args, unused_kwargs):
+ """Calls the given concrete function with only the tensor arguments."""
+
+ def inner():
+ # TODO(apassos) figure out what to do with kwargs and concrete functions.
+ return fn(*[x for x in args if isinstance(x, ops.Tensor)])
+
+ return inner
+
+
+class PolymorphicFunction(object):
+ """Wrapper class for the graph functions defined for a Python function.
+
+ See the documentation for `tf.function` for more information on the semantics
+ of defined functions.
- # Wrapping the values in lists to bypass python's lack of way to mutate
- # symbols from an outer scope.
- first_call = [True]
- function_to_call = []
+ PolymorphicFunction is thread-compatible.
+ """
+
+ def __init__(self,
+ python_function,
+ input_signature=None,):
+ """Initializes a polymorphic function.
+
+ Args:
+ python_function: the function to be wrapped.
+ input_signature: a possibly nested sequence of `TensorSpec` objects
+ specifying the input signature of this function. If `None`, a separate
+ function is instantiated for each inferred input signature.
+
+ Raises:
+ ValueError: if `input_signature` is not None and the `python_function`'s
+ argspec has keyword arguments.
+ """
+ self._python_function = python_function
+ self._input_signature = input_signature
+ self._created_variables = None
+ self._stateful_fn = None
+ self._descriptor_cache = weakref.WeakKeyDictionary()
- # TODO(apassos) represent this as an object and not as a closure.
- def decorated_fn(*args, **kwds):
- """Graph function for fn."""
- if not first_call[0]:
- return function_to_call[0](*args, **kwds)
+ def _initialize(self, args, kwds):
+ """Initializes, on the first call."""
- first_call[0] = False
- created_variables = []
+ self._created_variables = []
- def variable_creator_scope(unused_next_creator, **kwds):
+ def variable_capturing_scope(unused_next_creator, **kwds):
"""Creates UnliftedInitializerVariables and saves references to them."""
v = UnliftedInitializerVariable(**kwds)
- created_variables.append(v)
+ self._created_variables.append(v)
return v
- first_graph_function = _defun_with_scope(variable_creator_scope, fn)
+ self._stateful_fn = _defun_with_scope(
+ variable_capturing_scope, self._python_function, self._input_signature)
# Force the definition of the function for these arguments
- first_concrete = first_graph_function.get_concrete_function(*args, **kwds)
+ self._concrete_stateful_fn = self._stateful_fn.get_concrete_function(
+ *args, **kwds)
def invalid_creator_scope(*unused_args, **unused_kwds):
"""Disables variable creation."""
raise ValueError(
- "def_function-decorated function tried to create "
- "variables on second call.")
+ "tf.function-decorated function tried to create "
+ "variables on non-first call.")
- second_graph_function = _defun_with_scope(invalid_creator_scope, fn)
+ self._stateless_fn = _defun_with_scope(
+ invalid_creator_scope, self._python_function, self._input_signature)
- function_to_call.append(second_graph_function)
- if not created_variables:
- # Note: this retracing might be unnecessary, but running the function
- # forever in the scope which disallows variable creation is safer than not
- # doing so.
- return second_graph_function(*args, **kwds)
+ def __call__(self, *args, **kwds):
+ """Calls the graph function."""
+ if self._created_variables:
+ # In this case we have created variables on the first call, so we run the
+ # defunned version which is guaranteed to never create variables.
+ return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
+ elif self._stateful_fn is not None:
+ # In this case we have not created variables on the first call. So we can
+ # run the first trace but we should fail if variables are created.
+ results = self._first_trace(*args, **kwds)
+ if self._created_variables:
+ raise ValueError("Creating variables on a non-first call to a function"
+ " decorated with tf.function.")
+ return results
+
+ self._initialize(args, kwds)
+
+ if not self._created_variables:
+ # If we did not create any variables the trace we have is good enough.
+ return _call_concrete(self._concrete_stateful_fn, args, kwds)()
def fn_with_cond(*inner_args, **inner_kwds):
"""Conditionally runs initialization if it's needed."""
condition = True
- for variable in created_variables:
+ for variable in self._created_variables:
condition = condition and resource_variable_ops.var_is_initialized_op(
variable.handle)
- # We want to call second_graph_function if possible because it avoids
- # recomputing potentially expensive initializers.
+ # We want to call stateless_fn if possible because it avoids recomputing
+ # potentially expensive initializers.
return control_flow_ops.cond(
condition,
- lambda: second_graph_function(*inner_args, **inner_kwds),
- lambda: first_concrete(*inner_args, **inner_kwds))
+ lambda: self._stateless_fn(*inner_args, **inner_kwds),
+ _call_concrete(self._concrete_stateful_fn, inner_args, inner_kwds))
+
+ return function_lib.defun(fn_with_cond)(*args, **kwds)
+
+ @property
+ def python_function(self):
+ """The python function wrapped in this tf.function."""
+ return self._python_function
+
+ def get_concrete_function(self, *args, **kwargs):
+ """Returns a `Function` object specialized to inputs and execution context.
+
+ `args` and `kwargs` are ignored if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
+ Args:
+ *args: inputs to specialize on.
+ **kwargs: inputs to specialize on.
- return function.defun(fn_with_cond)(*args, **kwds)
+ Raises:
+ ValueError: if this object has not yet been called on concrete values.
+ """
+ # TODO(apassos) figure out how to handle this case (what should we return
+ # here?)
+ if self._stateful_fn is None:
+ raise ValueError(
+ "Call this function with concrete values before asking for a"
+ " concrete function. Calling the function will ensure that, in"
+ " case this function creates variables, that those are properly"
+ " initialized.")
+ if self._created_variables:
+ # In this case we have created variables on the first call, so we run the
+ # defunned version which is guaranteed to never create variables.
+ return self._stateless_fn.get_concrete_function(*args, **kwargs)
+ elif self._stateful_fn is not None:
+ # In this case we have not created variables on the first call. So we can
+ # run the first trace but we should fail if variables are created.
+ concrete = self._first_trace.get_concrete_function(*args, **kwargs)
+ if self._created_variables:
+ raise ValueError("Creating variables on a non-first call to a function"
+ " decorated with tf.function.")
+ return concrete
- return decorated_fn
+ def __get__(self, instance, owner):
+ """Makes it possible to defun instance methods."""
+ del owner
+ # `instance` here is the instance that this `PolymorphicFunction` was
+ # accessed through; e.g., for
+ #
+ # class Foo(object):
+ #
+ # @function.defun
+ # def bar(self):
+ # ...
+ #
+ # foo = Foo()
+ # foo.bar() # `foo.bar` is a `PolymorphicFunction` instance
+ #
+ # then `instance` will be `foo` (and `owner` will be `Foo`). We create a
+ # new instance of PolymorphicFunction here to allow different instances each
+ # to create variables once, thereby allowing methods to be decorated with
+ # tf.function. Keeps a cache to avoid retracing the function every time the
+ # descriptor is accessed.
+ if instance not in self._descriptor_cache:
+ self._descriptor_cache[instance] = PolymorphicFunction(
+ functools.partial(self.python_function, instance),
+ self._input_signature)
+ return self._descriptor_cache[instance]
+
+
+def function(fn=None, input_signature=None):
+ """Defines a function as per the "functions, not sessions" document."""
+ return PolymorphicFunction(fn, input_signature)
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index 804436c4bb..39bad726d0 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -29,7 +29,7 @@ class DefFunctionTest(test.TestCase):
def testNoVariables(self):
- @def_function.def_function
+ @def_function.function
def fn(x):
return 2 * x
@@ -37,7 +37,7 @@ class DefFunctionTest(test.TestCase):
def testFailIfVariablesAreCreatedMoreThanOnce(self):
- @def_function.def_function
+ @def_function.function
def fn(x):
return variables.Variable(1.0) + x
@@ -47,7 +47,7 @@ class DefFunctionTest(test.TestCase):
def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
state = []
- @def_function.def_function
+ @def_function.function
def fn(x):
state.append(variables.Variable(1.0))
return state[-1] + x
@@ -59,7 +59,7 @@ class DefFunctionTest(test.TestCase):
state = []
- @def_function.def_function
+ @def_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0))
@@ -72,7 +72,7 @@ class DefFunctionTest(test.TestCase):
state = []
- @def_function.def_function
+ @def_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0 * x))
@@ -81,6 +81,28 @@ class DefFunctionTest(test.TestCase):
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+ def testMethod(self):
+
+ class MyModel(object):
+
+ def __init__(self):
+ self.var = None
+
+ @def_function.function
+ def apply(self, x):
+ if self.var is None:
+ self.var = variables.Variable(2.0)
+ return self.var * x
+
+ m0 = MyModel()
+ self.assertAllEqual(m0.apply(3.0), 6.0)
+ # Calling twice to exercise that we do not recreate variables.
+ m0.var.assign(3.0)
+ self.assertAllEqual(m0.apply(3.0), 9.0)
+
+ m1 = MyModel()
+ self.assertAllEqual(m1.apply(3.0), 6.0)
+
if __name__ == '__main__':
ops.enable_eager_execution()