aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-09-07 17:13:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 17:16:44 -0700
commit35c38c92d0fcb458047282f1e87146ae38c21b57 (patch)
treeb1740b1c1f4d06603093ede419054c60f40a38bf /tensorflow/python/eager
parent22fa861e03c75c0cf4eb6ee2d81b8c1c17c0982b (diff)
Automated rollback of commit 72bbefcf1f80cd64cf873b69953a90657dabab18
PiperOrigin-RevId: 212061688
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py133
-rw-r--r--tensorflow/python/eager/function_test.py108
2 files changed, 92 insertions, 149 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index bc7c7f6502..03f12139f6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -124,14 +124,8 @@ class FuncGraph(ops.Graph):
def __init__(self, name):
"""Construct a new FuncGraph.
- The graph will inherit the following from its current context or graph:
- * graph key,
- * collections,
- * seed,
- * device stack,
- * colocation stack,
- * variable creator stack, and
- * distribution strategy stack.
+ The graph will inherit its graph key, collections, seed, device stack, and
+ distribution strategy stack from the current context or graph.
Args:
name: the name of the function.
@@ -164,8 +158,6 @@ class FuncGraph(ops.Graph):
self._device_function_stack = graph._device_function_stack.copy() # pylint: disable=protected-access
self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
- self._variable_creator_stack = graph._variable_creator_stack # pylint: disable=protected-access
-
# TODO(b/112165328, b/112906995): summaries depend on inheriting collections
# from the default graph even in eager mode. It'd be nice to not have a
# default graph with eager execution, so hopefully this will go away when we
@@ -799,7 +791,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
except (ValueError, TypeError):
raise TypeError(
"To be compatible with tf.contrib.eager.defun, Python functions "
- "must return zero or more Tensors; when tracing %s, found "
+ "must return zero or more Tensors; in compilation of %s, found "
"return value of type %s, which is not a Tensor." %
(str(python_func), type(x)))
x = a.mark_as_return(x)
@@ -1049,11 +1041,7 @@ class PolymorphicFunction(object):
colocation_stack = (None if executing_eagerly else
tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
- variable_creator_stack = tuple(graph._variable_creator_stack) # pylint: disable=protected-access
-
- # TODO(b/114446670): Add the _distribution_strategy_stack to the key.
- return cache_key + (execution_context, device_functions, colocation_stack,
- variable_creator_stack)
+ return cache_key + (execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwds):
"""Canonicalizes `args` and `kwds`.
@@ -1136,8 +1124,7 @@ class PolymorphicFunction(object):
kwds, as well as the inputs that the object should be called with.
Raises:
- ValueError: If inputs are incompatible with the input signature or
- if variables are created on a noninitial trace.
+ ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
"""
@@ -1152,21 +1139,9 @@ class PolymorphicFunction(object):
"must be hashable.")
if graph_function is None:
-
- def fail_on_noninitial_creation(next_creator, **kwargs):
- if self._function_cache:
- raise ValueError(
- "A `tf.Variable` was created on a noninitial trace "
- "of the Python function %s. When generating a "
- "function via `defun`, the encapsulated Python "
- "function may only create `tf.Variable`s on the first "
- "trace." % self.python_function)
- return next_creator(**kwargs)
-
- with variable_scope.variable_creator_scope(fail_on_noninitial_creation):
- graph_function = Function(
- func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature))
+ graph_function = Function(
+ func_graph_from_py_func(self._name, self._python_function, args,
+ kwds, self._input_signature))
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
self._function_cache[cache_key] = graph_function
@@ -1181,25 +1156,25 @@ def _validate_signature(signature):
def defun(func=None, input_signature=None):
- """Traces a Python function and produces a callable TensorFlow graph.
+ """Compiles a Python function into a callable TensorFlow graph.
- `defun` (short for "define function") traces a Python function
- composed of TensorFlow operations and produces a callable that executes a
- `tf.Graph` containing those operations. The callable produced by `defun`
- contains only the subgraph of TensorFlow operations that were executed when
- the Python function was called with a particular input signature, defined as a
- list of the shapes and dtypes of the Python function's Tensor-valued arguments
- and the values of its non-Tensor Python objects. In particular, `defun` cannot
- capture arbitrary Python code in the callables it generates.
+ `defun` (short for "define function") trace-compiles a Python function
+ composed of TensorFlow operations into a callable that executes a `tf.Graph`
+ containing those operations. The callable produced by `defun` contains only
+ the subgraph of TensorFlow operations that were executed when the Python
+ function was called with a particular input signature, defined as a list
+ of the shapes and dtypes of the Python function's Tensor-valued arguments and
+ the values of its non-Tensor Python objects. In particular, `defun` is _not_ a
+ compiler for arbitrary Python code.
When eager execution is enabled, the ability to create graphs from Python
functions makes it possible to incrementally trade off debugability and
- interactivity for performance. Functions traced with `defun` cannot be
+ interactivity for performance. Functions compiled with `defun` cannot be
inspected with `pdb` and `print` statements; however, executing a graph
generated by `defun` sometimes takes less time and memory than eagerly
executing the corresponding Python function, since specifying computations as
graphs allows for optimizations like automatic buffer reuse and
- parallelization among ops. Note that executing a `defun`-traced function
+ parallelization among ops. Note that executing a `defun`-compiled function
incurs a small constant overhead, so eagerly executing sufficiently small
Python functions might take less time than executing their corresponding
`defun`-generated graphs.
@@ -1208,9 +1183,8 @@ def defun(func=None, input_signature=None):
be hashable Python objects or lists thereof. The function itself may not
modify the list/map structure of its arguments. Additionally, it must return
zero or more `tf.Tensor` objects. If the Python function returns
- a `tf.Variable`, its traced version will return the value of that variable
- as a `tf.Tensor`. The Python function may only create `tf.Variable`s the
- first time it is called.
+ a `tf.Variable`, its compiled version will return the value of that variable
+ as a `tf.Tensor`.
Executing a graph generated by `defun` respects device annotations (i.e.,
all `with tf.device` directives present in a Python function will also be
@@ -1237,7 +1211,7 @@ def defun(func=None, input_signature=None):
# TensorFlow graph.
assert f(x, y).numpy() == g(x, y).numpy()
- # `defun` is capable of tracing Python functions that close over Python
+ # `defun` is capable of compiling Python functions that close over Python
# objects, including Tensors and Variables.
@tf.contrib.eager.defun
def h():
@@ -1246,7 +1220,7 @@ def defun(func=None, input_signature=None):
assert (h().numpy() == f(x, y).numpy()).all()
# `defun` automatically lifts variables out of the graphs it creates,
- # allowing you to trace the `call` methods of `tf.keras.layers.Layer` and
+ # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
# `tf.keras.Model` objects.
class MyModel(tf.keras.Model):
@@ -1268,7 +1242,7 @@ def defun(func=None, input_signature=None):
model(x, training=True) # executes a graph, with dropout
model(x, training=False) # executes a graph, without dropout
- # `defun`-traced functions are differentiable.
+ # `defun`-compiled functions are differentiable.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
with tf.GradientTape() as tape:
outputs = model(x)
@@ -1336,7 +1310,7 @@ def defun(func=None, input_signature=None):
```
- Python functions that are traced with an `input_signature` must only accept
+ Python functions that are compiled with an `input_signature` must only accept
Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
_Tracing_
@@ -1358,8 +1332,8 @@ def defun(func=None, input_signature=None):
return tf.eye(5) + np.random.randn(5, 5)
```
- will return a different output everytime it is invoked, the traced function
- `tf_function = tf.contrib.eager.defun(add_noise)` will return the same value
+ will return a different output everytime it is invoked, the compiled function
+ `compiled = tf.contrib.eager.defun(add_noise)` will return the same value
every time it is called, since a particular random offset generated by NumPy
will be inserted into the graph as a TensorFlow constant. The solution is to
replace the call to `np.random.randn` with `tf.random_normal((5, 5))`.
@@ -1376,7 +1350,7 @@ def defun(func=None, input_signature=None):
The structure of many machine learning computations depend upon whether one is
training or validating, and it is common to nest specialized logic under `if
training:` blocks. By mapping each input signature to a unique graph, `defun`
- lets users transparently trace such code, as the following code snippet
+ lets users transparently compile such code, as the following code snippet
demonstrates:
```python
@@ -1422,16 +1396,15 @@ def defun(func=None, input_signature=None):
with `tf.cond(tensor < 10, true_fn, false_fn)`.
_Variables_
- TensorFlow operations related to the creation and initialization of
- `tf.Variable`s are automatically lifted out of the graphs generated by
- `defun`. In practice, this implies that variable creation and initialization
- only happen the first time `F` is called, and that variables are reused every
- time thereafter. Many TensorFlow APIs, like `tf.keras.layers.Layer` objects,
- create variables the first time they are called and reuse them thereafter.
- Automatic variable lifting makes it possible to trace these APIs without
- extra effort, at the cost of introducing a discrepancy between the semantics
- of executing Python functions and their corresponding trace-generated
- functions. For example:
+ TensorFlow operations related to variable creation and initialization are
+ automatically lifted out of the graphs generated by `defun`. In practice, this
+ implies that variable creation and initialization only happen the first time
+ `F` is called, and that variables are reused every time thereafter. Many
+ TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
+ first time they are called and reuse them thereafter. Automatic variable
+ lifting makes it possible to compile these APIs without extra effort, at the
+ cost of introducing a discrepancy between the semantics of executing Python
+ functions and their corresponding compiled functions. For example:
```python
import tensorflow as tf
@@ -1447,24 +1420,30 @@ def defun(func=None, input_signature=None):
# every invocation
assert fn().numpy() == fn().numpy() == 1.0
- traced_fn = tf.contrib.eager.defun(fn)
+ compiled = tf.contrib.eager.defun(fn)
- # Tracing `fn` with `defun` hoists all variables outside of the generated
+ # Compiling `fn` with `defun` hoists all variables outside of the generated
# graph, so initialization happens exactly once.
- assert traced_fn().numpy() == 1.0
- assert traced_fn().numpy() == 2.0
+ assert compiled().numpy() == 1.0
+ assert compiled().numpy() == 2.0
```
- The wrapped Python function is only permitted to create variables on its first
- invocation; an error will be raised if a subsequent trace creates any
- variables. This means that if your Python function does create variables, it
- must include logic that ensures variables are only created the first time it
- is called. Note that this is precisely what `tf.keras.layers.Layer` objects
- do, so we recommend using them to represent variable-bearing computations
- whenever possible.
+ Finally, because each input signature is bound to a unique graph, if your
+ Python function constructs `tf.Variable` objects, then each graph constructed
+ for that Python function will reference a unique set of variables. To
+ circumvent this problem, we recommend against compiling Python functions that
+ create `tf.Variable` objects. Instead, Python functions should either
+ lexically close over `tf.Variable` objects or accept them as arguments,
+ preferably encapsulated in an object-oriented container. If you must create
+ variables inside your Python function and you want each graph generated for it
+ to reference the same set of variables, add logic to your Python function that
+ ensures that variables are only created the first time it is called and are
+ reused for every subsequent invocation; note that this is precisely what
+ `tf.keras.layers.Layer` objects do, so we recommend using them to represent
+ variable-bearing computations whenever possible.
Args:
- func: function to be traced. If `func` is None, returns a
+ func: function to be compiled. If `func` is None, returns a
decorator that can be invoked with a single argument - `func`. The
end result is equivalent to providing all the arguments up front.
In other words, defun(input_signature=...)(func) is equivalent to
@@ -1482,7 +1461,7 @@ def defun(func=None, input_signature=None):
`func` cannot accept `**kwargs`.
Returns:
- If `func` is not None, returns a callable that will execute the traced
+ If `func` is not None, returns a callable that will execute the compiled
function (and return zero or more `tf.Tensor` objects).
If `func` is None, returns a decorator that, when invoked with a single
`func` argument, returns a callable equivalent to the case above.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index dd6c2483cc..37a9957cea 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -92,7 +92,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testGraphModeWithGradients(self):
- v = variables.Variable(1.0, name='v')
+ v = resource_variable_ops.ResourceVariable(1.0, name='v')
@function.defun
def step():
@@ -105,7 +105,7 @@ class FunctionTest(test.TestCase):
def testGraphGradientVariable(self):
with ops.Graph().as_default(), self.test_session():
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
def f():
@@ -121,18 +121,13 @@ class FunctionTest(test.TestCase):
@function.defun
def f():
- with ops.init_scope():
- t = constant_op.constant(1.0)
- return t + constant_op.constant(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
+ return v.read_value()
- self.assertAllEqual(f(), 2.0)
- self.assertEqual(len(f._function_cache), 1)
+ self.assertAllEqual(f(), 1.0)
with ops.Graph().as_default():
- # Reinvoking `f()` in graph-mode should re-trace (to avoid using
- # the captured eager tensor).
self.assertEqual(f().shape, ())
- self.assertEqual(len(f._function_cache), 2)
def testBasicGraphFunction(self):
matmul = function.defun(math_ops.matmul)
@@ -178,7 +173,7 @@ class FunctionTest(test.TestCase):
def testExecutingStatefulDefunConcurrently(self):
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
def stateful(x):
@@ -191,7 +186,7 @@ class FunctionTest(test.TestCase):
def testExecutingManyStatefulDefunsConcurrently(self):
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
def stateful(x):
@@ -263,7 +258,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(b['b'].numpy(), 1.0)
def testGraphFunctionWithGradients(self):
- v = variables.Variable(1.0, name='v')
+ v = resource_variable_ops.ResourceVariable(1.0, name='v')
@function.defun
def step():
@@ -342,7 +337,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(2, int(add_int32s()))
def testDefunReadVariable(self):
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
def f():
@@ -351,7 +346,7 @@ class FunctionTest(test.TestCase):
self.assertEqual(1.0, float(f()))
def testDefunAssignAddVariable(self):
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
x = constant_op.constant(2.0)
@function.defun
@@ -369,7 +364,7 @@ class FunctionTest(test.TestCase):
@function.defun
def tensor_init():
with self.assertRaisesRegexp(ValueError, error_msg):
- variables.Variable(constant_op.constant(2.0))
+ resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
tensor_init()
@@ -378,7 +373,7 @@ class FunctionTest(test.TestCase):
@function.defun
def tensor_init():
- v = variables.Variable(
+ v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
return v.read_value()
@@ -394,7 +389,7 @@ class FunctionTest(test.TestCase):
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
- v = variables.Variable(const)
+ v = resource_variable_ops.ResourceVariable(const)
return v.read_value()
value = tensor_init()
@@ -402,40 +397,8 @@ class FunctionTest(test.TestCase):
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(value), 2.0)
- def testCreatingVariablesOnNoninitialTraceFails(self):
-
- @function.defun
- def create_var(param):
- del param
- v = variables.Variable(1.0)
- return v.read_value()
-
- create_var('one')
- self.assertEqual(len(create_var.variables), 1)
-
- with self.assertRaisesRegexp(
- ValueError, 'A `tf.Variable` was created on '
- 'a noninitial trace of the Python function.*'):
- create_var('two')
-
- @function.defun
- def maybe_create_var(param):
- if param == 'two':
- v = variables.Variable(1.0)
- return v.read_value()
- else:
- return constant_op.constant(1.0)
-
- maybe_create_var('one')
- self.assertEqual(len(maybe_create_var.variables), 0)
-
- with self.assertRaisesRegexp(
- ValueError, 'A `tf.Variable` was created on '
- 'a noninitial trace of the Python function.*'):
- maybe_create_var('two')
-
def testDefunShapeInferenceWithCapturedResourceVariable(self):
- v = variables.Variable([[1, 2], [3, 4]])
+ v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
@@ -462,7 +425,7 @@ class FunctionTest(test.TestCase):
def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
with context.graph_mode():
- v = variables.Variable([[1, 2], [3, 4]])
+ v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
@@ -495,10 +458,10 @@ class FunctionTest(test.TestCase):
defined() # Create the variable.
self.assertEqual(len(defined.variables), 1)
self.assertIsInstance(
- defined.variables[0], variables.Variable)
+ defined.variables[0], resource_variable_ops.ResourceVariable)
def testDefunDifferentiable(self):
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
def f():
@@ -507,7 +470,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
def testDefunCanBeDifferentiatedTwice(self):
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
def f():
@@ -523,7 +486,7 @@ class FunctionTest(test.TestCase):
class HasAVar(object):
def __init__(self):
- self.v = variables.Variable(1.0)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
def call(self):
return self.v * 2
@@ -536,7 +499,7 @@ class FunctionTest(test.TestCase):
def testSymbolicGradientVariableZerosLike(self):
with ops.Graph().as_default():
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
@function.defun
def f(x, v):
@@ -642,7 +605,7 @@ class FunctionTest(test.TestCase):
g(constant_op.constant(1.0))
def testNestedDefunWithNoOutputAndTapedInput(self):
- three = variables.Variable(3.0, name='v')
+ three = resource_variable_ops.ResourceVariable(3.0, name='v')
@function.defun
def f(x):
@@ -658,7 +621,7 @@ class FunctionTest(test.TestCase):
g(three)
def testGradientTensorConversionWithDefun(self):
- three = variables.Variable(3.0, name='v')
+ three = resource_variable_ops.ResourceVariable(3.0, name='v')
@function.defun
def f(x):
@@ -690,7 +653,7 @@ class FunctionTest(test.TestCase):
def testGatherResourceWithDefun(self):
with ops.device('cpu:0'):
- v = variables.Variable([0.0, 1.0, 2.0])
+ v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
@@ -699,7 +662,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(sum_gather(), defined())
def testGradientOfGatherWithDefun(self):
- v = variables.Variable([0.0, 1.0, 2.0])
+ v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
@@ -786,10 +749,10 @@ class FunctionTest(test.TestCase):
self.skipTest('No GPUs found.')
with ops.device('/cpu:0'):
- v_cpu = variables.Variable([0.0, 1.0, 2.0])
+ v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
with ops.device('/gpu:0'):
- v_gpu = variables.Variable([0.0, 1.0, 2.0])
+ v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
@@ -808,13 +771,14 @@ class FunctionTest(test.TestCase):
self.skipTest('No GPUs found.')
with ops.device('/cpu:0'):
- v_cpu = variables.Variable(
- [0.0, 1.0, 2.0], name='cpu', use_resource=True)
- v_also_cpu = variables.Variable(
- [0.0, 1.0, 2.0], name='also_cpu', use_resource=True)
+ v_cpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='cpu')
+ v_also_cpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='also_cpu')
with ops.device('/gpu:0'):
- v_gpu = variables.Variable([0.0, 1.0, 2.0], name='gpu', use_resource=True)
+ v_gpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='gpu')
@function.defun
def resource_apply_adam():
@@ -948,7 +912,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(3, add_one(constant_op.constant(2)))
def testVariableCaptureInNestedFunctions(self):
- v = variables.Variable(1, dtype=dtypes.int32)
+ v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
@function.defun
def inner_read():
@@ -1016,7 +980,7 @@ class FunctionTest(test.TestCase):
@function.defun
def create_variable():
with ops.name_scope('foo'):
- v = variables.Variable(0.0, name='bar')
+ v = resource_variable_ops.ResourceVariable(0.0, name='bar')
self.assertEqual(v.name, 'foo/bar:0')
create_variable()
@@ -1026,7 +990,7 @@ class FunctionTest(test.TestCase):
@function.defun
def create_variable():
with ops.name_scope('foo'):
- v = variables.Variable([1.0, 2.0], name='bar')
+ v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
self.assertEqual(v.name, 'foo/bar:0')
with ops.get_default_graph().as_default():
@@ -1158,7 +1122,7 @@ class FunctionTest(test.TestCase):
self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
def testVariablesAreTracked(self):
- v = variables.Variable(1.0)
+ v = resource_variable_ops.ResourceVariable(1.0)
def foo(x):
return v * x