diff options
author | Alexandre Passos <apassos@google.com> | 2018-01-08 17:59:58 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-08 18:03:58 -0800 |
commit | caa2a1b856d1080cfec26b6ab5756aa49114597e (patch) | |
tree | 447cce319a13ba32caf20c83d2497e4b292c7c86 /tensorflow/python/pywrap_tfe.i | |
parent | 6fda97ab7540fb003d46f7c9810d6aab6dbc6c19 (diff) |
Fix the threading model of gradient tapes.
The set of tapes needs to be global to enable multithreaded programming
(when it's natural for tensors to cross threads during reduction operations)
but each thread still needs to be able to locally pause recording while
it does gradient-related bookkeeping (like custom gradients or initialization).
Also removes a mutex from the thread-local structure since it's unnecessary
as we're always holding the GIL while calling across the python-c boundary
unless we explicitly release it.
PiperOrigin-RevId: 181246570
Diffstat (limited to 'tensorflow/python/pywrap_tfe.i')
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index d97823c17f..42e4773df3 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -26,15 +26,16 @@ limitations under the License. %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_UID; -%rename("%s") TFE_Py_TapeStackPushNew; -%rename("%s") TFE_Py_TapeStackPush; -%rename("%s") TFE_Py_TapeStackPop; -%rename("%s") TFE_Py_TapeStackIsEmpty; -%rename("%s") TFE_Py_TapeStackShouldRecord; -%rename("%s") TFE_Py_TapeStackWatch; -%rename("%s") TFE_Py_TapeStackDeleteTrace; -%rename("%s") TFE_Py_TapeStackRecordOperation; -%rename("%s") TFE_Py_TapeStackWatchVariable; +%rename("%s") TFE_Py_TapeSetNew; +%rename("%s") TFE_Py_TapeSetRemove; +%rename("%s") TFE_Py_TapeSetStopOnThread; +%rename("%s") TFE_Py_TapeSetRestartOnThread; +%rename("%s") TFE_Py_TapeSetIsEmpty; +%rename("%s") TFE_Py_TapeSetShouldRecord; +%rename("%s") TFE_Py_TapeSetWatch; +%rename("%s") TFE_Py_TapeSetDeleteTrace; +%rename("%s") TFE_Py_TapeSetRecordOperation; +%rename("%s") TFE_Py_TapeSetWatchVariable; %rename("%s") TFE_Py_TapeGradient; %rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; |