From 84580227a398e68001c2114fae966a62ac918045 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 13 Oct 2017 16:30:53 -0700 Subject: imperative_grad takes the tape instead of popping it. PiperOrigin-RevId: 172162006 --- tensorflow/python/eager/BUILD | 1 - tensorflow/python/eager/backprop.py | 3 ++- tensorflow/python/eager/imperative_grad.py | 11 +++++------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 1d20a0782f..69b96df87c 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -431,5 +431,4 @@ py_library( name = "imperative_grad", srcs = ["imperative_grad.py"], srcs_version = "PY2AND3", - deps = [":tape"], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 554b9a818c..0060dd0c1c 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -338,6 +338,7 @@ def implicit_val_and_grad(f): variables = tape.top_tape_watched_variables() sources = [x.handle for x in variables] grad = imperative_grad.imperative_grad(_default_vspace, + tape.pop_tape(), nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -574,7 +575,7 @@ def val_and_grad_function(f, params=None): tape.watch(args[i]) result = f(*args) return result, imperative_grad.imperative_grad( - _default_vspace, nest.flatten(result), sources, + _default_vspace, tape.pop_tape(), nest.flatten(result), sources, output_gradients=nest.flatten(dy) if dy is not None else None) return decorated diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index dd9d691d26..d30d124040 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -20,7 +20,7 @@ from __future__ import print_function import collections -from tensorflow.python.eager import tape +from tensorflow.python.eager import tape as tape_module # Terminology: @@ -76,7 +76,7 @@ def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources): # op is None or -1 if the tensor is a source (i.e. was watched directly) if op is None or op == -1 or op in o_to_e: continue - op_trace = tape.TapeEntry(*op_to_entry[op]) + op_trace = tape_module.TapeEntry(*op_to_entry[op]) o_to_e[op] = op_trace for it in op_trace.input_ids: if it in tensor_usage_counts: @@ -125,6 +125,7 @@ VSpace = collections.namedtuple( def imperative_grad( vspace, + tape, target, sources, output_gradients=None): @@ -136,6 +137,7 @@ def imperative_grad( Args: vspace: the vector space in which to differentiate. + tape: the gradient tape which stores the trace. target: either a Tensor or list of Tensors to be differentiated. sources: list of Tensors for which we want gradients output_gradients: if not None, a list of gradient provided for each Target, @@ -152,10 +154,7 @@ def imperative_grad( or if only non-differentiable functions of the source were used in the computation of target. """ - if not tape._tape_stack.stack: # pylint: disable=protected-access - raise RuntimeError("Computing a gradient with no tape present") - bp_tape = tape.pop_tape() - tensor_to_op, op_to_entry = bp_tape.export() + tensor_to_op, op_to_entry = tape.export() # This overwrites the op_to_entry variable, which will release all memory used # to keep traces that are irrelevant to the gradient computation we're doing # here. -- cgit v1.2.3