diff options
author | Alexandre Passos <apassos@google.com> | 2017-10-09 16:24:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-09 16:28:26 -0700 |
commit | 88145023cea47b4a96cc04f8febe205d50a0d0d6 (patch) | |
tree | 64afb7f5b7b32f57be42ac88f951994a4e187ef6 /tensorflow/python/eager/tape.py | |
parent | 33d55122d994d12f2a066f9ec4f0f03094a59579 (diff) |
Removing side outputs from tape code.
They belong better in future function objects (simplifies tape move to C)
PiperOrigin-RevId: 171603665
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r-- | tensorflow/python/eager/tape.py | 19 |
1 files changed, 3 insertions, 16 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 84814d48fd..4578a7190d 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -32,7 +32,7 @@ def tid(tensor): class TapeEntry( collections.namedtuple("TapeEntry", [ "op_type", - "output_ids", "input_ids", "side_outputs", "backward_function", + "output_ids", "input_ids", "backward_function", "output_shape_and_dtype", ])): """Entry in the gradient tape. @@ -43,8 +43,6 @@ class TapeEntry( Args: output_ids: tensor_id(t) for each output tensor T input_ids: tensor_id(t) for each input tensor T - side_outputs: optional tensors (not IDs) which need to be provided to the - backward function. backward_function: function to be called with the downstream gradients and side outputs as arguments which computes the backward pass. output_shape_and_dtype: a list of (shape_tuple, dtype) for every output @@ -69,8 +67,6 @@ class Tape(object): self._op_tape = {} # next operation ID self._next_op_id = 0 - # List of directly watched tensors - self._watched = [] # Set of directly watched variables self._watched_variables = set() @@ -91,14 +87,13 @@ class Tape(object): if i not in self._tensor_tape: self._tensor_tape[i] = None self._tensor_usage[i] = 1 - self._watched.append(tensor) def watch_variable(self, v): self._watched_variables.add(v) self.watch(v.handle) def record_operation(self, op_type, output_tensors, input_tensors, - side_outputs, backward_function): + backward_function): """Records an operation in the tape.""" if not self.should_record(input_tensors): return output_tensors @@ -113,7 +108,6 @@ class Tape(object): op_type, [tid(t) for t in output_tensors], [tid(t) for t in input_tensors], - side_outputs, backward_function, [(_tensor_shape(t), t.dtype) for t in output_tensors]) self._next_op_id += 1 @@ -227,13 +221,11 @@ def should_record(tensors): return any(x.should_record(tensors) for x in _tape_stack.stack) -def record_operation(op_type, output_tensors, input_tensors, side_outputs, - backward_function): +def record_operation(op_type, output_tensors, input_tensors, backward_function): """Records the operation on all tapes in the stack.""" for t in _tape_stack.stack: t.record_operation(op_type, output_tensors, input_tensors, - side_outputs, backward_function) @@ -243,11 +235,6 @@ def delete_trace(tensor_id): t.delete_trace(tensor_id) -def top_tape_watched_tensors(): - t = _tape_stack.stack[-1] - return t._watched # pylint: disable=protected-access - - def top_tape_watched_variables(): t = _tape_stack.stack[-1] return t._watched_variables # pylint: disable=protected-access |