aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/pywrap_tfe.i
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-08 17:59:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-08 18:03:58 -0800
commitcaa2a1b856d1080cfec26b6ab5756aa49114597e (patch)
tree447cce319a13ba32caf20c83d2497e4b292c7c86 /tensorflow/python/pywrap_tfe.i
parent6fda97ab7540fb003d46f7c9810d6aab6dbc6c19 (diff)
Fix the threading model of gradient tapes.
The set of tapes needs to be global to enable multithreaded programming (when it's natural for tensors to cross threads during reduction operations) but each thread still needs to be able to locally pause recording while it does gradient-related bookkeeping (like custom gradients or initialization). Also removes a mutex from the thread-local structure since it's unnecessary as we're always holding the GIL while calling across the python-c boundary unless we explicitly release it. PiperOrigin-RevId: 181246570
Diffstat (limited to 'tensorflow/python/pywrap_tfe.i')
-rw-r--r--tensorflow/python/pywrap_tfe.i19
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index d97823c17f..42e4773df3 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -26,15 +26,16 @@ limitations under the License.
%rename("%s") TFE_Py_RegisterExceptionClass;
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_Py_UID;
-%rename("%s") TFE_Py_TapeStackPushNew;
-%rename("%s") TFE_Py_TapeStackPush;
-%rename("%s") TFE_Py_TapeStackPop;
-%rename("%s") TFE_Py_TapeStackIsEmpty;
-%rename("%s") TFE_Py_TapeStackShouldRecord;
-%rename("%s") TFE_Py_TapeStackWatch;
-%rename("%s") TFE_Py_TapeStackDeleteTrace;
-%rename("%s") TFE_Py_TapeStackRecordOperation;
-%rename("%s") TFE_Py_TapeStackWatchVariable;
+%rename("%s") TFE_Py_TapeSetNew;
+%rename("%s") TFE_Py_TapeSetRemove;
+%rename("%s") TFE_Py_TapeSetStopOnThread;
+%rename("%s") TFE_Py_TapeSetRestartOnThread;
+%rename("%s") TFE_Py_TapeSetIsEmpty;
+%rename("%s") TFE_Py_TapeSetShouldRecord;
+%rename("%s") TFE_Py_TapeSetWatch;
+%rename("%s") TFE_Py_TapeSetDeleteTrace;
+%rename("%s") TFE_Py_TapeSetRecordOperation;
+%rename("%s") TFE_Py_TapeSetWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;