aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-09 16:24:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-09 16:28:26 -0700
commit88145023cea47b4a96cc04f8febe205d50a0d0d6 (patch)
tree64afb7f5b7b32f57be42ac88f951994a4e187ef6 /tensorflow/python/eager/tape.py
parent33d55122d994d12f2a066f9ec4f0f03094a59579 (diff)
Removing side outputs from tape code.
They belong better in future function objects (simplifies tape move to C) PiperOrigin-RevId: 171603665
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r--tensorflow/python/eager/tape.py19
1 files changed, 3 insertions, 16 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 84814d48fd..4578a7190d 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -32,7 +32,7 @@ def tid(tensor):
class TapeEntry(
collections.namedtuple("TapeEntry", [
"op_type",
- "output_ids", "input_ids", "side_outputs", "backward_function",
+ "output_ids", "input_ids", "backward_function",
"output_shape_and_dtype",
])):
"""Entry in the gradient tape.
@@ -43,8 +43,6 @@ class TapeEntry(
Args:
output_ids: tensor_id(t) for each output tensor T
input_ids: tensor_id(t) for each input tensor T
- side_outputs: optional tensors (not IDs) which need to be provided to the
- backward function.
backward_function: function to be called with the downstream gradients and
side outputs as arguments which computes the backward pass.
output_shape_and_dtype: a list of (shape_tuple, dtype) for every output
@@ -69,8 +67,6 @@ class Tape(object):
self._op_tape = {}
# next operation ID
self._next_op_id = 0
- # List of directly watched tensors
- self._watched = []
# Set of directly watched variables
self._watched_variables = set()
@@ -91,14 +87,13 @@ class Tape(object):
if i not in self._tensor_tape:
self._tensor_tape[i] = None
self._tensor_usage[i] = 1
- self._watched.append(tensor)
def watch_variable(self, v):
self._watched_variables.add(v)
self.watch(v.handle)
def record_operation(self, op_type, output_tensors, input_tensors,
- side_outputs, backward_function):
+ backward_function):
"""Records an operation in the tape."""
if not self.should_record(input_tensors):
return output_tensors
@@ -113,7 +108,6 @@ class Tape(object):
op_type,
[tid(t) for t in output_tensors],
[tid(t) for t in input_tensors],
- side_outputs,
backward_function,
[(_tensor_shape(t), t.dtype) for t in output_tensors])
self._next_op_id += 1
@@ -227,13 +221,11 @@ def should_record(tensors):
return any(x.should_record(tensors) for x in _tape_stack.stack)
-def record_operation(op_type, output_tensors, input_tensors, side_outputs,
- backward_function):
+def record_operation(op_type, output_tensors, input_tensors, backward_function):
"""Records the operation on all tapes in the stack."""
for t in _tape_stack.stack:
t.record_operation(op_type, output_tensors,
input_tensors,
- side_outputs,
backward_function)
@@ -243,11 +235,6 @@ def delete_trace(tensor_id):
t.delete_trace(tensor_id)
-def top_tape_watched_tensors():
- t = _tape_stack.stack[-1]
- return t._watched # pylint: disable=protected-access
-
-
def top_tape_watched_variables():
t = _tape_stack.stack[-1]
return t._watched_variables # pylint: disable=protected-access