diff options
author | Alexandre Passos <apassos@google.com> | 2017-10-13 16:30:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-13 16:35:44 -0700 |
commit | 84580227a398e68001c2114fae966a62ac918045 (patch) | |
tree | ac4356e6323742da29c784585624a53fa98bbff8 | |
parent | 5dd569cf026bae92330a194c8f2895d0f48149d9 (diff) |
imperative_grad takes the tape instead of popping it.
PiperOrigin-RevId: 172162006
-rw-r--r-- | tensorflow/python/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/imperative_grad.py | 11 |
3 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 1d20a0782f..69b96df87c 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -431,5 +431,4 @@ py_library( name = "imperative_grad", srcs = ["imperative_grad.py"], srcs_version = "PY2AND3", - deps = [":tape"], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 554b9a818c..0060dd0c1c 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -338,6 +338,7 @@ def implicit_val_and_grad(f): variables = tape.top_tape_watched_variables() sources = [x.handle for x in variables] grad = imperative_grad.imperative_grad(_default_vspace, + tape.pop_tape(), nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -574,7 +575,7 @@ def val_and_grad_function(f, params=None): tape.watch(args[i]) result = f(*args) return result, imperative_grad.imperative_grad( - _default_vspace, nest.flatten(result), sources, + _default_vspace, tape.pop_tape(), nest.flatten(result), sources, output_gradients=nest.flatten(dy) if dy is not None else None) return decorated diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index dd9d691d26..d30d124040 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -20,7 +20,7 @@ from __future__ import print_function import collections -from tensorflow.python.eager import tape +from tensorflow.python.eager import tape as tape_module # Terminology: @@ -76,7 +76,7 @@ def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources): # op is None or -1 if the tensor is a source (i.e. was watched directly) if op is None or op == -1 or op in o_to_e: continue - op_trace = tape.TapeEntry(*op_to_entry[op]) + op_trace = tape_module.TapeEntry(*op_to_entry[op]) o_to_e[op] = op_trace for it in op_trace.input_ids: if it in tensor_usage_counts: @@ -125,6 +125,7 @@ VSpace = collections.namedtuple( def imperative_grad( vspace, + tape, target, sources, output_gradients=None): @@ -136,6 +137,7 @@ def imperative_grad( Args: vspace: the vector space in which to differentiate. + tape: the gradient tape which stores the trace. target: either a Tensor or list of Tensors to be differentiated. sources: list of Tensors for which we want gradients output_gradients: if not None, a list of gradient provided for each Target, @@ -152,10 +154,7 @@ def imperative_grad( or if only non-differentiable functions of the source were used in the computation of target. """ - if not tape._tape_stack.stack: # pylint: disable=protected-access - raise RuntimeError("Computing a gradient with no tape present") - bp_tape = tape.pop_tape() - tensor_to_op, op_to_entry = bp_tape.export() + tensor_to_op, op_to_entry = tape.export() # This overwrites the op_to_entry variable, which will release all memory used # to keep traces that are irrelevant to the gradient computation we're doing # here. |