diff options
author | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-09-11 19:59:11 +0800 |
---|---|---|
committer | Cao Zongyan <zongyan.cao@alibaba-inc.com> | 2018-09-11 19:59:11 +0800 |
commit | 9b3a93edf5a1f259bfe5230cc3b6c076573d4ec9 (patch) | |
tree | cbb0548282ba1584ed91a1be8f89b03ec882f287 /tensorflow/python/eager | |
parent | 90cf7fb7786c8a9c135ef73482856b082e80f61a (diff) | |
parent | e18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff) |
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/BUILD | 39 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 157 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 101 | ||||
-rw-r--r-- | tensorflow/python/eager/benchmarks_test.py | 55 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 41 | ||||
-rw-r--r-- | tensorflow/python/eager/core_test.py | 24 | ||||
-rw-r--r-- | tensorflow/python/eager/execution_callbacks.py | 8 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 875 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 290 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable.py | 435 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable_test.py | 249 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_only_ops_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/imperative_grad.py | 10 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tensor.cc | 90 | ||||
-rwxr-xr-x[-rw-r--r--] | tensorflow/python/eager/pywrap_tfe.h | 29 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 375 | ||||
-rw-r--r-- | tensorflow/python/eager/tape.py | 26 | ||||
-rw-r--r-- | tensorflow/python/eager/tape_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/tensor_test.py | 15 |
19 files changed, 1298 insertions, 1529 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index de93b1e2e1..85da1baaf0 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -47,7 +47,6 @@ py_library( ":core", ":execute", ":function", - ":graph_callable", ":graph_only_ops", ":tape", ":test", @@ -238,10 +237,11 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":graph_only_ops", + "//tensorflow/python:cond_v2_impl", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", + "//tensorflow/python:gradients_impl", "//tensorflow/python:graph_to_function_def", "//tensorflow/python:util", "//tensorflow/python/eager:context", @@ -254,41 +254,6 @@ py_library( ) py_library( - name = "graph_callable", - srcs = ["graph_callable.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:function", - "//tensorflow/python/eager:tape", - ], -) - -py_test( - name = "graph_callable_test", - srcs = ["graph_callable_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":backprop", - ":graph_callable", - "//tensorflow/python:dtypes", - "//tensorflow/python:function", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:test", - ], -) - -py_library( name = "backprop", srcs = ["backprop.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 553f761a14..be392c7a0f 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging @@ -180,10 +181,10 @@ def implicit_val_and_grad(f): ``` Args: - f: function to be differentiated. If `f` returns a scalar, this scalar will - be differentiated. If `f` returns a tensor or list of tensors, by default - a scalar will be computed by adding all their values to produce a single - scalar. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. Returns: A function which, when called, returns a tuple pair. @@ -215,9 +216,7 @@ def implicit_val_and_grad(f): "function was being computed.") sources = [v.handle for v in variables] - grad = imperative_grad.imperative_grad(_default_vspace, - this_tape, - nest.flatten(end_node), + grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -255,10 +254,10 @@ def implicit_grad(f): ``` Args: - f: function to be differentiated. If `f` returns a scalar, this scalar will - be differentiated. If `f` returns a tensor or list of tensors, by default - a scalar will be computed by adding all their values to produce a single - scalar. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. Returns: A function which, when called, returns a list of (gradient, variable) pairs. @@ -343,24 +342,24 @@ def gradients_function(f, params=None): Note that only tensors with real or complex dtypes are differentiable. Args: - f: function to be differentiated. If `f` returns a scalar, this scalar will - be differentiated. If `f` returns a tensor or list of tensors, by default - a scalar will be computed by adding all their values to produce a single - scalar. If desired, the tensors can be elementwise multiplied by the - tensors passed as the `dy` keyword argument to the returned gradient - function. - params: list of parameter names of f or list of integers indexing the - parameters with respect to which we'll differentiate. Passing None - differentiates with respect to all parameters. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. If desired, the tensors can be elementwise multiplied by the + tensors passed as the `dy` keyword argument to the returned gradient + function. + params: list of parameter names of f or list of integers indexing the + parameters with respect to which we'll differentiate. Passing None + differentiates with respect to all parameters. Returns: function which, when called, returns the value of f and the gradient - of f with respect to all of `params`. The function takes an extra optional - keyword argument "dy". Setting it allows computation of vector jacobian + of `f` with respect to all of `params`. The function takes an extra optional + keyword argument `dy`. Setting it allows computation of vector jacobian products for vectors other than the vector of ones. Raises: - ValueError: if the params are not all strings or all integers. + ValueError: if the params are not all strings or all integers. """ def decorated(*args, **kwds): @@ -440,23 +439,24 @@ def val_and_grad_function(f, params=None): ``` Args: - f: function to be differentiated. If `f` returns a scalar, this scalar will - be differentiated. If `f` returns a tensor or list of tensors, by default - a scalar will be computed by adding all their values to produce a single - scalar. If desired, the tensors can be elementwise multiplied by the - tensors passed as the `dy` keyword argument to the returned gradient - function. - params: list of parameter names of f or list of integers indexing the - parameters with respect to which we'll differentiate. Passing `None` - differentiates with respect to all parameters. - - Returns: function which, when called, returns the value of f and the gradient - of f with respect to all of `params`. The function takes an extra optional - keyword argument "dy". Setting it allows computation of vector jacobian - products for vectors other than the vector of ones. + f: function to be differentiated. If `f` returns a scalar, this scalar will + be differentiated. If `f` returns a tensor or list of tensors, by default + a scalar will be computed by adding all their values to produce a single + scalar. If desired, the tensors can be elementwise multiplied by the + tensors passed as the `dy` keyword argument to the returned gradient + function. + params: list of parameter names of f or list of integers indexing the + parameters with respect to which we'll differentiate. Passing `None` + differentiates with respect to all parameters. + + Returns: + function which, when called, returns the value of f and the gradient + of f with respect to all of `params`. The function takes an extra optional + keyword argument "dy". Setting it allows computation of vector jacobian + products for vectors other than the vector of ones. Raises: - ValueError: if the params are not all strings or all integers. + ValueError: if the params are not all strings or all integers. """ def decorated(*args, **kwds): @@ -520,7 +520,7 @@ def make_vjp(f, params=None, persistent=True): args = _ensure_unique_tensor_objects(parameter_positions, args) for i in parameter_positions: sources.append(args[i]) - tape.watch(args[i]) + tape.watch(this_tape, args[i]) result = f(*args) if result is None: raise ValueError("Cannot differentiate a function that returns None; " @@ -535,8 +535,8 @@ def make_vjp(f, params=None, persistent=True): if dy is not None: dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] return imperative_grad.imperative_grad( - _default_vspace, this_tape, nest.flatten(result), sources, - output_gradients=dy) + this_tape, nest.flatten(result), sources, output_gradients=dy) + return result, vjp return decorated @@ -557,7 +557,7 @@ def _aggregate_grads(gradients): if len(gradients) == 1: return gradients[0] if all([isinstance(g, ops.Tensor) for g in gradients]): - return math_ops.add_n(gradients) + return gen_math_ops.add_n(gradients) else: assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in gradients]) @@ -592,7 +592,9 @@ def _num_elements(grad): def _fast_fill(value, shape, dtype): - return array_ops.fill(shape, constant_op.constant(value, dtype=dtype)) + return array_ops.fill( + constant_op.constant(shape, dtype=dtypes.int32), + constant_op.constant(value, dtype=dtype)) def _zeros(shape, dtype): @@ -627,9 +629,9 @@ def _ones(shape, dtype): _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, - tensor_id=ops.tensor_id, zeros=_zeros, ones=_ones) +pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace) def _handle_or_self(x): @@ -691,19 +693,57 @@ class GradientTape(object): del g # Drop the reference to the tape ``` + By default GradientTape will automatically watch any trainable variables that + are accessed inside the context. If you want fine grained control over which + variables are watched you can disable automatic tracking by passing + `watch_accessed_variables=False` to the tape constructor: + + ```python + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(variable_a) + y = variable_a ** 2 # Gradients will be available for `variable_a`. + z = variable_b ** 3 # No gradients will be avaialble since `variable_b` is + # not being watched. + ``` + + Note that when using models you should ensure that your variables exist when + using `watch_accessed_variables=False`. Otherwise it's quite easy to make your + first iteration not have any gradients: + + ```python + a = tf.keras.layers.Dense(32) + b = tf.keras.layers.Dense(32) + + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(a.variables) # Since `a.build` has not been called at this point + # `a.variables` will return an empty list and the + # tape will not be watching anything. + result = b(a(inputs)) + tape.gradient(result, a.variables) # The result of this computation will be + # a list of `None`s since a's variables + # are not being watched. + ``` + Note that only tensors with real or complex dtypes are differentiable. """ - def __init__(self, persistent=False): + def __init__(self, persistent=False, watch_accessed_variables=True): """Creates a new GradientTape. Args: persistent: Boolean controlling whether a persistent gradient tape is created. False by default, which means at most one call can be made to the gradient() method on this object. + watch_accessed_variables: Boolean controlling whether the tape will + automatically `watch` any (trainable) variables accessed while the tape + is active. Defaults to True meaning gradients can be requested from any + result computed in the tape derived from reading a trainable `Variable`. + If False users must explicitly `watch` any `Variable`s they want to + request gradients from. """ self._tape = None self._persistent = persistent + self._watch_accessed_variables = watch_accessed_variables self._recording = False context.context().start_step() @@ -717,15 +757,15 @@ class GradientTape(object): if self._recording: self._pop_tape() - def _push_tape(self, existing_tape=False): + def _push_tape(self): if self._recording: raise ValueError("Tape is already recording.") - if existing_tape: - if self._tape is None: - raise ValueError("There is no existing tape.") - tape.push_tape(self._tape) + if self._tape is None: + self._tape = tape.push_new_tape( + persistent=self._persistent, + watch_accessed_variables=self._watch_accessed_variables) else: - self._tape = tape.push_new_tape(persistent=self._persistent) + tape.push_tape(self._tape) self._recording = True def _pop_tape(self): @@ -744,7 +784,13 @@ class GradientTape(object): tensor: a Tensor or list of Tensors. """ for t in nest.flatten(tensor): - tape.watch(_handle_or_self(t)) + if hasattr(t, "handle"): + # There are many variable-like objects, all of them currently have + # `handle` attribute that points to a tensor. If this changes, internals + # of watch_variable need to change as well. + tape.watch_variable(self._tape, t) + else: + tape.watch(self._tape, t) @tf_contextlib.contextmanager def stop_recording(self): @@ -776,7 +822,7 @@ class GradientTape(object): try: yield finally: - self._push_tape(existing_tape=True) + self._push_tape() def reset(self): """Clears all information stored in this tape. @@ -810,6 +856,7 @@ class GradientTape(object): ``` """ self._pop_tape() + self._tape = None self._push_tape() def watched_variables(self): @@ -861,7 +908,9 @@ class GradientTape(object): for x in nest.flatten(output_gradients)] flat_grad = imperative_grad.imperative_grad( - _default_vspace, self._tape, nest.flatten(target), flat_sources, + self._tape, + nest.flatten(target), + flat_sources, output_gradients=output_gradients) if not self._persistent: diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 3d3f54b9c4..f938ed5df8 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -65,7 +64,7 @@ class BackpropTest(test.TestCase): grad = backprop.gradients_function(fn, [0])(var)[0] grad = self.evaluate(ops.convert_to_tensor(grad)) - with context.graph_mode(), self.test_session(): + with context.graph_mode(): tf_var = array_ops.constant(var_np, dtypes.float32) tf_ind1 = array_ops.constant([0, 1]) tf_ind2 = array_ops.constant([2, 3]) @@ -80,14 +79,13 @@ class BackpropTest(test.TestCase): tf_dense_grad = math_ops.unsorted_segment_sum( tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0]) - self.assertAllClose(grad, tf_dense_grad.eval()) + self.assertAllClose(grad, self.evaluate(tf_dense_grad)) def testImplicitGradWithResourceVariable(self): x = resource_variable_ops.ResourceVariable( initial_value=constant_op.constant(1.0), name='x') def fn(): - tape.watch_variable(x) b = constant_op.constant(2.0) c = math_ops.add(x.value(), b) return math_ops.add(c, constant_op.constant(3.0)) @@ -194,14 +192,13 @@ class BackpropTest(test.TestCase): initial_value=random_init, dtype=dtypes.float32, name='embedding') def f(): - tape.watch_variable(embedding) embedded_x = embedding_ops.embedding_lookup(embedding, x) return constant_op.constant(1.0, dtypes.float32) - embedded_x grad = backprop.implicit_grad(f)()[0][0] opt = training.GradientDescentOptimizer(lrn_rate) - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): tf_x = array_ops.ones((batch_size), dtypes.int64) # TODO(ashankar,apassos): Change to ResourceVariable. tf_embedding = variables.Variable( @@ -316,6 +313,24 @@ class BackpropTest(test.TestCase): grad = backprop.gradients_function(second, [0])(f)[0] self.assertAllEqual([[0.0]], grad) + @test_util.run_in_graph_and_eager_modes + def testWatchingIsTapeLocal(self): + x1 = resource_variable_ops.ResourceVariable(2.0, trainable=False) + x2 = resource_variable_ops.ResourceVariable(2.0, trainable=False) + + with backprop.GradientTape() as tape1: + with backprop.GradientTape() as tape2: + tape1.watch(x1) + tape2.watch([x1, x2]) + y = x1 ** 3 + z = x2 ** 2 + dy, dz = tape2.gradient([y, z], [x1, x2]) + d2y, d2z = tape1.gradient([dy, dz], [x1, x2]) + + self.evaluate([x1.initializer, x2.initializer]) + self.assertEqual(self.evaluate(d2y), 12.0) + self.assertIsNone(d2z) + @test_util.assert_no_new_tensors def testMakeVJP(self): @@ -404,7 +419,6 @@ class BackpropTest(test.TestCase): def f(): with context.device('gpu:0'): - tape.watch_variable(v) return v.read_value() self.assertEqual( @@ -460,6 +474,18 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) @test_util.assert_no_new_tensors + def testGradientTapeReEnterContext(self): + g = backprop.GradientTape() + with g: + x = constant_op.constant(3.0) + g.watch(x) + y = 2*x + with g: + z = 2*y + grad = g.gradient(target=z, sources=[x]) + self.assertEqual(self.evaluate(grad), [4.0]) + + @test_util.assert_no_new_tensors @test_util.run_in_graph_and_eager_modes def testGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=False) as g: @@ -784,7 +810,6 @@ class BackpropTest(test.TestCase): initial_value=array_ops.constant([1.0]), name='x') def fn(): - tape.watch_variable(x) a = math_ops.add(x.value(), 1.0) # Make sure convert_to_tensor works correctly with list of TensorNodes. b = array_ops.stack([a, a], axis=0) @@ -928,21 +953,75 @@ class BackpropTest(test.TestCase): def testZerosCacheDoesntLeakAcrossGraphs(self): with context.graph_mode(): def get_grad(): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4)) x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4)) - with backprop.GradientTape() as gt: + with backprop.GradientTape() as tape: tape.watch(x) x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) y1 = x1**2 y = array_ops.concat([y1, t], axis=1) - return self.evaluate(gt.gradient(y, x)) + return self.evaluate(tape.gradient(y, x)) grad1 = get_grad() grad2 = get_grad() self.assertAllEqual(grad1, grad2) + @test_util.run_in_graph_and_eager_modes + def testSelectivelyWatchVariables(self): + x1 = resource_variable_ops.ResourceVariable(1.0) + x2 = resource_variable_ops.ResourceVariable(1.0) + with backprop.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(x2) + y = x1**2 + z = x2**3 + self.assertTupleEqual(tape.watched_variables(), (x2,)) + dy, dz = tape.gradient([y, z], [x1, x2]) + self.evaluate([x1.initializer, x2.initializer]) + self.assertIsNone(dy) + self.assertEqual(self.evaluate(dz), 3.0) + + + @test_util.run_in_graph_and_eager_modes + def testDifferentiatingScalarCache(self): + # In the following test, if x2 = x1 (i.e the objects are the exact same), + # then y is essentially, 2*x1, and dy/dx1 = 2. + # When we had a pure scalar cache in eager, this would be the case. This + # test prevents us from going back to that case. + with backprop.GradientTape(persistent=False) as g: + x1 = constant_op.constant(3.0) + x2 = constant_op.constant(3.0) + g.watch(x1) + g.watch(x2) + y = x1 + x2 + grad = g.gradient(target=y, sources=[x1]) + self.assertEqual(self.evaluate(grad), [1.0]) + + def testVariablesAndConstantsProduceTheSameGradients(self): + + # In the following test, differentiating [y, z] against [a, b] gives: + # (dy/da + dz/da, dy/db + dz/db). + # If a and b are the same constant, dz/da will not be 0 (which it should + # be). + # This is solved by using variable since doing a read_value on a tensor will + # produce a new tensor and corresponding TensorHandle, and not reuse the + # same tensor (which would happen if we are using a cache and reusing + # EagerTensor objects). + def get_grads(a, b): + with backprop.GradientTape() as tape: + tape.watch([a, b]) + y = a**3 + z = b**2 + return tape.gradient([y, z], [a, b]) + + gradients_constants = get_grads( + constant_op.constant(2.0), constant_op.constant(2.0)) + gradients_variables = get_grads( + resource_variable_ops.ResourceVariable(2.0), + resource_variable_ops.ResourceVariable(2.0)) + self.assertAllEqual(gradients_constants, gradients_variables) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index e2b1890c2f..3fe79ef244 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -42,6 +42,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -175,6 +176,11 @@ class MicroBenchmarks(test.Benchmark): self._run(func, 30000) + def benchmark_create_constant(self): + func = lambda: constant_op.constant(3.0) + + self._run(func, 30000) + def benchmark_create_float_tensor_from_list_CPU(self): self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU) @@ -350,6 +356,21 @@ class MicroBenchmarks(test.Benchmark): func = lambda: f(m, m, transpose_b) self._run(func, num_iters, execution_mode=execution_mode) + def _benchmark_defun_matmul_forward_backward(self, + m, + transpose_b, + num_iters, + execution_mode=None): + f = function.defun(math_ops.matmul) + + def func(): + with backprop.GradientTape() as gt: + gt.watch(m) + y = f(m, m, transpose_b) + _ = gt.gradient(y, m) + + self._run(func, num_iters, execution_mode=execution_mode) + def _benchmark_read_variable(self, m, num_iters): self._run(m.value, num_iters) @@ -421,6 +442,21 @@ class MicroBenchmarks(test.Benchmark): num_iters=self._num_iters_2_by_2, execution_mode=context.ASYNC) + def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_defun_matmul_forward_backward( + m, transpose_b=False, num_iters=self._num_iters_2_by_2) + + def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_defun_matmul_forward_backward( + m, + transpose_b=False, + num_iters=self._num_iters_2_by_2, + execution_mode=context.ASYNC) + def benchmark_tf_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -682,6 +718,25 @@ class MicroBenchmarks(test.Benchmark): assert np.equal(func(), make_keras_model()(data)).all() self._run(func, 30000) + def benchmarkScan(self): + elems = math_ops.range(1600) + + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + + def benchmarkScanDefun(self): + elems = math_ops.range(1600) + + @function.defun + def scan(): + return functional_ops.scan( + lambda a, x: a + x, elems, parallel_iterations=1) + + self._run(scan, 100) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 6a327bd010..778ff85342 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -37,7 +37,7 @@ GRAPH_MODE = 0 EAGER_MODE = 1 # Default execution mode. -_default_mode = GRAPH_MODE +default_execution_mode = GRAPH_MODE # Cache from (old_device_name, partial_new_device_name) -> (new_device_name, # new_device_spec). @@ -56,14 +56,18 @@ SYNC = 0 ASYNC = 1 -class _TensorCache(object): +class _EagerTensorCache(object): """Simple cache which evicts items based on length in a FIFO manner.""" - def __init__(self, max_items=256): + def __init__(self, max_items=256, max_tensor_size=10000): self._data = collections.OrderedDict() - self._max_items = max_items if max_items else 256 + self._max_items = max_items + self._max_tensor_size = max_tensor_size def put(self, key, value): + if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access + return + self._data[key] = value if len(self._data) > self._max_items: @@ -84,14 +88,14 @@ class _EagerContext(threading.local): super(_EagerContext, self).__init__() self.device_spec = pydev.DeviceSpec.from_string("") self.device_name = self.device_spec.to_string() - self.mode = _default_mode - self.is_eager = _default_mode == EAGER_MODE + self.mode = default_execution_mode + self.is_eager = default_execution_mode == EAGER_MODE self.scope_name = "" self.recording_summaries = False self.summary_writer_resource = None self.scalar_cache = {} - self.ones_rank_cache = _TensorCache() - self.zeros_cache = _TensorCache() + self.ones_rank_cache = _EagerTensorCache() + self.zeros_cache = _EagerTensorCache() self.execution_mode = None @@ -111,8 +115,8 @@ class _ContextSwitchStack(threading.local): # Initialize the stack with a pointer to enter the eager context; this # ensures that the fact that eager execution was enabled is propagated # across threads, since (1) `enable_eager_execution` modifies a - # process-level flag (`_default_mode`) and (2) `__init__` is called each - # time a threading.local object is used in a separate thread. + # process-level flag (`default_execution_mode`) and (2) `__init__` is + # called each time a threading.local object is used in a separate thread. self.push(is_building_function=False, enter_context_fn=eager_mode) def push(self, is_building_function, enter_context_fn): @@ -504,9 +508,7 @@ class Context(object): Args: fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). """ - pywrap_tensorflow.TFE_ContextAddFunction( - self._handle, # pylint: disable=protected-access - fn) + pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn) def add_function_def(self, fdef): """Add a function definition to the context. @@ -519,9 +521,7 @@ class Context(object): """ fdef_string = fdef.SerializeToString() pywrap_tensorflow.TFE_ContextAddFunctionDef( - self._handle, # pylint: disable=protected-access - fdef_string, - len(fdef_string)) + self._handle, fdef_string, len(fdef_string)) def add_post_execution_callback(self, callback): """Add a post-execution callback to the context. @@ -633,14 +633,7 @@ def context(): def context_safe(): - return _context - - -# TODO(agarwal): remove this. -def get_default_context(): - """Same as context.""" - if _context is None: - _initialize_context() + """Returns current context (or None if one hasn't been initialized).""" return _context diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index cc765725a4..fb5442b646 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import pickle import threading import numpy as np @@ -185,6 +187,17 @@ class TFETest(test_util.TensorFlowTestCase): device_count={'GPU': 0})) self.assertEquals(0, ctx.num_gpus()) + def testPickle(self): + tmp_dir = self.get_temp_dir() + fname = os.path.join(tmp_dir, 't.pickle') + with open(fname, 'wb') as f: + t = constant_op.constant(10.0) + pickle.dump(t, f) + + with open(fname, 'rb') as f: + t = pickle.load(f) + self.assertAllEqual(t.numpy(), 10.0) + def testTensorPlacement(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') @@ -676,5 +689,16 @@ class SendRecvTest(test_util.TensorFlowTestCase): 2.0) +class EagerTensorCacheTest(test_util.TensorFlowTestCase): + + def testCacheSkipsTensorsTooLarge(self): + cache = context._EagerTensorCache(max_items=100, max_tensor_size=3) + cache.put('1', array_ops.zeros((2, 2))) + self.assertEqual(cache.get('1'), None) + + cache.put('2', array_ops.zeros((2))) + self.assertNotEqual(cache.get('2'), None) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py index 9a08259653..80ff4459d6 100644 --- a/tensorflow/python/eager/execution_callbacks.py +++ b/tensorflow/python/eager/execution_callbacks.py @@ -146,7 +146,7 @@ def inf_nan_callback(op_type, """ del attrs, inputs # Not used. - ctx = context.get_default_context() + ctx = context.context() for index, output in enumerate(outputs): if not output.dtype.is_numpy_compatible: @@ -263,12 +263,12 @@ def add_execution_callback(callback): Return value(s) from the callback are ignored. """ execute.execute = execute.execute_with_callbacks - context.get_default_context().add_post_execution_callback(callback) + context.context().add_post_execution_callback(callback) def clear_execution_callbacks(): """Clear all execution callbacks from the default eager context.""" - context.get_default_context().clear_post_execution_callbacks() + context.context().clear_post_execution_callbacks() def seterr(inf_or_nan=None): @@ -309,7 +309,7 @@ def seterr(inf_or_nan=None): "Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS)) old_settings = {"inf_or_nan": "ignore"} - default_context = context.get_default_context() + default_context = context.context() carryover_callbacks = [] for callback in default_context.post_execution_callbacks: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 5afba466bc..03f12139f6 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -21,12 +21,12 @@ from __future__ import print_function import collections import functools +import sys import threading import numpy as np import six -from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context @@ -34,10 +34,12 @@ from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl @@ -49,8 +51,15 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +# This is to avoid a circular dependency with cond_v2_impl +# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl). +cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access -def create_substitute_placeholder(value, name, dtype=None): +# This is to avoid a circular dependency with gradients_impl +gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access + + +def _create_substitute_placeholder(value, name, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. @@ -82,82 +91,131 @@ def create_substitute_placeholder(value, name, dtype=None): return placeholder -def capture_value(tensor_map, value, dtype, name): - """Capture a value from outside the function, to pass in as an extra arg.""" - captured_value = tensor_map.get(value, None) - if captured_value is None: - captured_value = create_substitute_placeholder(value, name=name, - dtype=dtype) - tensor_map[value] = captured_value - tape.record_operation("captured_value", [captured_value], [value], - lambda x: [x]) - return captured_value +def _get_device_functions(ctx, graph): + """Returns a tuple of device functions representing the device stack.""" + if ctx.executing_eagerly(): + return (pydev.merge_device(ctx.device_name),) + else: + return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access -class CapturingGraph(ops.Graph): - """Graph that can capture tensors from other graphs. +class FuncGraph(ops.Graph): + """Graph representing a function body. Attributes: - captures: Maps external tensor -> internal tensor (e.g. input placeholder). + name: The name of the function. + inputs: Placeholder tensors representing the inputs to this function. The + tensors are in this FuncGraph. This represents "regular" inputs as well as + captured inputs (i.e. the values of self.captures), with the regular + inputs coming first. + outputs: Tensors that will be returned by this function. The tensors are in + this FuncGraph. + structured_outputs: A possibly-nested python object which will be returned + by this function. The Tensors in this structure are the same as those of + self.outputs. Note that this structure might contain Python `None`s. + variables: Variables that should be watched during function execution. + outer_graph: The graph this function is defined in. May be another FuncGraph + or the global default Graph. + captures: Maps external tensor -> internal tensor (i.e. input placeholder). The entries are in the order they were captured. + seed: The graph-level random seed. """ - def __init__(self): - super(CapturingGraph, self).__init__() + def __init__(self, name): + """Construct a new FuncGraph. + + The graph will inherit its graph key, collections, seed, device stack, and + distribution strategy stack from the current context or graph. + Args: + name: the name of the function. + """ + super(FuncGraph, self).__init__() + + self.name = name + self.inputs = [] + self.outputs = [] + self.structured_outputs = None + self.variables = [] + self.outer_graph = ops.get_default_graph() self.captures = collections.OrderedDict() - self._building_function = True + self._building_function = True # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. self._last_op_using_resource_tensor = {} - # TODO(apassos) remove once the C API is used by default. - def _use_c_api_hack(self): - return True - - def clear_resource_control_flow_state(self): - self._last_op_using_resource_tensor = {} + graph = self.outer_graph - # TODO(skyewm): get rid of name and use the name of `tensor`. - def capture(self, tensor, name=None): - """Capture `tensor` if it's external to this graph. - - If `tensor` is from a different graph, returns a placeholder for it. - `tensor` and the placeholder will also appears in self.captures. Multiple - calls to this method with the same `tensor` argument will return the same - placeholder. If `tensor` is from this graph, returns `tensor`. - - Args: - tensor: Tensor. May be from this FuncGraph or a different graph. - name: Optional name if a placeholder is created. - - Returns: - Tensor from this FuncGraph. - """ - if isinstance(tensor, ops.EagerTensor): - if name is None: - name = str(ops.uid()) - return capture_value(self.captures, tensor, tensor.dtype, name) - if tensor.graph is not self: - if name is None: - name = tensor.op.name - return capture_value(self.captures, tensor, tensor.dtype, name) - return tensor + if context.executing_eagerly(): + self.seed = context.global_seed() + self._xla_compile = (context.context().device_spec.device_type == "TPU") + self._add_device_to_stack(context.context().device_name) + else: + self.seed = graph.seed + self._xla_compile = getattr(graph, "_xla_compile", False) + self._device_function_stack = graph._device_function_stack.copy() # pylint: disable=protected-access + self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access + + # TODO(b/112165328, b/112906995): summaries depend on inheriting collections + # from the default graph even in eager mode. It'd be nice to not have a + # default graph with eager execution, so hopefully this will go away when we + # remove collections. + # pylint: disable=protected-access + self._collections = graph._collections + # TODO(b/112906995): distribution strategy depends on inheriting this stack + # from the default graph even in eager mode. Maybe it should be part of the + # eager context? + self._distribution_strategy_stack = graph._distribution_strategy_stack + # Inherit the graph key, since this is used for matching variables in + # optimizers. + self._graph_key = graph._graph_key + # pylint: enable=protected-access def create_op( self, op_type, inputs, - dtypes, # pylint: disable=redefined-outer-name + dtypes, input_types=None, name=None, attrs=None, op_def=None, compute_shapes=True, compute_device=True): - """Captures an external inputs before calling Graph.capture_op.""" + """Like Graph.create_op, except handles external input tensors. + + This overload adds functionality to create_op to "capture" any external + input tensors, i.e. tensors from the eager context or outer function graphs + if this is a nested function. See `capture` for more information. + + Args: + op_type: The `Operation` type to create. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + inputs: A list of `Tensor` objects that will be inputs to the `Operation`. + dtypes: A list of `DType` objects that will be the types of the tensors + that the operation produces. + input_types: (Optional.) A list of `DType`s that will be the types of + the tensors that the operation consumes. By default, uses the base + `DType` of each input in `inputs`. Operations that expect + reference-typed inputs must specify `input_types` explicitly. + name: (Optional.) A string name for the operation. If not specified, a + name is generated based on `op_type`. + attrs: (Optional.) A dictionary where the key is the attribute name (a + string) and the value is the respective `attr` attribute of the + `NodeDef` proto that will represent the operation (an `AttrValue` + proto). + op_def: (Optional.) The `OpDef` proto that describes the `op_type` that + the operation will have. + compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always + computed). + compute_device: (Optional.) If True, device functions will be executed + to compute the device property of the Operation. + + Returns: + An `Operation` object. + """ # This capturing logic interacts poorly with control flow contexts which # want to replace inputs of ops far too late in the process. This can lead # the context to get confused and try to create an Enter for an Enter. We @@ -171,80 +229,61 @@ class CapturingGraph(ops.Graph): # to capture the inputs. ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access for i, inp in enumerate(inputs): + # TPU Estimator defines a control flow context with no AddValue method. if ctxt is not None and hasattr(ctxt, "AddValue"): inp = ctxt.AddValue(inp) inp = self.capture(inp) inputs[i] = inp - return super(CapturingGraph, self).create_op( + return super(FuncGraph, self).create_op( op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device=compute_device) + def capture(self, tensor, name=None): + """Captures `tensor` if it's external to this graph. -class FuncGraph(CapturingGraph): - """Graph representing a function body. - - Attributes: - name: The name of the function. - - inputs: Placeholder tensors representing the inputs to this function. The - tensors are in this FuncGraph. This represents "regular" inputs as well as - captured inputs (i.e. the values of self.captures), with the regular - inputs coming first. - outputs: Tensors that will be returned by this function. The tensors are in - this FuncGraph. - structured_outputs: A possibly-nested python object which will be returned - by this function. The Tensors in this structure are the same as those of - self.outputs. Note that this structure might contain Python `None`s. - variables: Variables that should be watched during function execution. - seed: The graph-level random seed. - """ - - def __init__(self, name, graph=None): - """Construct a new FuncGraph. + If `tensor` is from a different graph, returns a placeholder for it. + `tensor` and the placeholder will appear in self.captures, and the + placeholder will appear in self.inputs. Multiple calls to this method with + the same `tensor` argument will return the same placeholder. If `tensor` is + from this graph, returns `tensor`. Args: - name: the name of the function. - graph: if specified, this FuncGraph will inherit its graph key, - collections, and seed from `graph`. - """ - super(FuncGraph, self).__init__() - - self.name = name - self.inputs = [] - self.outputs = [] - self.structured_outputs = None - self.variables = [] - - if graph is not None: - # Inherit the graph key, since this is used for matching variables in - # optimizers. - self._graph_key = graph._graph_key # pylint: disable=protected-access - - # Copy the graph collections to ensure summaries and other things work. - # This lets the function access (but not mutate) collections of the - # containing graph, such as the global step and the summary writer - # collections. - for collection in graph.collections: - self.get_collection_ref(collection)[:] = graph.get_collection( - collection) - - # Copy distribution strategy scope from the containing graph as well. - self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access + tensor: Tensor. May be from this FuncGraph or a different graph. + name: Optional name if a placeholder is created. - if context.executing_eagerly(): - self.seed = context.global_seed() - else: - self.seed = graph.seed + Returns: + Tensor from this FuncGraph. + """ + if isinstance(tensor, ops.EagerTensor): + if name is None: + name = str(ops.uid()) + return self._capture_helper(tensor, name) + if tensor.graph is not self: + if name is None: + name = tensor.op.name + return self._capture_helper(tensor, name) + return tensor - def capture(self, tensor, name=None): - """Calls CapturingGraph.capture and updates self.inputs if necessary.""" - new_capture = tensor not in self.captures - internal_tensor = super(FuncGraph, self).capture(tensor, name) + def _capture_helper(self, tensor, name): + captured_tensor = self.captures.get(tensor, None) + if captured_tensor is None: + captured_tensor = _create_substitute_placeholder(tensor, name=name, + dtype=tensor.dtype) + self.captures[tensor] = captured_tensor + self.inputs.append(captured_tensor) + tape.record_operation("captured_value", [captured_tensor], [tensor], + lambda x: [x]) + return captured_tensor - if new_capture and tensor is not internal_tensor: - self.inputs.append(internal_tensor) + @property + def external_captures(self): + """External tensors captured by this function.""" + return list(self.captures.keys()) - return internal_tensor + @property + def internal_captures(self): + """Placeholders in this function corresponding captured tensors.""" + return list(self.captures.values()) def _forward_name(n): @@ -267,9 +306,6 @@ def _register(fn): context.context().add_function(fn) -_xla_compile_attr = "_XlaCompile" - - # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction # so it doesn't have the definition-generating logic and is just a container for # an already-defined function. @@ -282,18 +318,20 @@ class _EagerDefinedFunction(object): class may be provided as the value of these `func` attributes. """ - def __init__(self, name, graph, operations, inputs, outputs, attrs): + def __init__(self, name, graph, inputs, outputs, attrs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function - operations: list of Operation; the subset of operations in the graph - which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function attrs: dict mapping names of attributes to their AttrValue values """ + operations = [ + op for op in graph.get_operations() + if op not in set(arg.op for arg in inputs) + ] fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), @@ -311,7 +349,6 @@ class _EagerDefinedFunction(object): # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) - self._xla_compile = _xla_compile_attr in attrs # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. @@ -327,6 +364,7 @@ class _EagerDefinedFunction(object): self.signature = function_def.signature self._num_outputs = len(self.signature.output_arg) self._output_types = [o.type for o in self.signature.output_arg] + self._output_shapes = [o.shape for o in outputs] self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) @@ -347,7 +385,7 @@ class _EagerDefinedFunction(object): def stateful_ops(self): return self._stateful_ops - def call(self, ctx, args, output_shapes): + def call(self, ctx, args): """Calls this function with `args` as inputs. Function execution respects device annotations only if the function won't @@ -356,8 +394,6 @@ class _EagerDefinedFunction(object): Args: ctx: a Context object args: a list of arguments to supply this function with. - output_shapes: shapes to which outputs should be set; ignored when - executing eagerly. Returns: The outputs of the function call. @@ -365,10 +401,7 @@ class _EagerDefinedFunction(object): executing_eagerly = ctx.executing_eagerly() - xla_compile = self._xla_compile or (executing_eagerly and - ctx.device_spec.device_type == "TPU") - - if xla_compile: + if self._graph._xla_compile: # pylint: disable=protected-access # XLA compilation relies upon a custom kernel creator to run functions. signature = self.signature if executing_eagerly: @@ -406,7 +439,7 @@ class _EagerDefinedFunction(object): if executing_eagerly: return outputs else: - for i, shape in enumerate(output_shapes): + for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) return outputs @@ -427,179 +460,117 @@ def _flatten(sequence): return outputs -# TODO(akshayka): Perhaps rename to something more appropriate. -class GraphModeFunction(object): +class Function(object): """Callable object encapsulating a function definition and its gradient. - `GraphModeFunction` is a callable that encapsulates a function definition and + `Function` is a callable that encapsulates a function definition and is differentiable under `tf.GradientTape` objects. """ - def __init__(self, - name, - input_placeholders, - extra_inputs, - graph, - operations, - outputs, - python_func_outputs, - output_shapes, - variables=None, - attrs=None): - """Initialize a GraphModeFunction. + def __init__(self, func_graph, attrs=None): + """Initialize a Function. Args: - name: str the name of the created function - input_placeholders: list of placeholder values (tensors) to feed when - calling the wrapped function. - extra_inputs: Tensor inputs this function definition closed over which - are passed as arguments. Need to track so gradients are supported - correctly. - graph: the Graph from which the operations will be pulled. Used as - a context when computing gradients. - operations: the subset of Operations in the graph used in the function - definition. - outputs: a flat list of the Tensors in the graph used as outputs to the - function - python_func_outputs: a possibly nested python object which will be - returned by this function. The Tensors in this structure will be - replaced by their corresponding values in outputs. Note that this - structure might contain Python `None`s. - output_shapes: List of shapes of all tensors in outputs - variables: (optional) List of variables to watch during function - execution. + func_graph: An instance of FuncGraph: the function body to wrap. attrs: (optional) dict mapping names of attributes to their AttrValue values. Attributes in `attrs` will be included in this function's definition. + + Raises: + ValueError: If number of input_placeholders is not equal to the number + of function inputs. """ + self._func_graph = func_graph + self._captured_inputs = list(self._func_graph.captures.keys()) + self._num_outputs = len(self._func_graph.outputs) + self._output_shapes = tuple( + output.shape for output in self._func_graph.outputs) self._attrs = attrs or {} - defined_function = _EagerDefinedFunction( - name, graph, operations, input_placeholders, outputs, self._attrs) - if len(input_placeholders) != len(defined_function.signature.input_arg): - raise ValueError("Internal error: invalid lengths. %s %s" % ( - len(input_placeholders), len(defined_function.signature.input_arg))) - self._input_placeholders = input_placeholders - self._extra_inputs = list(extra_inputs) - self._graph = graph - self._backward_function = None - self._func_name = name - self._function_def = defined_function - self._num_outputs = len(defined_function.signature.output_arg) - self._python_func_outputs = python_func_outputs - self._python_returns = [python_func_outputs] if isinstance( - python_func_outputs, - (ops.Tensor, type(None))) else _flatten(python_func_outputs) - self._output_shapes = output_shapes - self._variables = variables if variables is not None else [] - - # Find the variables that are components of something distributed and - # put them into a {handle_tensor -> distributed variable object} map. + self._device_functions = tuple( + self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access + + self._inference_function = _EagerDefinedFunction( + _inference_name(self._func_graph.name), self._func_graph, + self._func_graph.inputs, self._func_graph.outputs, self._attrs) + self._backward_graph_function = None + + # Map holding distributed variables, keyed by resource handle tensors. self._distributed_variables = {} strategy = distribution_strategy_context.get_distribution_strategy() - for variable in self._variables: + for variable in self._func_graph.variables: # If variable is not distributed, unwrap returns [variable]. component_variables = strategy.unwrap(variable) - # Only add to the dictionary when the variable is actually distributed, - # i.e. more than one component or the component is different from the - # variable itself. component_variables cannot be empty. + # Only update the dictionary when the variable is actually distributed. if (len(component_variables) > 1 or component_variables[0] != variable): for component_variable in component_variables: self._distributed_variables[component_variable.handle] = variable - @property - def variables(self): - return self._variables + def __call__(self, *args): + """Executes the wrapped function.""" + ctx = context.context() + device_functions = _get_device_functions(ctx, ops.get_default_graph()) + if device_functions != self._device_functions: + raise ValueError( + "The current device stack does not match the device stack under " + "which the TensorFlow function '%s' was created.\n" + "Current device stack: %s\n%s device stack: %s" % + (self._inference_function.name, device_functions, + self._inference_function.name, self._device_functions)) + + for v in self._func_graph.variables: + if v.trainable: + tape.variable_accessed(v) - def _construct_backprop_function(self): - """Constructs the backprop function object for this function.""" - filtered_outputs = [x for x in self._python_returns if x is not None] - # TODO(skyewm): use FuncGraph - backwards_graph = CapturingGraph() - backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access - for collection in self._graph.collections: - backwards_graph.get_collection_ref( - collection)[:] = self._graph.get_collection(collection) - backwards_graph.seed = self._graph.seed - with backwards_graph.as_default(): - self._out_grad_placeholders = [ - graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] - in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access - filtered_outputs, - self._input_placeholders, - grad_ys=self._out_grad_placeholders, - src_graph=self._graph) - - backward_outputs = tuple( - grad for grad in _flatten(in_gradients) if grad is not None) - output_shapes = tuple(grad.shape for grad in backward_outputs) - - extra_inputs = backwards_graph.captures.keys() - extra_placeholders = backwards_graph.captures.values() - - forward_name = _forward_name(self._func_name) - # Note: we cannot have placeholder ops in the graph or the TPU compilation - # pass fails. - placeholder_ops = set([y.op for y in self._input_placeholders]) - function_ops = [x for x in self._graph.get_operations() - if x not in placeholder_ops] - self._forward_fdef = _EagerDefinedFunction( - forward_name, self._graph, function_ops, - self._input_placeholders, filtered_outputs + list(extra_inputs), - self._attrs) - all_inputs = self._out_grad_placeholders + list(extra_placeholders) - # Excluding input ops from the body as we do not intend to execute these - # operations when the function is executed. - all_ignored_ops = frozenset(x.op for x in all_inputs) - # Enforce a deterministic order of operations in the generated graph. This - # means rerunning the function-defining code will always define the same - # function, which is useful if we serialize this etc. - function_def_ops = tuple(x - for x in sorted(backwards_graph.get_operations(), - key=lambda x: x.name) - if x not in all_ignored_ops) - bname = _backward_name(self._func_name) - self._backward_function = GraphModeFunction( - bname, all_inputs, [], backwards_graph, function_def_ops, - backward_outputs, in_gradients, output_shapes, attrs=self._attrs) + captures = self._resolve_captured_inputs() + tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] + args = tensor_inputs + captures - def _backprop_call(self, args): - """Calls the wrapped function and records the result on a tape. + if tape.should_record(tensor_inputs) or tape.should_record(captures): + return self._backprop_call(args) - (Only records results on a tape if the function has outputs) + outputs = self._inference_function.call(ctx, args) + return self._build_call_outputs(outputs) - Args: - args: All inputs to the function, including resolved extra inputs - Returns: - The call output. - """ - ctx = context.context() - outputs = self._forward_fdef.call(ctx, args, self._output_shapes) - if isinstance(outputs, ops.Operation) or outputs is None: - return outputs + @property + def graph(self): + """Returns the graph from which this function was constructed.""" + return self._func_graph - # `real_outputs` are the actual outputs of the inference graph function; - # `side_outputs` are the intermediate Tensors that were added as outputs to - # the forward graph function so that we can compute its gradient. - real_outputs = outputs[:self._num_outputs] - side_outputs = outputs[self._num_outputs:] + @property + def variables(self): + """Returns all variables touched by this function.""" + return self._func_graph.variables - def backward_function(*args): - return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable + @property + def inputs(self): + """Returns tensors in `self.graph` corresponding to arguments.""" + return self._func_graph.inputs - tape.record_operation( - self._forward_fdef.signature.name, - real_outputs, - args, - backward_function) + @property + def outputs(self): + """Returns tensors in `self.graph` corresponding to return values.""" + return self._func_graph.outputs - return self._build_call_outputs(real_outputs) + @property + def captured_inputs(self): + """Returns external Tensors captured by this function. + + self.__call__(*args) passes `args + self.captured_inputs` to the function. + """ + return self._captured_inputs + + @property + def function_def(self): + """Returns a `FunctionDef` object representing this function.""" + return self._inference_function.definition @property def output_shapes(self): """The function's output shapes.""" # TODO(ebrevdo): Should we only keep the output shapes associated # with len(self._python_returns) outputs? - outputs_list = nest.flatten(self._python_func_outputs) + # TODO(akshayka): Consider removing this. + outputs_list = nest.flatten(self._func_graph.structured_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -613,23 +584,80 @@ class GraphModeFunction(object): else: outputs_list[i] = self._output_shapes[j] j += 1 - return nest.pack_sequence_as(self._python_func_outputs, outputs_list) + return nest.pack_sequence_as(self._func_graph.structured_outputs, + outputs_list) @property def output_dtypes(self): - return nest.map_structure( - lambda x: x.dtype if x is not None else None, self._python_func_outputs) + # TODO(akshayka): Consider removing this. + return nest.map_structure(lambda x: x.dtype if x is not None else None, + self._func_graph.structured_outputs) - @property - def captured_inputs(self): - return self._extra_inputs + def _construct_backprop_function(self): + """Constructs the backprop function object for this function.""" + backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) + with backwards_graph.as_default(): + gradients_wrt_outputs = [ + graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs + ] + gradients_wrt_inputs = gradients_impl._GradientsHelper( # pylint: disable=protected-access + self._func_graph.outputs, + self._func_graph.inputs, + grad_ys=gradients_wrt_outputs, + src_graph=self._func_graph) + + self._forward_function = _EagerDefinedFunction( + _forward_name( + self._func_graph.name), self._func_graph, self._func_graph.inputs, + self._func_graph.outputs + list(backwards_graph.captures.keys()), + self._attrs) - @property - def name(self): - """Returns the name of the function in Eager-compatible format.""" - return self._function_def.name.encode("utf-8") + # The ordering of `backwards_graph.inputs` is important: inputs of + # `self._backward_graph_function` correspond to outputs of + # `self._forward_function`. + backwards_graph.inputs = gradients_wrt_outputs + list( + backwards_graph.captures.values()) + # Clear captures, since we pass them in as inputs. + backwards_graph.captures = {} + backwards_graph.outputs.extend( + grad for grad in _flatten(gradients_wrt_inputs) if grad is not None) + backwards_graph.structured_outputs = gradients_wrt_inputs + self._backward_graph_function = Function( + backwards_graph, attrs=self._attrs) + + def _backprop_call(self, args): + """Calls the forward function and records the result on a tape. + + (Only records results on a tape if the function has outputs) - def _resolve_extra_inputs(self): + Args: + args: All inputs to the function, including resolved captured inputs + + Returns: + The call output. + """ + if self._backward_graph_function is None: + self._construct_backprop_function() + + ctx = context.context() + outputs = self._forward_function.call(ctx, args) + if isinstance(outputs, ops.Operation) or outputs is None: + return outputs + + # `real_outputs` are the actual outputs of the inference graph function; + # `side_outputs` are the intermediate Tensors that were added as outputs to + # the forward graph function so that we can compute its gradient. + real_outputs = outputs[:self._num_outputs] + side_outputs = outputs[self._num_outputs:] + + def backward_function(*args): + return self._backward_graph_function(*(list(args) + side_outputs)) # pylint: disable=not-callable + + tape.record_operation(self._forward_function.signature.name, real_outputs, + args, backward_function) + return self._build_call_outputs(real_outputs) + + def _resolve_captured_inputs(self): """Resolve captured distributed variables to their current values. Some inputs can be distributed variables. Such variables yield a different @@ -637,44 +665,23 @@ class GraphModeFunction(object): execution. Returns: - a list of resolved extra input tensors. + a list of resolved captured input tensors. """ if self._distributed_variables: - # Loop over each extra_inputs and check if it corresponds to something + # Loop over each captured input and check if it corresponds to something # distributed. If so, get its _distributed_container and fetch the # component appropriate for the current execution context. - resolved_extra_inputs = self._extra_inputs[:] - for i, extra_input in enumerate(self._extra_inputs): - distributed_var = self._distributed_variables.get(extra_input, None) + resolved_captured_inputs = self._captured_inputs[:] + for i, captured_input in enumerate(self._captured_inputs): + distributed_var = self._distributed_variables.get(captured_input, None) if distributed_var is not None: # distributed variables override __getattr__ and substitute the # right component variable. In here, `distributed_var.handle` # actually does the equivalent of # distributed_var.get_current_component_var().handle. - resolved_extra_inputs[i] = distributed_var.handle - return resolved_extra_inputs - - return self._extra_inputs - - def __call__(self, *args): - """Executes the passed function in eager mode.""" - for v in self._variables: - if v.trainable: - tape.watch_variable(v) - - resolved_extra_inputs = self._resolve_extra_inputs() - - tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] - args = tensor_inputs + resolved_extra_inputs - if tape.should_record(tensor_inputs) or tape.should_record( - resolved_extra_inputs): - if self._backward_function is None: - self._construct_backprop_function() - return self._backprop_call(args) - - ctx = context.context() - outputs = self._function_def.call(ctx, args, self._output_shapes) - return self._build_call_outputs(outputs) + resolved_captured_inputs[i] = distributed_var.handle + return resolved_captured_inputs + return self._captured_inputs def _build_call_outputs(self, result): """Maps the fdef output list to actual output structure. @@ -684,12 +691,12 @@ class GraphModeFunction(object): Returns: The actual call output. """ - if self._python_func_outputs is None: + if self._func_graph.structured_outputs is None: return result # Use `nest.flatten` instead of `_flatten` in order to preserve any - # IndexedSlices in `self._python_func_outputs`. - outputs_list = nest.flatten(self._python_func_outputs) + # IndexedSlices in `self._func_graph.structured_outputs`. + outputs_list = nest.flatten(self._func_graph.structured_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -703,13 +710,13 @@ class GraphModeFunction(object): j += 3 else: outputs_list[i] = ops.IndexedSlices( - values=result[j], - indices=result[j + 1]) + values=result[j], indices=result[j + 1]) j += 2 else: outputs_list[i] = result[j] j += 1 - ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list) + ret = nest.pack_sequence_as(self._func_graph.structured_outputs, + outputs_list) return ret @@ -725,20 +732,18 @@ def _get_defun_inputs_from_signature(signature): def _get_defun_inputs_from_args(args): """Maps python function args to graph-construction inputs.""" function_inputs = [ - graph_placeholder(arg.dtype, arg.shape) if isinstance(arg, ops.Tensor) - else arg for arg in nest.flatten(args) + graph_placeholder(arg.dtype, arg.shape) + if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args) ] return nest.pack_sequence_as(args, function_inputs) -def _trace_and_define_function(name, python_func, compiled, args, kwds, - signature=None): - """Defines and returns graph-mode version of `python_func`. +def func_graph_from_py_func(name, python_func, args, kwds, signature=None): + """Returns a `FuncGraph` generated from `python_func`. Args: name: an identifier for the function. python_func: the Python function to trace. - compiled: whether the graph function should be compiled through XLA. args: the positional args with which the Python function should be called; ignored if a signature is provided. kwds: the keyword args with which the Python function should be called; @@ -750,14 +755,13 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, inputs. Returns: - A GraphModeFunction. + A FuncGraph. Raises: TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ - func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph()) - + func_graph = FuncGraph(name) with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) @@ -771,8 +775,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. func_graph.inputs.extend( x for x in nest.flatten(func_args) + nest.flatten(func_kwds) - if isinstance(x, ops.Tensor) - ) + if isinstance(x, ops.Tensor)) # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects @@ -797,6 +800,7 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, this_tape = tape.push_new_tape() try: func_outputs = python_func(*func_args, **func_kwds) + # invariant: `func_outputs` contains only Tensors and `None`s. func_outputs = nest.map_structure(convert, func_outputs) def check_mutation(n1, n2): @@ -816,53 +820,34 @@ def _trace_and_define_function(name, python_func, compiled, args, kwds, check_mutation(func_args_before, func_args) check_mutation(func_kwds_before, func_kwds) - finally: tape.pop_tape(this_tape) + func_graph.structured_outputs = func_outputs + # Returning a closed-over tensor does not trigger convert_to_tensor. + func_graph.outputs.extend( + func_graph.capture(x) + for x in _flatten(func_graph.structured_outputs) + if x is not None) + + # Some captured variables might be components of DistributedValues. + # Instead of storing non-distributed component variables, we + # store their distributed containers so we can retrieve the correct + # component variables at call-time. variables = list(this_tape.watched_variables()) - - # Some variables captured by the tape can come from a DistributedValue. - # At call time, DistributedValue can return another variable (e.g. if - # the function is run on a different device). Thus, instead of storing - # the specific captured variable, we replace it with its distributed - # container. strategy = distribution_strategy_context.get_distribution_strategy() for i, variable in enumerate(variables): # If variable is not distributed value_container returns itself. variables[i] = strategy.value_container(variable) - func_graph.variables = variables - # Returning a closed-over tensor as an output does not trigger a - # call to convert_to_tensor, so we manually capture all such tensors. - func_graph.outputs.extend( - func_graph.capture(x) for x in _flatten(func_graph.structured_outputs) - if x is not None - ) - - output_shapes = tuple( - x.shape if isinstance(x, ops.Tensor) else None - for x in func_graph.outputs) - - all_ignored_ops = frozenset(x.op for x in func_graph.inputs) - operations = tuple(x for x in func_graph.get_operations() - if x not in all_ignored_ops) - # Register any other functions defined in the graph - # TODO(ashankar): Oh lord, forgive me for this lint travesty. + # Register any other functions defined in the graph. if context.executing_eagerly(): for f in func_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register(f._c_func.func) # pylint: disable=protected-access - attrs = {} - if compiled: - attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True) - - return GraphModeFunction( - func_graph.name, func_graph.inputs, func_graph.captures.keys(), - func_graph, operations, func_graph.outputs, func_graph.structured_outputs, - output_shapes, func_graph.variables, attrs) + return func_graph _TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"]) @@ -911,13 +896,13 @@ def _deterministic_dict_values(dictionary): return tuple(dictionary[key] for key in sorted(dictionary)) -class _PolymorphicFunction(object): +class PolymorphicFunction(object): """Wrapper class for the graph functions defined for a Python function. See the documentation for `defun` for more information on the semantics of defined functions. - _PolymorphicFunction class is thread-compatible meaning that minimal + PolymorphicFunction class is thread-compatible meaning that minimal usage of defuns (defining and calling) is thread-safe, but if users call other methods or invoke the base `python_function` themselves, external synchronization is necessary. @@ -926,8 +911,7 @@ class _PolymorphicFunction(object): def __init__(self, python_function, name, - input_signature=None, - compiled=False): + input_signature=None): """Initializes a polymorphic function. Args: @@ -936,14 +920,10 @@ class _PolymorphicFunction(object): input_signature: a possibly nested sequence of `TensorSpec` objects specifying the input signature of this function. If `None`, a separate function is instantiated for each inferred input signature. - compiled: if True, the framework will attempt to compile func with XLA. Raises: ValueError: if `input_signature` is not None and the `python_function`'s argspec has keyword arguments. - TypeError: if `input_signature` contains anything other than - `TensorSpec` objects, or (if not None) is anything other than a tuple or - list. """ if isinstance(python_function, functools.partial): @@ -955,8 +935,7 @@ class _PolymorphicFunction(object): self._args_to_prepend = tuple() self._kwds_to_include = {} self._name = name - self._compiled = compiled - self._arguments_to_functions = {} + self._function_cache = collections.OrderedDict() self._variables = [] self._lock = threading.Lock() @@ -991,15 +970,40 @@ class _PolymorphicFunction(object): self._input_signature = tuple(input_signature) self._flat_input_signature = tuple(nest.flatten(input_signature)) - if any(not isinstance(arg, tensor_spec.TensorSpec) - for arg in self._flat_input_signature): - raise TypeError("Invalid input_signature %s; input_signature must be " - "a possibly nested sequence of TensorSpec objects.") + + def __call__(self, *args, **kwds): + """Calls a graph function specialized to the inputs.""" + graph_function, inputs = self._maybe_define_function(*args, **kwds) + return graph_function(*inputs) + + @property + def python_function(self): + """Returns the wrapped Python function.""" + return self._python_function + + # TODO(akshayka): Remove this property. + @property + def variables(self): + """Returns the union of all variables referenced by cached `Function`s`.""" + return self._variables + + def get_concrete_function(self, *args, **kwargs): + """Returns a `Function` object specialized to inputs and execution context. + + `args` and `kwargs` are ignored if this `PolymorphicFunction` was created + with an `input_signature`. + + Args: + *args: inputs to specialize on. + **kwargs: inputs to specialize on. + """ + graph_function, _ = self._maybe_define_function(*args, **kwargs) + return graph_function def __get__(self, instance, owner): """Makes it possible to defun instance methods.""" del owner - # `instance` here is the instance that this `_PolymorphicFunction` was + # `instance` here is the instance that this `PolymorphicFunction` was # accessed through; e.g., for # # class Foo(object): @@ -1009,29 +1013,42 @@ class _PolymorphicFunction(object): # ... # # foo = Foo() - # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance + # foo.bar() # `foo.bar` is a `PolymorphicFunction` instance # # then `instance` will be `foo` (and `owner` will be `Foo`). return functools.partial(self.__call__, instance) - def _cache_key(self, args, kwds): - """Computes the cache key given inputs.""" + def _cache_key(self, args, kwds, ctx, graph): + """Computes the cache key given inputs and execution context.""" if self._input_signature is None: inputs = (args, kwds) if kwds else args cache_key = tuple(_encode_arg(arg) for arg in inputs) else: del args, kwds cache_key = self._flat_input_signature + # The graph, or whether we're executing eagerly, should be a part of the # cache key so we don't improperly capture tensors such as variables. - return cache_key + (context.executing_eagerly() or ops.get_default_graph(),) + executing_eagerly = ctx.executing_eagerly() + execution_context = executing_eagerly or graph + + # Putting the device in the cache key ensures that call-site device + # annotations are respected. + device_functions = _get_device_functions(ctx, graph) + + # `ops.colocate_with` directives translate into `ops.device` directives when + # eager execution is enabled. + colocation_stack = (None if executing_eagerly else + tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access + + return cache_key + (execution_context, device_functions, colocation_stack) def _canonicalize_function_inputs(self, *args, **kwds): """Canonicalizes `args` and `kwds`. Canonicalize the inputs to the Python function using its fullargspec. In particular, we parse the varags and kwargs that this - `_PolymorphicFunction` was called with into a tuple corresponding to the + `PolymorphicFunction` was called with into a tuple corresponding to the Python function's positional (named) arguments and a dictionary corresponding to its kwargs. @@ -1085,8 +1102,9 @@ class _PolymorphicFunction(object): if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs): raise ValueError("When input_signature is provided, all inputs to " "the Python function must be Tensors.") - tensor_specs = [tensor_spec.TensorSpec.from_tensor(tensor) - for tensor in flat_inputs] + tensor_specs = [ + tensor_spec.TensorSpec.from_tensor(tensor) for tensor in flat_inputs + ] if any(not spec.is_compatible_with(other) for spec, other in zip(self._flat_input_signature, tensor_specs)): raise ValueError("Python inputs incompatible with input_signature: " @@ -1111,42 +1129,33 @@ class _PolymorphicFunction(object): """ args, kwds = self._canonicalize_function_inputs(*args, **kwds) - cache_key = self._cache_key(args, kwds) + cache_key = self._cache_key(args, kwds, context.context(), + ops.get_default_graph()) with self._lock: try: - graph_function = self._arguments_to_functions.get(cache_key, None) + graph_function = self._function_cache.get(cache_key, None) except TypeError: raise TypeError("Arguments supplied to `defun`-generated functions " "must be hashable.") if graph_function is None: - graph_function = _trace_and_define_function( - self._name, self._python_function, self._compiled, args, kwds, - self._input_signature) + graph_function = Function( + func_graph_from_py_func(self._name, self._python_function, args, + kwds, self._input_signature)) self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) - self._arguments_to_functions[cache_key] = graph_function + self._function_cache[cache_key] = graph_function return graph_function, (args, kwds) - def __call__(self, *args, **kwds): - """Calls a graph function specialized for this input signature.""" - graph_function, inputs = self._maybe_define_function(*args, **kwds) - return graph_function(*inputs) - - def call_python_function(self, *args, **kwargs): - """Directly calls the wrapped python function.""" - return self._python_function(*args, **kwargs) - @property - def variables(self): - """Returns a list of variables used in any of the defined functions.""" - return self._variables +def _validate_signature(signature): + if any(not isinstance(arg, tensor_spec.TensorSpec) + for arg in nest.flatten(signature)): + raise TypeError("Invalid input_signature %s; input_signature must be " + "a possibly nested sequence of TensorSpec objects.") -# TODO(akshayka): Remove the `compiled` flag and create a separate -# API for xla compilation (`defun` is already complicated enough -# as it is, and the keyword argument makes 'compiled' an overloaded concept) -def defun(func=None, input_signature=None, compiled=False): +def defun(func=None, input_signature=None): """Compiles a Python function into a callable TensorFlow graph. `defun` (short for "define function") trace-compiles a Python function @@ -1221,6 +1230,7 @@ def defun(func=None, input_signature=None, compiled=False): self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) self.keep_probability = keep_probability + @tf.contrib.eager.defun def call(self, inputs, training=True): x = self.dense2(self.dense1(inputs)) if training: @@ -1229,7 +1239,6 @@ def defun(func=None, input_signature=None, compiled=False): return x model = MyModel() - model.call = tf.contrib.eager.defun(model.call) model(x, training=True) # executes a graph, with dropout model(x, training=False) # executes a graph, without dropout @@ -1437,9 +1446,10 @@ def defun(func=None, input_signature=None, compiled=False): func: function to be compiled. If `func` is None, returns a decorator that can be invoked with a single argument - `func`. The end result is equivalent to providing all the arguments up front. - In other words, defun(compiled=True)(func) is equivalent to - defun(func, compiled=True). The former allows the following use case: - @tf.contrib.eager.defun(compiled=True) + In other words, defun(input_signature=...)(func) is equivalent to + defun(func, input_signature=...). The former allows + the following use case: + @tf.contrib.eager.defun(input_signature=...) def foo(...): ... @@ -1450,17 +1460,20 @@ def defun(func=None, input_signature=None, compiled=False): signature is specified, every input to `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. - compiled: If True, an attempt to compile `func` with XLA will be made. - If it fails, function will be run normally. Experimental. Currently - supported only for execution on TPUs. For the vast majority of users, - this argument should be False. - Returns: If `func` is not None, returns a callable that will execute the compiled function (and return zero or more `tf.Tensor` objects). If `func` is None, returns a decorator that, when invoked with a single `func` argument, returns a callable equivalent to the case above. + + Raises: + TypeError: If `input_signature` is neither `None` nor a sequence of + `tf.contrib.eager.TensorSpec` objects. """ + + if input_signature is not None: + _validate_signature(input_signature) + # TODO(apassos): deal with captured global state. Deal with control flow. def decorated(function): try: @@ -1469,8 +1482,7 @@ def defun(func=None, input_signature=None, compiled=False): name = "function" return tf_decorator.make_decorator( function, - _PolymorphicFunction( - function, name, input_signature=input_signature, compiled=compiled)) + PolymorphicFunction(function, name, input_signature=input_signature)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: @@ -1486,51 +1498,6 @@ def defun(func=None, input_signature=None, compiled=False): return decorated -def make_defun_op(func, *args, **kwds): - """Compile func into graph_mode, assuming func arguments are *args, **kwargs. - - `make_defun_op` converts a function that constructs a TensorFlow graph into - a function object and attaches it to the graph. The resulting function - object can be queried for its properties, and called directly with different - inputs to execute. - - More details on use cases and limitations are available in the - documentation for `defun`. - - Example: - ```python - def f(x, y): - return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) - - def g(x, y): - return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) - - z = tf.constant([[0.0, 0.0]]) - g_op = make_defun_op(g, z, z) - - assert g_op.output_shapes == tf.TensorShape([]) - assert g_op.output_types == tf.float32 - - x = tf.constant([[2.0, 3.0]]) - y = tf.constant([[3.0, -2.0]]) - - # The plain function and defun-compiled function should return the same value. - assert f(x, y).numpy() == g_op(x, y).numpy() - ``` - - Args: - func: function to be compiled. - *args: List arguments to pass to `func` when attaching to the graph. - **kwds: Keyword arguments to pass to `func` when attaching to the graph. - - Returns: - A wrapper object which can be queried for its output properties, - and which can be called directly the way a `@defun` wrapped function - can. - """ - return _trace_and_define_function(func.__name__, func, False, args, kwds) - - class AutomaticControlDependencies(object): """Context manager to automatically add control dependencies. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 380bcf763f..92254a2c00 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -27,7 +27,6 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function -from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -105,7 +104,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(step(), 2.0) def testGraphGradientVariable(self): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) @function.defun @@ -130,16 +129,16 @@ class FunctionTest(test.TestCase): with ops.Graph().as_default(): self.assertEqual(f().shape, ()) - def testBasicDefunOpGraphMode(self): + def testBasicGraphFunction(self): matmul = function.defun(math_ops.matmul) + @function.defun def sq(a): return matmul(a, a) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) - sq_op = function.make_defun_op(sq, t) - + sq_op = sq.get_concrete_function(t) self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) @@ -211,33 +210,44 @@ class FunctionTest(test.TestCase): random_seed.set_random_seed(1) self.assertAllEqual(f(), x) - def testNestedInputsDefunOpGraphMode(self): + def testSymGradGatherNd(self): + with ops.Graph().as_default(), self.cached_session() as sess: + + @function.defun + def f(x): + return array_ops.gather_nd(x, [[0]]) + + c = constant_op.constant([[2.]]) + f_c = f(c) + g, = gradients_impl.gradients(f_c, c) + self.assertAllEqual(sess.run(g), [[1.0]]) + + def testNestedInputsGraphFunction(self): matmul = function.defun(math_ops.matmul) pair = collections.namedtuple('pair', ['a', 'b']) + @function.defun def a_times_b(inputs): return matmul(inputs.a['a'], inputs.b['b']) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) - inputs = pair({'a': t}, {'b': t}) - sq_op = function.make_defun_op(a_times_b, inputs) - + sq_op = a_times_b.get_concrete_function(inputs) self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) out = sq_op(inputs) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) - def testNestedOutputDefunOpGraphMode(self): + def testNestedOutputGraphFunction(self): matmul = function.defun(math_ops.matmul) + @function.defun def sq(a): return (matmul(a, a), {'b': constant_op.constant(1.0)}) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) - sq_op = function.make_defun_op(sq, t) - + sq_op = sq.get_concrete_function(t) self.assertEqual(sq_op.output_shapes, (tensor_shape.TensorShape([2, 2]), {'b': tensor_shape.TensorShape([])})) @@ -247,28 +257,28 @@ class FunctionTest(test.TestCase): self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) self.assertAllEqual(b['b'].numpy(), 1.0) - def testDefunOpGraphModeWithGradients(self): + def testGraphFunctionWithGradients(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') + @function.defun def step(): def inner(): return v * v return backprop.implicit_grad(inner)()[0][0] - step_op = function.make_defun_op(step) - + step_op = step.get_concrete_function() self.assertEqual(step_op.output_dtypes, dtypes.float32) self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([])) self.assertAllEqual(step_op(), 2.0) - def testDefunOpGraphModeNoneOutput(self): + def testGraphFunctionNoneOutput(self): + @function.defun def fn(unused_a, unused_b): return None x = constant_op.constant(1) - fn_op = function.make_defun_op(fn, x, x) - + fn_op = fn.get_concrete_function(x, x) self.assertEqual(fn_op.output_dtypes, None) self.assertEqual(fn_op.output_shapes, None) self.assertAllEqual(fn_op(x, x), None) @@ -309,13 +319,13 @@ class FunctionTest(test.TestCase): x = random_ops.random_uniform([2, 2]).numpy() defined = function.defun(f) defined(x) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) x = random_ops.random_uniform([2, 2]).numpy() defined(x) # A NumPy array with different values but the same shape and dtype # shouldn't trigger another function definition. - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) def testDefunCapturedInt32(self): x = constant_op.constant(1, dtype=dtypes.int32) @@ -346,6 +356,47 @@ class FunctionTest(test.TestCase): self.assertEqual(3.0, float(test_assign_add())) + @test_util.run_in_graph_and_eager_modes + def testTensorInitializationInFunctionRaisesError(self): + error_msg = ('Tensor-typed variable initializers must either be ' + 'wrapped in an init_scope or callable.*') + + @function.defun + def tensor_init(): + with self.assertRaisesRegexp(ValueError, error_msg): + resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) + + tensor_init() + + @test_util.run_in_graph_and_eager_modes + def testCallableTensorInitializationInFunction(self): + + @function.defun + def tensor_init(): + v = resource_variable_ops.ResourceVariable( + lambda: constant_op.constant(2.0)) + return v.read_value() + + value = tensor_init() + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(value), 2.0) + + @test_util.run_in_graph_and_eager_modes + def testInitScopeTensorInitializationInFunction(self): + + @function.defun + def tensor_init(): + with ops.init_scope(): + const = constant_op.constant(2.0) + v = resource_variable_ops.ResourceVariable(const) + return v.read_value() + + value = tensor_init() + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(value), 2.0) + def testDefunShapeInferenceWithCapturedResourceVariable(self): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) @@ -430,7 +481,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) def testGraphModeCaptureVariable(self): - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: class HasAVar(object): @@ -458,12 +509,12 @@ class FunctionTest(test.TestCase): x = constant_op.constant(1.0) l = f(x, v) _, dv = gradients_impl.gradients(l, [x, v]) - with self.test_session(): + with self.cached_session(): v.initializer.run() self.assertAllEqual(dv.eval(), 0.0) def testGraphModeManyFunctions(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): @function.defun def f(x): @@ -564,7 +615,6 @@ class FunctionTest(test.TestCase): @function.defun def g(x): - tape.watch_variable(x) y = math_ops.add(x, three) f(y) @@ -578,7 +628,6 @@ class FunctionTest(test.TestCase): return math_ops.add(x, three) def g(x): - tape.watch_variable(three) return f(x) g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0] @@ -633,17 +682,19 @@ class FunctionTest(test.TestCase): def testReturningIndexedSlicesWithDefun(self): def validate(indexed_slice): + @function.defun def f(): return indexed_slice - output = function.defun(f)() + output = f() self.assertTrue(isinstance(output, ops.IndexedSlices)) self.assertAllEqual(indexed_slice.values, output.values) self.assertAllEqual(indexed_slice.indices, output.indices) self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape) self.assertEqual( - function.make_defun_op(f).output_shapes, indexed_slice.values.shape) + f.get_concrete_function().output_shapes, + indexed_slice.values.shape) arg = ops.IndexedSlices( values=constant_op.constant([1, 2]), @@ -883,7 +934,7 @@ class FunctionTest(test.TestCase): self.assertEqual(1, int(read())) def testReturnCapturedGraphTensor(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): t = constant_op.constant(1) @function.defun @@ -966,39 +1017,109 @@ class FunctionTest(test.TestCase): config=config_pb2.ConfigProto(device_count={'CPU': 4})) def testDeviceAnnotationsRespected(self): - @function.defun def multi_device_fn(): with ops.device('/cpu:0'): - s1 = iterator_ops.Iterator.from_structure( + s0 = iterator_ops.Iterator.from_structure( (dtypes.float32,)).string_handle() with ops.device('/cpu:1'): - s2 = iterator_ops.Iterator.from_structure( + s1 = iterator_ops.Iterator.from_structure( (dtypes.float32,)).string_handle() with ops.device('/cpu:2'): - s3 = iterator_ops.Iterator.from_structure( - (dtypes.float32,)).string_handle() - with ops.device(''): - # TODO(akshayka): This is unfortunate and brittle. It prevents - # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'. - # Remove this hack once we have a way of obtaining metadata about - # function execution. - s4 = iterator_ops.Iterator.from_structure( + s2 = iterator_ops.Iterator.from_structure( (dtypes.float32,)).string_handle() - return s1, s2, s3, s4 + s3 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + return s0, s1, s2, s3 - with ops.device('/cpu:3'): - outputs = self.evaluate(multi_device_fn()) + defined = function.defun(multi_device_fn) + outputs = self.evaluate(defined()) + self.assertEqual(len(defined._function_cache), 1) self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) - self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) - with ops.device('/cpu:0'): - outputs = self.evaluate(multi_device_fn()) + with ops.device('/cpu:3'): + outputs = self.evaluate(defined()) + self.assertEqual(len(defined._function_cache), 2) self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) - self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) + self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) + + # This should retrieve the call-site-device agnostic function + defined() + self.assertEqual(len(defined._function_cache), 2) + + # And this should retrieve the function created for '/cpu:3' + with ops.device('/cpu:3'): + defined() + self.assertEqual(len(defined._function_cache), 2) + + @test_util.run_in_graph_and_eager_modes( + config=config_pb2.ConfigProto(device_count={'CPU': 2})) + def testCallingGraphFunctionOnIncompatibleDeviceRaisesError(self): + + def func(): + return constant_op.constant(0) + + defined = function.defun(func) + with ops.device('cpu:0'): + cpu_graph_function = defined.get_concrete_function() + + with ops.device('cpu:0'): + self.assertEqual( + self.evaluate(cpu_graph_function()), self.evaluate(func())) + + with self.assertRaisesRegexp( + ValueError, + 'The current device stack does not match the device stack under ' + 'which the TensorFlow function \'.*func.*\' was created.\n' + 'Current device stack: .*\n.*func.* device stack.*'): + with ops.device('cpu:1'): + cpu_graph_function() + + with self.assertRaisesRegexp( + ValueError, + 'The current device stack does not match the device stack under ' + 'which the TensorFlow function \'.*func.*\' was created.\n' + 'Current device stack: .*\n.*func.* device stack.*'): + with ops.device(None): + cpu_graph_function() + + default_graph_function = defined.get_concrete_function() + self.assertEqual( + self.evaluate(default_graph_function()), self.evaluate(func())) + + with self.assertRaisesRegexp( + ValueError, + 'The current device stack does not match the device stack under ' + 'which the TensorFlow function \'.*func.*\' was created.\n' + 'Current device stack: .*\n.*func.* device stack.*'): + with ops.device('cpu:1'): + default_graph_function() + + @test_util.run_in_graph_and_eager_modes + def testColocateWithRespected(self): + # TODO(b/113291792): Use multiple CPUs instead of a GPU. + if not context.context().num_gpus(): + self.skipTest('No GPUs found.') + + with ops.device('cpu:0'): + x = constant_op.constant(1.0) + + with ops.device('gpu:0'): + y = constant_op.constant(1.0) + + @function.defun + def foo(): + return iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + + with ops.colocate_with(x): + self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo())) + + with ops.colocate_with(y): + self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) def testVariablesAreTracked(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -1027,26 +1148,31 @@ class FunctionTest(test.TestCase): defined = function.defun(func) defined(0, baz=20) + + def cache_keys(): + """Sanitizes cache keys of non-input metadata.""" + return tuple(key[:3] for key in defined._function_cache) + # `True` corresponds to the fact that we're executing eagerly - self.assertIn((0, 1, 20, True), defined._arguments_to_functions) + self.assertIn((0, 1, 20), cache_keys()) defined(1) # bar=1, baz=2 - self.assertIn((1, 1, 2, True), defined._arguments_to_functions) + self.assertIn((1, 1, 2), cache_keys()) # This matches the previous call. defined(foo=1) - self.assertEqual(len(defined._arguments_to_functions), 2) + self.assertEqual(len(defined._function_cache), 2) defined(1, 2, 3) - self.assertIn((1, 2, 3, True), defined._arguments_to_functions) + self.assertIn((1, 2, 3), cache_keys()) # This matches the previous call. defined(1, bar=2, baz=3) - self.assertEqual(len(defined._arguments_to_functions), 3) + self.assertEqual(len(defined._function_cache), 3) # This matches the previous call. defined(1, baz=3, bar=2) - self.assertEqual(len(defined._arguments_to_functions), 3) + self.assertEqual(len(defined._function_cache), 3) def testFunctoolsPartialUnwrappedCorrectly(self): @@ -1072,7 +1198,7 @@ class FunctionTest(test.TestCase): defined = function.defun(foo, input_signature=signature) a = array_ops.ones([2]) out = defined(a) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) self.assertAllEqual(out, a) def bar(a): @@ -1083,13 +1209,13 @@ class FunctionTest(test.TestCase): defined = function.defun(bar, input_signature=signature) a = array_ops.ones([2, 1]) out = defined(a) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) self.assertAllEqual(out, a) # Changing the second dimension shouldn't create a new function. b = array_ops.ones([2, 3]) out = defined(b) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) self.assertAllEqual(out, b) def testNestedInputSignatures(self): @@ -1106,7 +1232,7 @@ class FunctionTest(test.TestCase): a = array_ops.ones([2, 1]) b = array_ops.ones([1]) out = defined([a, a], b) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) nest.assert_same_structure(out, [[a, a], b]) self.assertAllEqual(out[0][0], a) self.assertAllEqual(out[0][1], a) @@ -1117,7 +1243,7 @@ class FunctionTest(test.TestCase): b = array_ops.ones([2, 5]) c = array_ops.ones([1]) out = defined([a, b], c) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) nest.assert_same_structure(out, [[a, b], c]) self.assertAllEqual(out[0][0], a) self.assertAllEqual(out[0][1], b) @@ -1153,13 +1279,13 @@ class FunctionTest(test.TestCase): # Signatures must consist exclusively of `TensorSpec` objects. signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'): - function.defun(foo, input_signature=signature)(1, 2) + function.defun(foo, input_signature=signature) # Signatures must be either lists or tuples on their outermost levels. signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} with self.assertRaisesRegexp(TypeError, 'input_signature must be either a ' 'tuple or a list.*'): - function.defun(foo, input_signature=signature)(1, 2) + function.defun(foo, input_signature=signature) def testInputsIncompatibleWithSignatureRaisesError(self): @@ -1213,22 +1339,22 @@ class FunctionTest(test.TestCase): integer = constant_op.constant(2, dtypes.int64) out1, out2 = foo(flt, integer) - self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(len(foo._function_cache), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) out1, out2 = foo(flt=flt, integer=integer) - self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(len(foo._function_cache), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) out1, out2 = foo(integer=integer, flt=flt) - self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(len(foo._function_cache), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) out1, out2 = foo(flt, integer=integer) - self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(len(foo._function_cache), 1) self.assertEqual(out1.numpy(), 1.0) self.assertEqual(out2.numpy(), 2) @@ -1258,27 +1384,27 @@ class FunctionTest(test.TestCase): a = constant_op.constant(2.0) b = constant_op.constant([1.0, 2.0]) one = defined(a, b) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) two = defined(a=a, b=b) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) three = defined(b=b, a=a) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) four = defined(a, b=b) - self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertEqual(len(defined._function_cache), 1) # The next call corresponds to a new input signature, hence # we expect another function to be defined. five = defined(b, a) - self.assertEqual(len(defined._arguments_to_functions), 2) + self.assertEqual(len(defined._function_cache), 2) six = defined(a=b, b=a) - self.assertEqual(len(defined._arguments_to_functions), 2) + self.assertEqual(len(defined._function_cache), 2) seven = defined(b=a, a=b) - self.assertEqual(len(defined._arguments_to_functions), 2) + self.assertEqual(len(defined._function_cache), 2) self.assertAllEqual(one, [1.0, 2.0]) self.assertAllEqual(two, [1.0, 2.0]) @@ -1298,14 +1424,14 @@ class FunctionTest(test.TestCase): grad_t, = backprop.gradients_function(sq, [0])(t) self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) - with backprop.GradientTape(persistent=True) as gtape: - gtape.watch(t) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(t) one = matmul(t, b=t, transpose_a=True) two = matmul(b=t, a=t, transpose_a=True) three = matmul(a=t, b=t, transpose_a=True) for output in [one, two, three]: - self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]]) + self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]]) def testGradientInFunctionWithKeywordArguments(self): @@ -1363,7 +1489,7 @@ class FunctionTest(test.TestCase): self.assertAllEqual(state, [0]) # Whereas calling the python function directly should create a side-effect. - side_effecting_function.call_python_function() + side_effecting_function.python_function() self.assertAllEqual(state, [0, 0]) @@ -1371,7 +1497,7 @@ class FunctionTest(test.TestCase): class AutomaticControlDependenciesTest(test.TestCase): def testBasic(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() with function.AutomaticControlDependencies() as c: @@ -1382,7 +1508,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(), 4.0) def testCondMustRun(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1403,7 +1529,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) def testCondMustRunSeparateRead(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1426,7 +1552,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(v.read_value().eval(), 6.0) def testCondNested(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1460,7 +1586,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0) def testCondOneBranch(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1480,7 +1606,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) def testCondOneBranchUpdateBefore(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1501,7 +1627,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) def testCondOneBranchUpdateAfter(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) @@ -1537,7 +1663,7 @@ class AutomaticControlDependenciesTest(test.TestCase): self.assertAllEqual(out, [3, 4, 5]) def testDecorator(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py deleted file mode 100644 index 7105d2e399..0000000000 --- a/tensorflow/python/eager/graph_callable.py +++ /dev/null @@ -1,435 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Decorator that produces a callable object that executes a TensorFlow graph. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import contextlib - -from tensorflow.python.eager import context -from tensorflow.python.eager import function -from tensorflow.python.eager import tape -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest -from tensorflow.python.util import tf_decorator -from tensorflow.python.util import tf_inspect - - -def _default_initializer(name, shape, dtype): - """The default initializer for variables.""" - # pylint: disable=protected-access - store = variable_scope._get_default_variable_store() - initializer = store._get_default_initializer(name, shape=shape, dtype=dtype) - # pylint: enable=protected-access - return initializer[0] - - -class _CapturedVariable(object): - """Variable captured by graph_callable. - - Internal to the implementation of graph_callable. Created only by - _VariableCapturingScope and used only to read the variable values when calling - the function after the variables are initialized. - """ - - def __init__(self, name, initializer, shape, dtype, trainable): - self.name = name - if initializer is None: - initializer = _default_initializer(name, shape, dtype) - initial_value = lambda: initializer(shape, dtype=dtype) - - with context.eager_mode(): - self.variable = resource_variable_ops.ResourceVariable( - initial_value=initial_value, name=name, dtype=dtype, - trainable=trainable) - self.shape = shape - self.dtype = dtype - self.placeholder = None - self.trainable = trainable - - def read(self, want_gradients=True): - if want_gradients and self.trainable: - v = tape.watch_variable(self.variable) - else: - v = self.variable - return v.read_value() - - -class _VariableCapturingScope(object): - """Variable-scope-like object which captures tf.get_variable calls. - - This is responsible for the main difference between the initialization version - of a function object and the calling version of a function object. - - capturing_scope replaces calls to tf.get_variable with placeholder tensors to - be fed the variable's current value. TODO(apassos): these placeholders should - instead be objects implementing a similar API to tf.Variable, for full - compatibility. - - initializing_scope replaces calls to tf.get_variable with creation of - variables and initialization of their values. This allows eventual support of - initialized_value and friends. - - TODO(apassos): once the eager mode layers API is implemented support eager - func-to-object as well. - """ - - def __init__(self): - self.variables = {} - self.tf_variables = {} - - @contextlib.contextmanager - def capturing_scope(self): - """Context manager to capture variable creations. - - Replaces variable accesses with placeholders. - - Yields: - nothing - """ - # TODO(apassos) ignoring the regularizer and partitioner here; figure out - # how to deal with these. - def _custom_getter( # pylint: disable=missing-docstring - getter=None, - name=None, - shape=None, - dtype=dtypes.float32, - initializer=None, - regularizer=None, - reuse=None, - trainable=None, - collections=None, - caching_device=None, # pylint: disable=redefined-outer-name - partitioner=None, - validate_shape=True, - use_resource=None, - aggregation=variable_scope.VariableAggregation.NONE, - synchronization=variable_scope.VariableSynchronization.AUTO): - del getter, regularizer, partitioner, validate_shape, use_resource, dtype - del collections, initializer, trainable, reuse, caching_device, shape - del aggregation, synchronization - assert name in self.variables - v = self.variables[name] - return v.variable - - scope = variable_scope.get_variable_scope() - with variable_scope.variable_scope(scope, custom_getter=_custom_getter): - yield - - @contextlib.contextmanager - def initializing_scope(self): - """Context manager to capture variable creations. - - Forcibly initializes all created variables. - - Yields: - nothing - """ - # TODO(apassos) ignoring the regularizer and partitioner here; figure out - # how to deal with these. - def _custom_getter( # pylint: disable=missing-docstring - getter=None, - name=None, - shape=None, - dtype=dtypes.float32, - initializer=None, - regularizer=None, - reuse=None, - trainable=None, - collections=None, - caching_device=None, # pylint: disable=redefined-outer-name - partitioner=None, - validate_shape=True, - use_resource=None, - aggregation=variable_scope.VariableAggregation.NONE, - synchronization=variable_scope.VariableSynchronization.AUTO): - del getter, regularizer, collections, caching_device, partitioner - del use_resource, validate_shape, aggregation, synchronization - if name in self.tf_variables: - if reuse: - return self.tf_variables[name].initialized_value() - else: - raise ValueError("Specified reuse=%s but tried to reuse variables." - % reuse) - # TODO(apassos): ensure this is on the same device as above - v = _CapturedVariable(name, initializer, shape, dtype, trainable) - self.variables[name] = v - - graph_mode_resource = v.variable.handle - if initializer is None: - initializer = _default_initializer(name, shape, dtype) - resource_variable_ops.shape_safe_assign_variable_handle( - graph_mode_resource, v.variable.shape, initializer(shape, dtype)) - return v.variable - - scope = variable_scope.get_variable_scope() - with variable_scope.variable_scope(scope, custom_getter=_custom_getter): - yield - - -class _InitializingFunctionObject(object): - """Responsible for deciding which version of func-to-object to call. - - call_fn is the version which calls the function with the current values of the - variables and init_fn is the version which calls the function to initialize - all variables. - - TODO(apassos): figure out a way to support initializing only _some_ - variables. This requires a way to pull out a variable's initialization code - from the graph, which might not be possible in general. - """ - - def __init__(self, call_fn, init_fn, shape_and_dtypes): - self._init_fn = init_fn - self._call_fn = call_fn - self.shape_and_dtypes = shape_and_dtypes - self.flattened_shapes = [tensor_shape.as_shape(sd.shape) for sd in - nest.flatten(self.shape_and_dtypes)] - - @property - def variables(self): - return self._call_fn.variables - - def __call__(self, *args): - nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False) - if not all([ - shape.is_compatible_with(arg.shape) - for shape, arg in zip(self.flattened_shapes, nest.flatten(args)) - ]): - raise ValueError( - "Declared shapes do not match argument shapes: Expected %s, found %s." - % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)])) - - initialized = [resource_variable_ops.var_is_initialized_op( - v.handle).numpy() for v in self._call_fn.variables] - if all(x for x in initialized): - for v in self._call_fn.variables: - if v.trainable: - tape.watch_variable(v) - return self._call_fn(*args) - elif all(not x for x in initialized): - return self._init_fn(*args) - else: - raise ValueError("Some, but not all, variables are initialized.") - - -def _get_graph_callable_inputs(shape_and_dtypes): - """Maps specified shape_and_dtypes to graph inputs.""" - ret = [] - for x in shape_and_dtypes: - if isinstance(x, ShapeAndDtype): - ret.append(array_ops.placeholder(x.dtype, x.shape)) - elif isinstance(x, (tuple, list)): - ret.append(_get_graph_callable_inputs(x)) - else: - raise errors.InvalidArgumentError( - None, None, "Expected the argument to @graph_callable to be a " - "(possibly nested) list or tuple of ShapeAndDtype objects, " - "but got an object of type: %s" % type(x)) - - return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret - - -def _graph_callable_internal(func, shape_and_dtypes): - """Defines and returns a template version of func. - - Under the hood we make two function objects, each wrapping a different version - of the graph-mode code. One version immediately runs variable initialization - before making the variable's Tensors available for use, while the other - version replaces the Variables with placeholders which become function - arguments and get the current variable's value. - - Limitations in (2) and (4) are because this does not implement a graph-mode - Variable class which has a convert_to_tensor(as_ref=True) method and a - initialized_value method. This is fixable. - - Args: - func: The tfe Python function to compile. - shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects. - - Raises: - ValueError: If any one of func's outputs is not a Tensor. - - Returns: - Callable graph object. - """ - container = tf_ops.get_default_graph()._container # pylint: disable=protected-access - graph_key = tf_ops.get_default_graph()._graph_key # pylint: disable=protected-access - with context.graph_mode(): - # This graph will store both the initialization and the call version of the - # wrapped function. It will later be used by the backprop code to build the - # backprop graph, if necessary. - tmp_graph = function.CapturingGraph() - # Inherit the graph key from the original graph to ensure optimizers don't - # misbehave. - tmp_graph._container = container # pylint: disable=protected-access - tmp_graph._graph_key = graph_key # pylint: disable=protected-access - with tmp_graph.as_default(): - # Placeholders for the non-variable inputs. - func_inputs = _get_graph_callable_inputs(shape_and_dtypes) - func_num_args = len(tf_inspect.getfullargspec(func).args) - if len(func_inputs) != func_num_args: - raise TypeError("The number of arguments accepted by the decorated " - "function `%s` (%d) must match the number of " - "ShapeAndDtype objects passed to the graph_callable() " - "decorator (%d)." % - (func.__name__, func_num_args, len(func_inputs))) - - # First call the function to generate a graph which can initialize all - # variables. As a side-effect this will populate the variable capturing - # scope's view of which variables exist. - variable_captures = _VariableCapturingScope() - with variable_captures.initializing_scope( - ), function.AutomaticControlDependencies() as a: - func_outputs = func(*func_inputs) - outputs_list = nest.flatten(func_outputs) - for i, x in enumerate(outputs_list): - if x is not None: - outputs_list[i] = a.mark_as_return(x) - if len(outputs_list) == 1 and outputs_list[0] is None: - outputs_list = [] - output_shapes = [x.shape for x in outputs_list] - if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list): - raise ValueError("Found non-tensor output in %s" % str(outputs_list)) - initializing_operations = tmp_graph.get_operations() - - # Call the function again, now replacing usages of variables with - # placeholders. This assumes the variable capturing scope created above - # knows about all variables. - tmp_graph.clear_resource_control_flow_state() - with variable_captures.capturing_scope( - ), function.AutomaticControlDependencies() as a: - captured_outputs = func(*func_inputs) - captured_outlist = nest.flatten(captured_outputs) - for i, x in enumerate(captured_outlist): - if x is not None: - captured_outlist[i] = a.mark_as_return(x) - capturing_operations = tmp_graph.get_operations()[ - len(initializing_operations):] - - sorted_variables = sorted(variable_captures.variables.values(), - key=lambda x: x.name) - - extra_inputs = tmp_graph.captures.keys() - extra_placeholders = tmp_graph.captures.values() - - flat_inputs = [x for x in nest.flatten(func_inputs) - if isinstance(x, tf_ops.Tensor)] - placeholder_inputs = flat_inputs+ list(extra_placeholders) - - func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)] - initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access - # TODO(ashankar): Oh lord, forgive me for this lint travesty. - # Also, what about the gradient registry of these functions? Those need to be - # addressed as well. - for f in tmp_graph._functions.values(): # pylint: disable=protected-access - function._register(f._c_func.func) # pylint: disable=protected-access - initializer_function = function.GraphModeFunction( - initialization_name, - placeholder_inputs, - extra_inputs, - tmp_graph, - initializing_operations, - func_def_outputs, - func_outputs, - output_shapes) - - capture_func_def_outputs = [ - x for x in captured_outlist if isinstance(x, tf_ops.Tensor)] - captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access - captured_function = function.GraphModeFunction( - captured_function_name, - placeholder_inputs, - extra_inputs, - tmp_graph, - capturing_operations, - capture_func_def_outputs, - captured_outputs, - output_shapes, - variables=[x.variable for x in sorted_variables]) - - return _InitializingFunctionObject(captured_function, initializer_function, - shape_and_dtypes) - - -class ShapeAndDtype(object): - """Data type that packages together shape and type information. - - Used for arguments to graph callables. See graph_callable() for an example. - """ - - def __init__(self, shape, dtype): - self.shape = shape - self.dtype = dtype - - -def graph_callable(shape_and_dtypes): - """Decorator that produces a callable that executes a TensorFlow graph. - - When applied on a function that constructs a TensorFlow graph, this decorator - produces a callable object that: - - 1. Executes the graph when invoked. The first call will initialize any - variables defined in the graph. - - 2. Provides a .variables() method to return the list of TensorFlow variables - defined in the graph. - - Note that the wrapped function is not allowed to change the values of the - variables, just use them. - - The return value of the wrapped function must be one of the following: - (1) None, (2) a Tensor, or (3) a possibly nested sequence of Tensors. - - Example: - - ```python - @tfe.graph_callable([tfe.ShapeAndDtype(shape(), dtype=dtypes.float32)]) - def foo(x): - v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=()) - return v + x - - ret = foo(tfe.Tensor(2.0)) # `ret` here is a Tensor with value 3.0. - - foo.variables[0].assign(7.0) # Modify the value of variable `v`. - ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0. - ``` - Args: - shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects - that specifies shape and type information for each of the callable's - arguments. The length of this list must be equal to the number of - arguments accepted by the wrapped function. - - Returns: - A callable graph object. - """ - # TODO(alive,apassos): support initialized_value and friends from tf.Variable. - assert context.executing_eagerly(), ( - "graph_callable can only be used when Eager execution is enabled.") - def decorator(func): - return tf_decorator.make_decorator(func, - _graph_callable_internal( - func, shape_and_dtypes)) - - return decorator diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py deleted file mode 100644 index b9e6ca2a93..0000000000 --- a/tensorflow/python/eager/graph_callable_test.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.eager import backprop -from tensorflow.python.eager import graph_callable -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope - - -class GraphCallableTest(test.TestCase): - - def testBasic(self): - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def my_function(x): - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) - - my_function.variables[0].assign(1.) - self.assertEqual( - 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) - - def testFunctionWithoutReturnValue(self): - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def my_function(x): - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=()) - v.assign(x) - - my_function(constant_op.constant(4, dtype=dtypes.float32)) - self.assertAllEqual(4, my_function.variables[0].read_value()) - - def testFunctionWithoutReturnValueAndArgs(self): - - @graph_callable.graph_callable([]) - def my_function(): - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=()) - v.assign(4) - - my_function() - self.assertAllEqual(4, my_function.variables[0].read_value()) - - def testVariableAPI(self): - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def my_function(x): - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=()) - return v.read_value() + x - - self.assertEqual( - 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) - - my_function.variables[0].assign(1.) - self.assertEqual( - 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) - - def testTensorShape(self): - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)]) - def my_function(x): - _ = x.get_shape() - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]]) - self.assertEqual(v.shape[0], x.shape[0]) - return v + x - - self.assertEqual([2.], - my_function( - constant_op.constant([2.], - dtype=dtypes.float32)).numpy()) - - def testUpdatesAreOrdered(self): - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def my_function(x): - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=()) - v.assign(x + 1) - v.assign(v * x) - return v.read_value() - - self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0) - - def testEmptyInitializer(self): - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)]) - def my_function(x): - v = variable_scope.get_variable("v", shape=[1]) - return x + 0 * v - - self.assertEqual([2.], - my_function( - constant_op.constant([2.], - dtype=dtypes.float32)).numpy()) - - def testMismatchingNumArgs(self): - # pylint: disable=anomalous-backslash-in-string - with self.assertRaisesRegexp(TypeError, - "The number of arguments accepted by the " - "decorated function `my_function` \(2\) must " - "match the number of ShapeAndDtype objects " - "passed to the graph_callable\(\) decorator " - "\(1\)."): - @graph_callable.graph_callable([ - graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def my_function(x, y): # pylint: disable=unused-variable - return x + y - # pylint: enable=anomalous-backslash-in-string - - def testPureFunction(self): - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) - def f(x): - return math_ops.add(x, constant_op.constant(3)) - - self.assertAllEqual(5, f(constant_op.constant(2))) - - def testNestedFunction(self): - # TensorFlow function (which is what would be used in TensorFlow graph - # construction). - @function.Defun(dtypes.int32, dtypes.int32) - def add(a, b): - return math_ops.add(a, b) - - # A graph_callable that will invoke the TensorFlow function. - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) - def add_one(x): - return add(x, 1) - - self.assertAllEqual(3, add_one(constant_op.constant(2))) - - # TODO(ashankar): Make this work. - # The problem is that the two graph_callables (for add_one and add_two) - # are both trying to register the FunctionDef corresponding to "add". - def DISABLED_testRepeatedUseOfSubFunction(self): - - @function.Defun(dtypes.int32, dtypes.int32) - def add(a, b): - return math_ops.add(a, b) - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) - def add_one(x): - return add(x, 1) - - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) - def add_two(x): - return add(x, 2) - - two = constant_op.constant(2) - self.assertAllEqual(3, add_one(two)) - self.assertAllEqual(4, add_two(two)) - - def testNestedSequenceInputs(self): - sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32) - @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]]) - def my_op(inputs): - a, b, c = inputs - e, f = b - v = variable_scope.get_variable( - "my_v", initializer=init_ops.zeros_initializer(), shape=()) - return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v - - inputs = [constant_op.constant(1.), - [constant_op.constant(2.), constant_op.constant(3.)], - constant_op.constant(4.)] - ret = my_op(inputs) - self.assertEqual(len(ret), 2.) - self.assertAllEqual(ret[1], 10.) - - my_op.variables[0].assign(1.) - ret = my_op(inputs) - self.assertAllEqual(ret[1], 11.) - - def testVariableShapeIsTensorShape(self): - @graph_callable.graph_callable([]) - def my_function(): - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=()) - self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape) - - my_function() - - def testIncorrectlyShapedInputs(self): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)]) - def my_function(x): - v = variable_scope.get_variable( - "v", initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - with self.assertRaises(ValueError): - my_function([1, 2]) - - self.assertTrue(([1, 2, 3] == my_function( - constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all()) - - def testGradients(self): - @graph_callable.graph_callable([]) - def my_function(): - v = variable_scope.get_variable( - "v", initializer=init_ops.constant_initializer(3.), shape=()) - return v * v - - grad_fn = backprop.implicit_grad(my_function) - grads_and_vars = list(zip(*grad_fn())) - self.assertAllEqual(6., grads_and_vars[0][0]) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py index d2a2b4e223..3cf3a61a62 100644 --- a/tensorflow/python/eager/graph_only_ops_test.py +++ b/tensorflow/python/eager/graph_only_ops_test.py @@ -32,13 +32,13 @@ class GraphOnlyOpsTest(test_util.TensorFlowTestCase): def testGraphZerosLike(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) z_tf = graph_only_ops.graph_zeros_like(x) - with self.test_session(): + with self.cached_session(): self.assertAllClose(np.zeros((2, 3)), z_tf.eval()) def testGraphPlaceholder(self): x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,)) y_tf = math_ops.square(x_tf) - with self.test_session() as sess: + with self.cached_session() as sess: x = np.array([42]) y = sess.run(y_tf, feed_dict={x_tf: np.array([42])}) self.assertAllClose(np.square(x), y) diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 000152855d..5f027d107c 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -24,12 +24,10 @@ from tensorflow.python import pywrap_tensorflow VSpace = collections.namedtuple( - "VSpace", - ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"]) + "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"]) def imperative_grad( - vspace, tape, target, sources, @@ -41,7 +39,6 @@ def imperative_grad( gradients for all sources. Args: - vspace: the vector space in which to differentiate. tape: the gradient tape which stores the trace. target: either a Tensor or list of Tensors to be differentiated. sources: list of Tensors for which we want gradients @@ -60,4 +57,7 @@ def imperative_grad( computation of target. """ return pywrap_tensorflow.TFE_Py_TapeGradient( - tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access + tape._tape, # pylint: disable=protected-access + target, + sources, + output_gradients) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 15d2ccf9d2..f34ce6af79 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" +#include "structmember.h" // NOLINT // For PyMemberDef + // forward declare struct EagerTensor; @@ -263,6 +265,14 @@ typedef struct EagerTensor { TF_Status* status; PyObject* weakreflist; /* List of weak references */ + + // Per-instance attribute dictionary, to support monkey patching + // (e.g. EagerTensor.assign when slicing variables). This dictionary is + // created by CPython the first time an attribute is assigned, pointed to by + // tp_dictoffset. Note that garbage collection is not enabled for + // EagerTensors, so assigning objects to EagerTensor attributes which require + // garbage collection is likely to cause issues. + PyObject* dict; } EagerTensor; namespace { @@ -311,17 +321,42 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { Py_INCREF(Py_None); self->tensor_shape = Py_None; self->status = TF_NewStatus(); + self->dict = nullptr; self->weakreflist = nullptr; PyObject* value; PyObject* context = nullptr; PyObject* device = nullptr; PyObject* dtype = Py_None; - const char* kwlist[] = {"value", "context", "device", "dtype", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O", + PyObject* other_value = nullptr; + const char* kwlist[] = {"value", "context", "device", + "dtype", "other_value", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO", const_cast<char**>(kwlist), &value, &context, - &device, &dtype)) { + &device, &dtype, &other_value)) { return -1; } + + if (other_value != nullptr) { + if (!EagerTensor_CheckExact(other_value)) { + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting an EagerTensor for other_value, got ", + Py_TYPE(other_value)->tp_name) + .c_str()); + + return -1; + } + EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value); + self->handle = + TFE_TensorHandleCopySharingTensor(other->handle, self->status); + + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + return -1; + } + + return 0; + } + // Extract dtype int desired_dtype = -1; if (dtype != Py_None) { @@ -410,6 +445,10 @@ void EagerTensor_dealloc(EagerTensor* self) { Py_DECREF(self->handle_data); Py_DECREF(self->keras_mask); Py_DECREF(self->tensor_shape); + // If an attribute dictionary has been created, release it. Note that this + // is only ever created by CPython's attribute setting methods; we don't + // create it ourselves. + Py_CLEAR(self->dict); if (self->handle != nullptr) { TFE_DeleteTensorHandle(self->handle); self->handle = nullptr; @@ -474,6 +513,30 @@ static PyObject* EagerTensor_rank(EagerTensor* self) { #endif } +// Getter for `_num_elements`. +static PyObject* EagerTensor_num_elements(EagerTensor* self) { + auto handle = self->handle; + int n = TFE_TensorHandleNumDims(handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } + tensorflow::int64 value = 1; + if (PyErr_Occurred()) return nullptr; + for (int i = 0; i < n; ++i) { + int64_t dim = TFE_TensorHandleDim(handle, i, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions"); + return nullptr; + } + value *= dim; + } + return PyLong_FromLongLong(value); +} + static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) { Py_INCREF(self->handle_data); return self->handle_data; @@ -582,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = { {nullptr} /* Sentinel */ }; +#if PY_MAJOR_VERSION < 3 +// Only used for Python2 since Python3 seems to set the __dict__ correctly. +static PyMemberDef EagerTensor_members[] = { + {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict), + READONLY}, + {nullptr}, +}; +#endif + static PyMethodDef EagerTensor_methods[] = { {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS, PyDoc_STR("_numpy")}, @@ -592,6 +664,8 @@ static PyMethodDef EagerTensor_methods[] = { {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")}, {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")}, + {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS, + PyDoc_STR("_num_elements")}, {nullptr, nullptr}, }; @@ -654,13 +728,13 @@ static PyTypeObject _EagerTensorType = { nullptr, /* tp_iter */ nullptr, /* tp_iternext */ EagerTensor_methods, /* tp_methods */ - nullptr, /* tp_members */ + EagerTensor_members, /* tp_members */ EagerTensor_getseters, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ + offsetof(EagerTensor, dict), /* tp_dictoffset */ (initproc)EagerTensor_init, /* tp_init */ nullptr, /* tp_alloc */ nullptr, /* tp_new */ @@ -788,8 +862,9 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType"); return nullptr; } + EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict); #else - _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class); + _EagerTensorType.tp_base = base_class_type; if (PyType_Ready(&_EagerTensorType) < 0) { if (PyErr_Occurred()) return nullptr; @@ -800,9 +875,6 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { EagerTensorType = &_EagerTensorType; Py_INCREF(EagerTensorType); #endif - // We disable instance based attribute lookup. Its not clear if these - // dictionaries are correctly initialized in the first place. - EagerTensorType->tp_dictoffset = 0; return reinterpret_cast<PyObject*>(EagerTensorType); } diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index a916a75f00..f1b4042ec9 100644..100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -59,6 +59,10 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); // This function is not thread-safe. PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e); +// Registers e as the VSpace to use. +// `vspace` must be a imperative_grad.py:VSpace named tuple. +PyObject* TFE_Py_RegisterVSpace(PyObject* e); + // Registers e as the Exception to be raised when the conditions of // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it // is a signal to the calling code that it should fall back to the safer (and @@ -89,7 +93,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, PyObject* exception); // Returns the string associated with the passed-in python object. -char* TFE_GetPythonString(PyObject* o); +const char* TFE_GetPythonString(PyObject* o); // Returns a unique id on each call. int64_t get_uid(); @@ -124,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); // To unset the profiler, pass Py_None as the value of `profiler`. PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler); -// Creates a new tape and adds it to the active set. `persistent` must be a -// PyBool_Type, i.e either Py_True or Py_False -PyObject* TFE_Py_TapeSetNew(PyObject* persistent); +// Creates a new tape and adds it to the active set. `persistent` and +// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`). +PyObject* TFE_Py_TapeSetNew(PyObject* persistent, + PyObject* watch_accessed_variables); // Removes the passed tape from the set of active tapes. void TFE_Py_TapeSetRemove(PyObject* tape); @@ -138,7 +143,7 @@ void TFE_Py_TapeSetAdd(PyObject* tape); PyObject* TFE_Py_TapeSetIsEmpty(); PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors); -void TFE_Py_TapeSetWatch(PyObject* tensor); +void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor); void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id); // Stops any gradient recording on the current thread. @@ -158,18 +163,20 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, PyObject* input_tensor_ids, PyObject* backward_function); +// Notifies all tapes that a variable has been accessed. +void TFE_Py_TapeVariableAccessed(PyObject* variable); + // Watches the given variable object on the given tape. -void TFE_Py_TapeSetWatchVariable(PyObject* variable); +void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must -// have been produced by TFE_Py_NewTape. `vspace` must be a -// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python +// have been produced by TFE_Py_NewTape. `target` and `sources` must be python // lists of Tensor objects. `output_gradients` is either None or a python list // of either Tensor or None, and if not None should have the same length as // target. -PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, - PyObject* target, PyObject* sources, - PyObject* output_gradients, TF_Status* status); +PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, + PyObject* sources, PyObject* output_gradients, + TF_Status* status); // Execute a tensorflow operation assuming that all provided inputs are // correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors, diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 18fafd0de1..46dcf7c8a8 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -216,7 +216,7 @@ bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, #if PY_MAJOR_VERSION >= 3 if (PyUnicode_Check(py_value)) { Py_ssize_t size = 0; - char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); + const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); if (buf == nullptr) return false; *value = tensorflow::StringPiece(buf, size); return true; @@ -825,7 +825,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, return -1; } -char* TFE_GetPythonString(PyObject* o) { +const char* TFE_GetPythonString(PyObject* o) { if (PyBytes_Check(o)) { return PyBytes_AsString(o); } @@ -892,9 +892,10 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { class GradientTape : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> { public: - explicit GradientTape(bool persistent) + explicit GradientTape(bool persistent, bool watch_accessed_variables) : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>( - persistent) {} + persistent), + watch_accessed_variables_(watch_accessed_variables) {} virtual ~GradientTape() { for (const IdAndVariable& v : watched_variables_) { @@ -902,6 +903,12 @@ class GradientTape } } + void VariableAccessed(PyObject* v) { + if (watch_accessed_variables_) { + WatchVariable(v); + } + } + void WatchVariable(PyObject* v) { tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); if (handle == nullptr) { @@ -951,6 +958,7 @@ class GradientTape } }; + bool watch_accessed_variables_; tensorflow::mutex watched_variables_mu_; std::set<IdAndVariable, CompareById> watched_variables_ GUARDED_BY(watched_variables_mu_); @@ -1056,11 +1064,13 @@ void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } -PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { +PyObject* TFE_Py_TapeSetNew(PyObject* persistent, + PyObject* watch_accessed_variables) { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); - tape->tape = new GradientTape(persistent == Py_True); + tape->tape = new GradientTape(persistent == Py_True, + watch_accessed_variables == Py_True); Py_INCREF(tape); GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)); return reinterpret_cast<PyObject*>(tape); @@ -1154,7 +1164,7 @@ PyObject* TFE_Py_TapeSetShouldRecord(PyObject* tensors) { Py_RETURN_FALSE; } -void TFE_Py_TapeSetWatch(PyObject* tensor) { +void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { if (*ThreadTapeIsStopped()) { return; } @@ -1162,9 +1172,7 @@ void TFE_Py_TapeSetWatch(PyObject* tensor) { if (PyErr_Occurred()) { return; } - for (TFE_Py_Tape* tape : *GetTapeSet()) { - tape->tape->Watch(tensor_id); - } + reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); } static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { @@ -1235,15 +1243,22 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { return list; } -void TFE_Py_TapeSetWatchVariable(PyObject* variable) { +void TFE_Py_TapeVariableAccessed(PyObject* variable) { if (*ThreadTapeIsStopped()) { return; } for (TFE_Py_Tape* tape : SafeTapeSet()) { - tape->tape->WatchVariable(variable); + tape->tape->VariableAccessed(variable); } } +void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { + if (*ThreadTapeIsStopped()) { + return; + } + reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); +} + PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple(); } @@ -1350,7 +1365,9 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> { public: - explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {} + explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { + Py_INCREF(py_vspace_); + } tensorflow::Status Initialize() { num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); @@ -1378,6 +1395,8 @@ class PyVSpace Py_XDECREF(aggregate_fn_); Py_XDECREF(zeros_); Py_XDECREF(ones_); + + Py_DECREF(py_vspace_); } tensorflow::int64 NumElements(PyObject* tensor) const final { @@ -1493,6 +1512,22 @@ class PyVSpace PyObject* zeros_; PyObject* ones_; }; +PyVSpace* py_vspace = nullptr; + +PyObject* TFE_Py_RegisterVSpace(PyObject* e) { + if (py_vspace != nullptr) { + delete py_vspace; + } + + py_vspace = new PyVSpace(e); + auto status = py_vspace->Initialize(); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + delete py_vspace; + return nullptr; + } + + Py_RETURN_NONE; +} std::vector<PyObject*> MakeTensorList(PyObject* tensors) { PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); @@ -1509,9 +1544,9 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) { return list; } -PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, - PyObject* target, PyObject* sources, - PyObject* output_gradients, TF_Status* status) { +PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, + PyObject* sources, PyObject* output_gradients, + TF_Status* status) { TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); if (!tape_obj->tape->IsPersistent()) { auto* tape_set = GetTapeSet(); @@ -1526,10 +1561,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, return nullptr; } } - PyVSpace c_vspace(vspace); - if (!c_vspace.Initialize().ok()) { - return nullptr; - } std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target); if (PyErr_Occurred()) { @@ -1553,7 +1584,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, } std::vector<PyObject*> result; status->status = tape_obj->tape->ComputeGradient( - c_vspace, target_vec, sources_vec, outgrad_vec, &result); + *py_vspace, target_vec, sources_vec, outgrad_vec, &result); if (!status->status.ok()) { if (PyErr_Occurred()) { // Do not propagate the erroneous status as that would swallow the @@ -1709,118 +1740,169 @@ PyObject* MaybeGetDTypeForAttr(const string& attr, Py_RETURN_NONE; } -bool OpDoesntRequireOutput(const string& op_name) { - static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs = - new tensorflow::gtl::FlatSet<string>({ - "Identity", - "MatMul", - "Conv2DBackpropInput", - "Conv2DBackpropFilter", - "Conv3D", - "Conv3DBackpropInputV2", - "AvgPool3D", - "AvgPool3DGrad", - "MaxPool3D", - "MaxPool3DGrad", - "MaxPool3DGradGrad", - "BiasAdd", - "BiasAddV1", - "BiasAddGrad", - "Softplus", - "SoftplusGrad", - "Softsign", - "ReluGrad", - "LeakyRelu", - "LeakyReluGrad", - "Conv2D", - "DepthwiseConv2dNative", - "Dilation2D", - "AvgPool", - "AvgPoolGrad", - "BatchNormWithGlobalNormalization", - "L2Loss", - "Sum", - "Prod", - "SegmentSum", - "SegmentMean", - "SparseSegmentSum", - "SparseSegmentMean", - "SparseSegmentSqrtN", - "SegmentMin", - "SegmentMax", - "UnsortedSegmentSum", - "UnsortedSegmentMax", - "Abs", - "Neg", - "ReciprocalGrad", - "Square", - "Expm1", - "Log", - "Log1p", - "TanhGrad", - "SigmoidGrad", - "Sign", - "Sin", - "Cos", - "Tan", - "Add", - "Sub", - "Mul", - "Div", - "RealDiv", - "Maximum", - "Minimum", - "SquaredDifference", - "Select", - "SparseMatMul", - "BatchMatMul", - "Complex", - "Real", - "Imag", - "Angle", - "Conj", - "Cast", - "Cross", - "Cumsum", - "Cumprod", - "ReadVariableOp", - "VarHandleOp", - "Shape", +// Returns a pair where the first value of the pair indicates whether or not all +// outputs are unused. If the first value is false, the second value is a +// set that identifies which of the output indices are unused. +bool OpGradientDoesntRequireOutputIndices( + const string& op_name, + std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) { + static tensorflow::gtl::FlatMap< + string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m = + new tensorflow::gtl::FlatMap< + string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({ + // Ops that don't require any outputs. + {"Identity", {true, {}}}, + {"MatMul", {true, {}}}, + {"Conv2DBackpropInput", {true, {}}}, + {"Conv2DBackpropFilter", {true, {}}}, + {"Conv3D", {true, {}}}, + {"Conv3DBackpropInputV2", {true, {}}}, + {"AvgPool3D", {true, {}}}, + {"AvgPool3DGrad", {true, {}}}, + {"MaxPool3D", {true, {}}}, + {"MaxPool3DGrad", {true, {}}}, + {"MaxPool3DGradGrad", {true, {}}}, + {"BiasAdd", {true, {}}}, + {"BiasAddV1", {true, {}}}, + {"BiasAddGrad", {true, {}}}, + {"Softplus", {true, {}}}, + {"SoftplusGrad", {true, {}}}, + {"Softsign", {true, {}}}, + {"ReluGrad", {true, {}}}, + {"LeakyRelu", {true, {}}}, + {"LeakyReluGrad", {true, {}}}, + {"Conv2D", {true, {}}}, + {"DepthwiseConv2dNative", {true, {}}}, + {"Dilation2D", {true, {}}}, + {"AvgPool", {true, {}}}, + {"AvgPoolGrad", {true, {}}}, + {"BatchNormWithGlobalNormalization", {true, {}}}, + {"L2Loss", {true, {}}}, + {"Sum", {true, {}}}, + {"Prod", {true, {}}}, + {"SegmentSum", {true, {}}}, + {"SegmentMean", {true, {}}}, + {"SparseSegmentSum", {true, {}}}, + {"SparseSegmentMean", {true, {}}}, + {"SparseSegmentSqrtN", {true, {}}}, + {"SegmentMin", {true, {}}}, + {"SegmentMax", {true, {}}}, + {"UnsortedSegmentSum", {true, {}}}, + {"UnsortedSegmentMax", {true, {}}}, + {"Abs", {true, {}}}, + {"Neg", {true, {}}}, + {"ReciprocalGrad", {true, {}}}, + {"Square", {true, {}}}, + {"Expm1", {true, {}}}, + {"Log", {true, {}}}, + {"Log1p", {true, {}}}, + {"TanhGrad", {true, {}}}, + {"SigmoidGrad", {true, {}}}, + {"Sign", {true, {}}}, + {"Sin", {true, {}}}, + {"Cos", {true, {}}}, + {"Tan", {true, {}}}, + {"Add", {true, {}}}, + {"Sub", {true, {}}}, + {"Mul", {true, {}}}, + {"Div", {true, {}}}, + {"RealDiv", {true, {}}}, + {"Maximum", {true, {}}}, + {"Minimum", {true, {}}}, + {"SquaredDifference", {true, {}}}, + {"Select", {true, {}}}, + {"SparseMatMul", {true, {}}}, + {"BatchMatMul", {true, {}}}, + {"Complex", {true, {}}}, + {"Real", {true, {}}}, + {"Imag", {true, {}}}, + {"Angle", {true, {}}}, + {"Conj", {true, {}}}, + {"Cast", {true, {}}}, + {"Cross", {true, {}}}, + {"Cumsum", {true, {}}}, + {"Cumprod", {true, {}}}, + {"ReadVariableOp", {true, {}}}, + {"VarHandleOp", {true, {}}}, + {"Shape", {true, {}}}, + {"StridedSlice", {true, {}}}, + {"Fill", {true, {}}}, + + // Ops that don't require a subset of outputs. + {"FusedBatchNorm", {false, {0, 1, 2}}}, }); - return ops_that_dont_require_outputs->find(op_name) != - ops_that_dont_require_outputs->end(); -} - -bool OpDoesntRequireInput(const string& op_name) { - static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs = - new tensorflow::gtl::FlatSet<string>({ - "Identity", - "Softmax", - "LogSoftmax", - "BiasAdd", - "Relu", - "Relu6", - "Elu", - "Selu", - "SparseSoftmaxCrossEntropyWithLogits", - "Neg", - "Inv", - "Reciprocal", - "Sqrt", - "Exp", - "Tanh", - "Sigmoid", - "Real", - "Imag", - "Conj", - "ReadVariableOp", - "VarHandleOp", - "Shape", + auto it = m->find(op_name); + + if (it == m->end()) return false; + + *output = &it->second; + return true; +} + +// Returns a pair where the first value of the pair indicates whether or not all +// inputs are unused. If the first value is false, the second value is a +// set that identifies which of the input indices are unused. +bool OpGradientDoesntRequireInputIndices( + const string& op_name, + std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) { + static tensorflow::gtl::FlatMap< + string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m = + new tensorflow::gtl::FlatMap< + string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({ + // Ops that don't require any inputs. + {"Identity", {true, {}}}, + {"Softmax", {true, {}}}, + {"LogSoftmax", {true, {}}}, + {"BiasAdd", {true, {}}}, + {"Relu", {true, {}}}, + {"Relu6", {true, {}}}, + {"Elu", {true, {}}}, + {"Selu", {true, {}}}, + {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}}, + {"Neg", {true, {}}}, + {"Inv", {true, {}}}, + {"Reciprocal", {true, {}}}, + {"Sqrt", {true, {}}}, + {"Exp", {true, {}}}, + {"Tanh", {true, {}}}, + {"Sigmoid", {true, {}}}, + {"Real", {true, {}}}, + {"Imag", {true, {}}}, + {"Conj", {true, {}}}, + {"ReadVariableOp", {true, {}}}, + {"VarHandleOp", {true, {}}}, + {"Shape", {true, {}}}, + {"Fill", {true, {}}}, + + // Ops that don't require a subset of inputs. + {"FusedBatchNorm", {false, {2}}}, }); - return ops_that_dont_require_inputs->find(op_name) != - ops_that_dont_require_inputs->end(); + auto it = m->find(op_name); + + if (it == m->end()) return false; + + *output = &it->second; + return true; +} + +PyObject* CopySequenceSettingIndicesToNull( + PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) { + tensorflow::Safe_PyObjectPtr fast_seq( + PySequence_Fast(seq, "unable to allocate")); + PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get())); + for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) { + PyObject* item; + if (indices.find(i) != indices.end()) { + item = Py_None; + } else { + item = PySequence_Fast_GET_ITEM(fast_seq.get(), i); + } + Py_INCREF(item); + PyTuple_SET_ITEM(result, i, item); + } + return result; } PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, @@ -1840,16 +1922,35 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, if (!should_record) Py_RETURN_NONE; string c_op_name = TFE_GetPythonString(op_name); + PyObject* op_outputs; - if (OpDoesntRequireOutput(c_op_name)) { - op_outputs = Py_None; + bool op_outputs_tuple_created = false; + std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required; + + if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) { + if (outputs_not_required->first) { + op_outputs = Py_None; + } else { + op_outputs_tuple_created = true; + op_outputs = CopySequenceSettingIndicesToNull( + results, outputs_not_required->second); + } } else { op_outputs = results; } PyObject* op_inputs; - if (OpDoesntRequireInput(c_op_name)) { - op_inputs = Py_None; + bool op_inputs_tuple_created = false; + std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required; + + if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) { + if (inputs_not_required->first) { + op_inputs = Py_None; + } else { + op_inputs_tuple_created = true; + op_inputs = + CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second); + } } else { op_inputs = inputs; } @@ -1892,18 +1993,20 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, }); Py_DECREF(num_inputs); + if (op_outputs_tuple_created) Py_DECREF(op_outputs); + if (op_inputs_tuple_created) Py_DECREF(op_inputs); Py_RETURN_NONE; } -void MaybeWatchVariable(PyObject* input) { +void MaybeNotifyVariableAccessed(PyObject* input) { DCHECK(CheckResourceVariable(input)); DCHECK(PyObject_HasAttrString(input, "_trainable")); tensorflow::Safe_PyObjectPtr trainable( PyObject_GetAttrString(input, "_trainable")); if (trainable.get() == Py_False) return; - TFE_Py_TapeSetWatchVariable(input); + TFE_Py_TapeVariableAccessed(input); } bool CastTensor(const FastPathOpExecInfo& op_exec_info, @@ -1934,7 +2037,7 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info, bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, PyObject* input, tensorflow::Safe_PyObjectPtr* output, TF_Status* status) { - MaybeWatchVariable(input); + MaybeNotifyVariableAccessed(input); TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index caa217b70c..399d90223c 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -33,9 +33,10 @@ class Tape(object): return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) -def push_new_tape(persistent=False): +def push_new_tape(persistent=False, watch_accessed_variables=True): """Pushes a new tape onto the tape stack.""" - tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent) + tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent, + watch_accessed_variables) return Tape(tape) @@ -44,22 +45,19 @@ def push_tape(tape): pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access -def watch(tensor): - """Marks this tensor to be watched by all tapes in the stack. +def watch(tape, tensor): + """Marks this tensor to be watched by the given tape.""" + pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access - Args: - tensor: tensor to be watched. - """ - pywrap_tensorflow.TFE_Py_TapeSetWatch(tensor) +def watch_variable(tape, variable): + """Marks this variable to be watched by the given tape.""" + pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access -def watch_variable(variable): - """Marks this variable to be watched by all tapes in the stack. - Args: - variable: variable to be watched. - """ - pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable) +def variable_accessed(variable): + """Notifies all tapes in the stack that a variable has been accessed.""" + pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable) def pop_tape(tape): diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py index 4326d5efa3..acd0e569f1 100644 --- a/tensorflow/python/eager/tape_test.py +++ b/tensorflow/python/eager/tape_test.py @@ -72,7 +72,7 @@ class TapeTest(test.TestCase): a = constant_op.constant([[1., 0.], [0., 1.]]) b = constant_op.constant([[1., 2.], [3., 4.]]) da, db = backprop.gradients_function(fn, [0, 1])(a, b) - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32) tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32) tf_c = tf_a + tf_b @@ -135,7 +135,7 @@ class TapeTest(test.TestCase): a = constant_op.constant([[1., 0.], [0., 1.]]) b = constant_op.constant([[1., 2.], [3., 4.]]) da, db = backprop.gradients_function(fn, [0, 1])(a, b) - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32) tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32) tf_mm = math_ops.matmul(tf_a, tf_b) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 871136e2c8..344a9b25bd 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops def _create_tensor(value, device=None, dtype=None): @@ -295,6 +296,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): def testFloatTensor(self): self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype) self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype) + self.assertEqual(dtypes.float16, _create_tensor(np.float16()).dtype) self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype) def testSliceDimOutOfRange(self): @@ -332,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): "but tensor at index 2 has rank 0"): pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testTensorDir(self): + t = array_ops.zeros(1) + t.test_attr = "Test" + + instance_dir = dir(t) + type_dir = dir(ops.EagerTensor) + + # Monkey patched attributes should show up in dir(t) + self.assertIn("test_attr", instance_dir) + instance_dir.remove("test_attr") + self.assertEqual(instance_dir, type_dir) + if __name__ == "__main__": test.main() |