aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-13 16:30:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 16:35:44 -0700
commit84580227a398e68001c2114fae966a62ac918045 (patch)
treeac4356e6323742da29c784585624a53fa98bbff8
parent5dd569cf026bae92330a194c8f2895d0f48149d9 (diff)
imperative_grad takes the tape instead of popping it.
PiperOrigin-RevId: 172162006
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/backprop.py3
-rw-r--r--tensorflow/python/eager/imperative_grad.py11
3 files changed, 7 insertions, 8 deletions
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 1d20a0782f..69b96df87c 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -431,5 +431,4 @@ py_library(
name = "imperative_grad",
srcs = ["imperative_grad.py"],
srcs_version = "PY2AND3",
- deps = [":tape"],
)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 554b9a818c..0060dd0c1c 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -338,6 +338,7 @@ def implicit_val_and_grad(f):
variables = tape.top_tape_watched_variables()
sources = [x.handle for x in variables]
grad = imperative_grad.imperative_grad(_default_vspace,
+ tape.pop_tape(),
nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
@@ -574,7 +575,7 @@ def val_and_grad_function(f, params=None):
tape.watch(args[i])
result = f(*args)
return result, imperative_grad.imperative_grad(
- _default_vspace, nest.flatten(result), sources,
+ _default_vspace, tape.pop_tape(), nest.flatten(result), sources,
output_gradients=nest.flatten(dy) if dy is not None else None)
return decorated
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index dd9d691d26..d30d124040 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import collections
-from tensorflow.python.eager import tape
+from tensorflow.python.eager import tape as tape_module
# Terminology:
@@ -76,7 +76,7 @@ def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources):
# op is None or -1 if the tensor is a source (i.e. was watched directly)
if op is None or op == -1 or op in o_to_e:
continue
- op_trace = tape.TapeEntry(*op_to_entry[op])
+ op_trace = tape_module.TapeEntry(*op_to_entry[op])
o_to_e[op] = op_trace
for it in op_trace.input_ids:
if it in tensor_usage_counts:
@@ -125,6 +125,7 @@ VSpace = collections.namedtuple(
def imperative_grad(
vspace,
+ tape,
target,
sources,
output_gradients=None):
@@ -136,6 +137,7 @@ def imperative_grad(
Args:
vspace: the vector space in which to differentiate.
+ tape: the gradient tape which stores the trace.
target: either a Tensor or list of Tensors to be differentiated.
sources: list of Tensors for which we want gradients
output_gradients: if not None, a list of gradient provided for each Target,
@@ -152,10 +154,7 @@ def imperative_grad(
or if only non-differentiable functions of the source were used in the
computation of target.
"""
- if not tape._tape_stack.stack: # pylint: disable=protected-access
- raise RuntimeError("Computing a gradient with no tape present")
- bp_tape = tape.pop_tape()
- tensor_to_op, op_to_entry = bp_tape.export()
+ tensor_to_op, op_to_entry = tape.export()
# This overwrites the op_to_entry variable, which will release all memory used
# to keep traces that are irrelevant to the gradient computation we're doing
# here.