aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-06-11 10:42:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 10:45:04 -0700
commit6aeab4f5402f56e4b30540db0847256362c15e32 (patch)
tree40f916c06bc3c90a77bf062d043849c475302d90
parent7b8c64ef05c7fdddb3f3a32fd3189e1e4b7e8985 (diff)
Don't call back into python during insert (which will leave the set in a broken condition if the runtime decides to let another thread run).
Thank you for finding the bug. The watched_variables_ set should not really require a lock since all our functions hold the GIL (verified by looking at the generated SWIG). The reason that there was a concurrent access to the set is that the insert was calling back into python (which might release the GIL and let another thread run, which will also attempt to insert a variable and break the set). I included the lock to be safe though, since its non-trivial to verify without looking at the generated swig wrappers that the GIL is held. PiperOrigin-RevId: 200074843
-rw-r--r--tensorflow/contrib/distribute/python/BUILD1
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc82
2 files changed, 43 insertions, 40 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 9624abd199..b572512bbb 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -312,7 +312,6 @@ cuda_py_test(
tags = [
"multi_and_single_gpu",
"no_pip",
- "noguitar", # TODO(b/109653107): test is flaky.
],
)
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index e3ce0ef9d0..52b3268903 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -873,22 +873,6 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
-static tensorflow::int64 FastHandleId(PyObject* variable) {
- PyObject* handle = PyObject_GetAttrString(variable, "handle");
- if (handle == nullptr) {
- return -1;
- }
- tensorflow::int64 id = FastTensorId(handle);
- Py_DECREF(handle);
- return id;
-}
-
-struct CompareByHandleId {
- bool operator()(PyObject* lhs, PyObject* rhs) {
- return FastHandleId(lhs) < FastHandleId(rhs);
- }
-};
-
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
@@ -897,35 +881,63 @@ class GradientTape
persistent) {}
virtual ~GradientTape() {
- for (PyObject* v : watched_variables_) {
- Py_DECREF(v);
+ for (const IdAndVariable& v : watched_variables_) {
+ Py_DECREF(v.variable);
}
}
void WatchVariable(PyObject* v) {
- auto insert_result = watched_variables_.insert(v);
- if (insert_result.second) {
- // Only increment the reference count if we aren't already watching this
- // variable.
- Py_INCREF(v);
- }
- PyObject* handle = PyObject_GetAttrString(v, "handle");
+ tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
- tensorflow::int64 id = FastTensorId(handle);
- Py_DECREF(handle);
+ tensorflow::int64 id = FastTensorId(handle.get());
+
if (!PyErr_Occurred()) {
this->Watch(id);
}
+
+ tensorflow::mutex_lock l(watched_variables_mu_);
+ auto insert_result = watched_variables_.emplace(id, v);
+
+ if (insert_result.second) {
+ // Only increment the reference count if we aren't already watching this
+ // variable.
+ Py_INCREF(v);
+ }
}
- const std::set<PyObject*, CompareByHandleId> WatchedVariables() {
- return watched_variables_;
+ PyObject* GetVariablesAsPyTuple() {
+ tensorflow::mutex_lock l(watched_variables_mu_);
+ PyObject* result = PyTuple_New(watched_variables_.size());
+ Py_ssize_t pos = 0;
+ for (const IdAndVariable& id_and_variable : watched_variables_) {
+ PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
+ Py_INCREF(id_and_variable.variable);
+ }
+ return result;
}
private:
- std::set<PyObject*, CompareByHandleId> watched_variables_;
+ // We store an IdAndVariable in the map since the map needs to be locked
+ // during insert, but should not call back into python during insert to avoid
+ // deadlocking with the GIL.
+ struct IdAndVariable {
+ tensorflow::int64 id;
+ PyObject* variable;
+
+ IdAndVariable(tensorflow::int64 id, PyObject* variable)
+ : id(id), variable(variable) {}
+ };
+ struct CompareById {
+ bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) {
+ return lhs.id < rhs.id;
+ }
+ };
+
+ tensorflow::mutex watched_variables_mu_;
+ std::set<IdAndVariable, CompareById> watched_variables_
+ GUARDED_BY(watched_variables_mu_);
};
typedef struct {
@@ -1217,15 +1229,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
- const auto& watched_variables =
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables();
- PyObject* result = PyTuple_New(watched_variables.size());
- Py_ssize_t pos = 0;
- for (PyObject* variable : watched_variables) {
- PyTuple_SET_ITEM(result, pos++, variable);
- Py_INCREF(variable);
- }
- return result;
+ return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
namespace {