aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/pywrap_tfe.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/pywrap_tfe.h')
-rwxr-xr-xtensorflow/python/eager/pywrap_tfe.h25
1 files changed, 16 insertions, 9 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 16f8c3c917..f1b4042ec9 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -59,6 +59,10 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
+// Registers e as the VSpace to use.
+// `vspace` must be a imperative_grad.py:VSpace named tuple.
+PyObject* TFE_Py_RegisterVSpace(PyObject* e);
+
// Registers e as the Exception to be raised when the conditions of
// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
// is a signal to the calling code that it should fall back to the safer (and
@@ -124,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
// To unset the profiler, pass Py_None as the value of `profiler`.
PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
-// Creates a new tape and adds it to the active set. `persistent` must be a
-// PyBool_Type, i.e either Py_True or Py_False
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
+// Creates a new tape and adds it to the active set. `persistent` and
+// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables);
// Removes the passed tape from the set of active tapes.
void TFE_Py_TapeSetRemove(PyObject* tape);
@@ -158,18 +163,20 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
+// Notifies all tapes that a variable has been accessed.
+void TFE_Py_TapeVariableAccessed(PyObject* variable);
+
// Watches the given variable object on the given tape.
-void TFE_Py_TapeSetWatchVariable(PyObject* variable);
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
-// have been produced by TFE_Py_NewTape. `vspace` must be a
-// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// have been produced by TFE_Py_NewTape. `target` and `sources` must be python
// lists of Tensor objects. `output_gradients` is either None or a python list
// of either Tensor or None, and if not None should have the same length as
// target.
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
- PyObject* target, PyObject* sources,
- PyObject* output_gradients, TF_Status* status);
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+ PyObject* sources, PyObject* output_gradients,
+ TF_Status* status);
// Execute a tensorflow operation assuming that all provided inputs are
// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,