diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-09-07 17:13:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 17:16:44 -0700 |
commit | 35c38c92d0fcb458047282f1e87146ae38c21b57 (patch) | |
tree | b1740b1c1f4d06603093ede419054c60f40a38bf /tensorflow/python/eager | |
parent | 22fa861e03c75c0cf4eb6ee2d81b8c1c17c0982b (diff) |
Automated rollback of commit 72bbefcf1f80cd64cf873b69953a90657dabab18
PiperOrigin-RevId: 212061688
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 133 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 108 |
2 files changed, 92 insertions, 149 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index bc7c7f6502..03f12139f6 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -124,14 +124,8 @@ class FuncGraph(ops.Graph): def __init__(self, name): """Construct a new FuncGraph. - The graph will inherit the following from its current context or graph: - * graph key, - * collections, - * seed, - * device stack, - * colocation stack, - * variable creator stack, and - * distribution strategy stack. + The graph will inherit its graph key, collections, seed, device stack, and + distribution strategy stack from the current context or graph. Args: name: the name of the function. @@ -164,8 +158,6 @@ class FuncGraph(ops.Graph): self._device_function_stack = graph._device_function_stack.copy() # pylint: disable=protected-access self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access - self._variable_creator_stack = graph._variable_creator_stack # pylint: disable=protected-access - # TODO(b/112165328, b/112906995): summaries depend on inheriting collections # from the default graph even in eager mode. It'd be nice to not have a # default graph with eager execution, so hopefully this will go away when we @@ -799,7 +791,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): except (ValueError, TypeError): raise TypeError( "To be compatible with tf.contrib.eager.defun, Python functions " - "must return zero or more Tensors; when tracing %s, found " + "must return zero or more Tensors; in compilation of %s, found " "return value of type %s, which is not a Tensor." % (str(python_func), type(x))) x = a.mark_as_return(x) @@ -1049,11 +1041,7 @@ class PolymorphicFunction(object): colocation_stack = (None if executing_eagerly else tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access - variable_creator_stack = tuple(graph._variable_creator_stack) # pylint: disable=protected-access - - # TODO(b/114446670): Add the _distribution_strategy_stack to the key. - return cache_key + (execution_context, device_functions, colocation_stack, - variable_creator_stack) + return cache_key + (execution_context, device_functions, colocation_stack) def _canonicalize_function_inputs(self, *args, **kwds): """Canonicalizes `args` and `kwds`. @@ -1136,8 +1124,7 @@ class PolymorphicFunction(object): kwds, as well as the inputs that the object should be called with. Raises: - ValueError: If inputs are incompatible with the input signature or - if variables are created on a noninitial trace. + ValueError: If inputs are incompatible with the input signature. TypeError: If the function inputs include non-hashable objects """ @@ -1152,21 +1139,9 @@ class PolymorphicFunction(object): "must be hashable.") if graph_function is None: - - def fail_on_noninitial_creation(next_creator, **kwargs): - if self._function_cache: - raise ValueError( - "A `tf.Variable` was created on a noninitial trace " - "of the Python function %s. When generating a " - "function via `defun`, the encapsulated Python " - "function may only create `tf.Variable`s on the first " - "trace." % self.python_function) - return next_creator(**kwargs) - - with variable_scope.variable_creator_scope(fail_on_noninitial_creation): - graph_function = Function( - func_graph_from_py_func(self._name, self._python_function, args, - kwds, self._input_signature)) + graph_function = Function( + func_graph_from_py_func(self._name, self._python_function, args, + kwds, self._input_signature)) self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) self._function_cache[cache_key] = graph_function @@ -1181,25 +1156,25 @@ def _validate_signature(signature): def defun(func=None, input_signature=None): - """Traces a Python function and produces a callable TensorFlow graph. + """Compiles a Python function into a callable TensorFlow graph. - `defun` (short for "define function") traces a Python function - composed of TensorFlow operations and produces a callable that executes a - `tf.Graph` containing those operations. The callable produced by `defun` - contains only the subgraph of TensorFlow operations that were executed when - the Python function was called with a particular input signature, defined as a - list of the shapes and dtypes of the Python function's Tensor-valued arguments - and the values of its non-Tensor Python objects. In particular, `defun` cannot - capture arbitrary Python code in the callables it generates. + `defun` (short for "define function") trace-compiles a Python function + composed of TensorFlow operations into a callable that executes a `tf.Graph` + containing those operations. The callable produced by `defun` contains only + the subgraph of TensorFlow operations that were executed when the Python + function was called with a particular input signature, defined as a list + of the shapes and dtypes of the Python function's Tensor-valued arguments and + the values of its non-Tensor Python objects. In particular, `defun` is _not_ a + compiler for arbitrary Python code. When eager execution is enabled, the ability to create graphs from Python functions makes it possible to incrementally trade off debugability and - interactivity for performance. Functions traced with `defun` cannot be + interactivity for performance. Functions compiled with `defun` cannot be inspected with `pdb` and `print` statements; however, executing a graph generated by `defun` sometimes takes less time and memory than eagerly executing the corresponding Python function, since specifying computations as graphs allows for optimizations like automatic buffer reuse and - parallelization among ops. Note that executing a `defun`-traced function + parallelization among ops. Note that executing a `defun`-compiled function incurs a small constant overhead, so eagerly executing sufficiently small Python functions might take less time than executing their corresponding `defun`-generated graphs. @@ -1208,9 +1183,8 @@ def defun(func=None, input_signature=None): be hashable Python objects or lists thereof. The function itself may not modify the list/map structure of its arguments. Additionally, it must return zero or more `tf.Tensor` objects. If the Python function returns - a `tf.Variable`, its traced version will return the value of that variable - as a `tf.Tensor`. The Python function may only create `tf.Variable`s the - first time it is called. + a `tf.Variable`, its compiled version will return the value of that variable + as a `tf.Tensor`. Executing a graph generated by `defun` respects device annotations (i.e., all `with tf.device` directives present in a Python function will also be @@ -1237,7 +1211,7 @@ def defun(func=None, input_signature=None): # TensorFlow graph. assert f(x, y).numpy() == g(x, y).numpy() - # `defun` is capable of tracing Python functions that close over Python + # `defun` is capable of compiling Python functions that close over Python # objects, including Tensors and Variables. @tf.contrib.eager.defun def h(): @@ -1246,7 +1220,7 @@ def defun(func=None, input_signature=None): assert (h().numpy() == f(x, y).numpy()).all() # `defun` automatically lifts variables out of the graphs it creates, - # allowing you to trace the `call` methods of `tf.keras.layers.Layer` and + # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and # `tf.keras.Model` objects. class MyModel(tf.keras.Model): @@ -1268,7 +1242,7 @@ def defun(func=None, input_signature=None): model(x, training=True) # executes a graph, with dropout model(x, training=False) # executes a graph, without dropout - # `defun`-traced functions are differentiable. + # `defun`-compiled functions are differentiable. optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) with tf.GradientTape() as tape: outputs = model(x) @@ -1336,7 +1310,7 @@ def defun(func=None, input_signature=None): ``` - Python functions that are traced with an `input_signature` must only accept + Python functions that are compiled with an `input_signature` must only accept Tensors as arguments and must not take unnamed keyword arguments (**kwargs). _Tracing_ @@ -1358,8 +1332,8 @@ def defun(func=None, input_signature=None): return tf.eye(5) + np.random.randn(5, 5) ``` - will return a different output everytime it is invoked, the traced function - `tf_function = tf.contrib.eager.defun(add_noise)` will return the same value + will return a different output everytime it is invoked, the compiled function + `compiled = tf.contrib.eager.defun(add_noise)` will return the same value every time it is called, since a particular random offset generated by NumPy will be inserted into the graph as a TensorFlow constant. The solution is to replace the call to `np.random.randn` with `tf.random_normal((5, 5))`. @@ -1376,7 +1350,7 @@ def defun(func=None, input_signature=None): The structure of many machine learning computations depend upon whether one is training or validating, and it is common to nest specialized logic under `if training:` blocks. By mapping each input signature to a unique graph, `defun` - lets users transparently trace such code, as the following code snippet + lets users transparently compile such code, as the following code snippet demonstrates: ```python @@ -1422,16 +1396,15 @@ def defun(func=None, input_signature=None): with `tf.cond(tensor < 10, true_fn, false_fn)`. _Variables_ - TensorFlow operations related to the creation and initialization of - `tf.Variable`s are automatically lifted out of the graphs generated by - `defun`. In practice, this implies that variable creation and initialization - only happen the first time `F` is called, and that variables are reused every - time thereafter. Many TensorFlow APIs, like `tf.keras.layers.Layer` objects, - create variables the first time they are called and reuse them thereafter. - Automatic variable lifting makes it possible to trace these APIs without - extra effort, at the cost of introducing a discrepancy between the semantics - of executing Python functions and their corresponding trace-generated - functions. For example: + TensorFlow operations related to variable creation and initialization are + automatically lifted out of the graphs generated by `defun`. In practice, this + implies that variable creation and initialization only happen the first time + `F` is called, and that variables are reused every time thereafter. Many + TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the + first time they are called and reuse them thereafter. Automatic variable + lifting makes it possible to compile these APIs without extra effort, at the + cost of introducing a discrepancy between the semantics of executing Python + functions and their corresponding compiled functions. For example: ```python import tensorflow as tf @@ -1447,24 +1420,30 @@ def defun(func=None, input_signature=None): # every invocation assert fn().numpy() == fn().numpy() == 1.0 - traced_fn = tf.contrib.eager.defun(fn) + compiled = tf.contrib.eager.defun(fn) - # Tracing `fn` with `defun` hoists all variables outside of the generated + # Compiling `fn` with `defun` hoists all variables outside of the generated # graph, so initialization happens exactly once. - assert traced_fn().numpy() == 1.0 - assert traced_fn().numpy() == 2.0 + assert compiled().numpy() == 1.0 + assert compiled().numpy() == 2.0 ``` - The wrapped Python function is only permitted to create variables on its first - invocation; an error will be raised if a subsequent trace creates any - variables. This means that if your Python function does create variables, it - must include logic that ensures variables are only created the first time it - is called. Note that this is precisely what `tf.keras.layers.Layer` objects - do, so we recommend using them to represent variable-bearing computations - whenever possible. + Finally, because each input signature is bound to a unique graph, if your + Python function constructs `tf.Variable` objects, then each graph constructed + for that Python function will reference a unique set of variables. To + circumvent this problem, we recommend against compiling Python functions that + create `tf.Variable` objects. Instead, Python functions should either + lexically close over `tf.Variable` objects or accept them as arguments, + preferably encapsulated in an object-oriented container. If you must create + variables inside your Python function and you want each graph generated for it + to reference the same set of variables, add logic to your Python function that + ensures that variables are only created the first time it is called and are + reused for every subsequent invocation; note that this is precisely what + `tf.keras.layers.Layer` objects do, so we recommend using them to represent + variable-bearing computations whenever possible. Args: - func: function to be traced. If `func` is None, returns a + func: function to be compiled. If `func` is None, returns a decorator that can be invoked with a single argument - `func`. The end result is equivalent to providing all the arguments up front. In other words, defun(input_signature=...)(func) is equivalent to @@ -1482,7 +1461,7 @@ def defun(func=None, input_signature=None): `func` cannot accept `**kwargs`. Returns: - If `func` is not None, returns a callable that will execute the traced + If `func` is not None, returns a callable that will execute the compiled function (and return zero or more `tf.Tensor` objects). If `func` is None, returns a decorator that, when invoked with a single `func` argument, returns a callable equivalent to the case above. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index dd6c2483cc..37a9957cea 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -92,7 +92,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) def testGraphModeWithGradients(self): - v = variables.Variable(1.0, name='v') + v = resource_variable_ops.ResourceVariable(1.0, name='v') @function.defun def step(): @@ -105,7 +105,7 @@ class FunctionTest(test.TestCase): def testGraphGradientVariable(self): with ops.Graph().as_default(), self.test_session(): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) @function.defun def f(): @@ -121,18 +121,13 @@ class FunctionTest(test.TestCase): @function.defun def f(): - with ops.init_scope(): - t = constant_op.constant(1.0) - return t + constant_op.constant(1.0) + v = resource_variable_ops.ResourceVariable(1.0) + return v.read_value() - self.assertAllEqual(f(), 2.0) - self.assertEqual(len(f._function_cache), 1) + self.assertAllEqual(f(), 1.0) with ops.Graph().as_default(): - # Reinvoking `f()` in graph-mode should re-trace (to avoid using - # the captured eager tensor). self.assertEqual(f().shape, ()) - self.assertEqual(len(f._function_cache), 2) def testBasicGraphFunction(self): matmul = function.defun(math_ops.matmul) @@ -178,7 +173,7 @@ class FunctionTest(test.TestCase): def testExecutingStatefulDefunConcurrently(self): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) @function.defun def stateful(x): @@ -191,7 +186,7 @@ class FunctionTest(test.TestCase): def testExecutingManyStatefulDefunsConcurrently(self): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) @function.defun def stateful(x): @@ -263,7 +258,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(b['b'].numpy(), 1.0) def testGraphFunctionWithGradients(self): - v = variables.Variable(1.0, name='v') + v = resource_variable_ops.ResourceVariable(1.0, name='v') @function.defun def step(): @@ -342,7 +337,7 @@ class FunctionTest(test.TestCase): self.assertEqual(2, int(add_int32s())) def testDefunReadVariable(self): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) @function.defun def f(): @@ -351,7 +346,7 @@ class FunctionTest(test.TestCase): self.assertEqual(1.0, float(f())) def testDefunAssignAddVariable(self): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) x = constant_op.constant(2.0) @function.defun @@ -369,7 +364,7 @@ class FunctionTest(test.TestCase): @function.defun def tensor_init(): with self.assertRaisesRegexp(ValueError, error_msg): - variables.Variable(constant_op.constant(2.0)) + resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) tensor_init() @@ -378,7 +373,7 @@ class FunctionTest(test.TestCase): @function.defun def tensor_init(): - v = variables.Variable( + v = resource_variable_ops.ResourceVariable( lambda: constant_op.constant(2.0)) return v.read_value() @@ -394,7 +389,7 @@ class FunctionTest(test.TestCase): def tensor_init(): with ops.init_scope(): const = constant_op.constant(2.0) - v = variables.Variable(const) + v = resource_variable_ops.ResourceVariable(const) return v.read_value() value = tensor_init() @@ -402,40 +397,8 @@ class FunctionTest(test.TestCase): self.evaluate(variables.global_variables_initializer()) self.assertEqual(self.evaluate(value), 2.0) - def testCreatingVariablesOnNoninitialTraceFails(self): - - @function.defun - def create_var(param): - del param - v = variables.Variable(1.0) - return v.read_value() - - create_var('one') - self.assertEqual(len(create_var.variables), 1) - - with self.assertRaisesRegexp( - ValueError, 'A `tf.Variable` was created on ' - 'a noninitial trace of the Python function.*'): - create_var('two') - - @function.defun - def maybe_create_var(param): - if param == 'two': - v = variables.Variable(1.0) - return v.read_value() - else: - return constant_op.constant(1.0) - - maybe_create_var('one') - self.assertEqual(len(maybe_create_var.variables), 0) - - with self.assertRaisesRegexp( - ValueError, 'A `tf.Variable` was created on ' - 'a noninitial trace of the Python function.*'): - maybe_create_var('two') - def testDefunShapeInferenceWithCapturedResourceVariable(self): - v = variables.Variable([[1, 2], [3, 4]]) + v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) def f(): x = constant_op.constant([[1, 2], [3, 4]]) @@ -462,7 +425,7 @@ class FunctionTest(test.TestCase): def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): with context.graph_mode(): - v = variables.Variable([[1, 2], [3, 4]]) + v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) def f(): x = constant_op.constant([[1, 2], [3, 4]]) @@ -495,10 +458,10 @@ class FunctionTest(test.TestCase): defined() # Create the variable. self.assertEqual(len(defined.variables), 1) self.assertIsInstance( - defined.variables[0], variables.Variable) + defined.variables[0], resource_variable_ops.ResourceVariable) def testDefunDifferentiable(self): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) @function.defun def f(): @@ -507,7 +470,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) def testDefunCanBeDifferentiatedTwice(self): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) @function.defun def f(): @@ -523,7 +486,7 @@ class FunctionTest(test.TestCase): class HasAVar(object): def __init__(self): - self.v = variables.Variable(1.0) + self.v = resource_variable_ops.ResourceVariable(1.0) def call(self): return self.v * 2 @@ -536,7 +499,7 @@ class FunctionTest(test.TestCase): def testSymbolicGradientVariableZerosLike(self): with ops.Graph().as_default(): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) @function.defun def f(x, v): @@ -642,7 +605,7 @@ class FunctionTest(test.TestCase): g(constant_op.constant(1.0)) def testNestedDefunWithNoOutputAndTapedInput(self): - three = variables.Variable(3.0, name='v') + three = resource_variable_ops.ResourceVariable(3.0, name='v') @function.defun def f(x): @@ -658,7 +621,7 @@ class FunctionTest(test.TestCase): g(three) def testGradientTensorConversionWithDefun(self): - three = variables.Variable(3.0, name='v') + three = resource_variable_ops.ResourceVariable(3.0, name='v') @function.defun def f(x): @@ -690,7 +653,7 @@ class FunctionTest(test.TestCase): def testGatherResourceWithDefun(self): with ops.device('cpu:0'): - v = variables.Variable([0.0, 1.0, 2.0]) + v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) def sum_gather(): return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) @@ -699,7 +662,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(sum_gather(), defined()) def testGradientOfGatherWithDefun(self): - v = variables.Variable([0.0, 1.0, 2.0]) + v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) def sum_gather(): return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) @@ -786,10 +749,10 @@ class FunctionTest(test.TestCase): self.skipTest('No GPUs found.') with ops.device('/cpu:0'): - v_cpu = variables.Variable([0.0, 1.0, 2.0]) + v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) with ops.device('/gpu:0'): - v_gpu = variables.Variable([0.0, 1.0, 2.0]) + v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) def sum_gather(): cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) @@ -808,13 +771,14 @@ class FunctionTest(test.TestCase): self.skipTest('No GPUs found.') with ops.device('/cpu:0'): - v_cpu = variables.Variable( - [0.0, 1.0, 2.0], name='cpu', use_resource=True) - v_also_cpu = variables.Variable( - [0.0, 1.0, 2.0], name='also_cpu', use_resource=True) + v_cpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name='cpu') + v_also_cpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name='also_cpu') with ops.device('/gpu:0'): - v_gpu = variables.Variable([0.0, 1.0, 2.0], name='gpu', use_resource=True) + v_gpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name='gpu') @function.defun def resource_apply_adam(): @@ -948,7 +912,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(3, add_one(constant_op.constant(2))) def testVariableCaptureInNestedFunctions(self): - v = variables.Variable(1, dtype=dtypes.int32) + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) @function.defun def inner_read(): @@ -1016,7 +980,7 @@ class FunctionTest(test.TestCase): @function.defun def create_variable(): with ops.name_scope('foo'): - v = variables.Variable(0.0, name='bar') + v = resource_variable_ops.ResourceVariable(0.0, name='bar') self.assertEqual(v.name, 'foo/bar:0') create_variable() @@ -1026,7 +990,7 @@ class FunctionTest(test.TestCase): @function.defun def create_variable(): with ops.name_scope('foo'): - v = variables.Variable([1.0, 2.0], name='bar') + v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') self.assertEqual(v.name, 'foo/bar:0') with ops.get_default_graph().as_default(): @@ -1158,7 +1122,7 @@ class FunctionTest(test.TestCase): self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) def testVariablesAreTracked(self): - v = variables.Variable(1.0) + v = resource_variable_ops.ResourceVariable(1.0) def foo(x): return v * x |