diff options
Diffstat (limited to 'tensorflow/python/eager/pywrap_tfe.h')
-rwxr-xr-x | tensorflow/python/eager/pywrap_tfe.h | 25 |
1 files changed, 16 insertions, 9 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 16f8c3c917..f1b4042ec9 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -59,6 +59,10 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); // This function is not thread-safe. PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e); +// Registers e as the VSpace to use. +// `vspace` must be a imperative_grad.py:VSpace named tuple. +PyObject* TFE_Py_RegisterVSpace(PyObject* e); + // Registers e as the Exception to be raised when the conditions of // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it // is a signal to the calling code that it should fall back to the safer (and @@ -124,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); // To unset the profiler, pass Py_None as the value of `profiler`. PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler); -// Creates a new tape and adds it to the active set. `persistent` must be a -// PyBool_Type, i.e either Py_True or Py_False -PyObject* TFE_Py_TapeSetNew(PyObject* persistent); +// Creates a new tape and adds it to the active set. `persistent` and +// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`). +PyObject* TFE_Py_TapeSetNew(PyObject* persistent, + PyObject* watch_accessed_variables); // Removes the passed tape from the set of active tapes. void TFE_Py_TapeSetRemove(PyObject* tape); @@ -158,18 +163,20 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, PyObject* input_tensor_ids, PyObject* backward_function); +// Notifies all tapes that a variable has been accessed. +void TFE_Py_TapeVariableAccessed(PyObject* variable); + // Watches the given variable object on the given tape. -void TFE_Py_TapeSetWatchVariable(PyObject* variable); +void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable); // Computes a gradient based on information recorded on the tape.`tape` must -// have been produced by TFE_Py_NewTape. `vspace` must be a -// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python +// have been produced by TFE_Py_NewTape. `target` and `sources` must be python // lists of Tensor objects. `output_gradients` is either None or a python list // of either Tensor or None, and if not None should have the same length as // target. -PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, - PyObject* target, PyObject* sources, - PyObject* output_gradients, TF_Status* status); +PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, + PyObject* sources, PyObject* output_gradients, + TF_Status* status); // Execute a tensorflow operation assuming that all provided inputs are // correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors, |