aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r--tensorflow/python/eager/tape.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index a06f5e1a67..c16aa8c2f7 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -72,7 +72,7 @@ class Tape(object):
True if any of the tensors is in the tape.
"""
return pywrap_tensorflow.TFE_Py_TapeShouldRecord(
- self._tape, tensors)
+ self._tape, [x._id for x in tensors]) # pylint: disable=protected-access
def watch(self, tensor):
"""Adds a tensor to the tape."""
@@ -99,6 +99,16 @@ class Tape(object):
"""Deletes any trace we have for this tensor."""
self._delete_tensor_id(tensor_id)
+ def export(self):
+ """Exports the internal state of this tape.
+
+ Returns:
+ tensor_tape: a map from tensor_id(tensor) to <identifier for op>
+ responsible for generating that tensor.
+ op_tape: a map from <identifier for op> to TapeEntry for that op.
+ """
+ return pywrap_tensorflow.TFE_Py_TapeExport(self._tape)
+
class _TapeStack(threading.local):