aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Ayush Dubey <ayushd@google.com>2018-06-27 11:56:43 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commitd75edc93bfaf83aacbac4d25d0161141c7c928b0 (patch)
tree730554eb4c4a56da74f95207649d78bb1824b23b /tensorflow/core/distributed_runtime
parent6d5668ab82cd40844c868c9a9b2433af51272857 (diff)
Fix synchronization across callbacks in collective params initialization.
During initialization of local collective params, we may issue RPCs to other workers in order to obtain device localities. Currently, we hold a mutex across these RPCs, but we do not ensure that the thread that unlocks the mutex is the same as the one that locked it. This change releases the mutex (InstanceRec::out_mu) before calling GetDeviceLocalitiesAsync. Before releasing out_mu, it marks the mutex unavailable. Any thread that wishes to acquire out_mu must wait on a condition variable if the mutex is unavailable. The callback for GetDeviceLocalitiesAsync marks the mutex as available again and notifies the condition variable. PiperOrigin-RevId: 202346357
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc2
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 612ac14e22..422d142f04 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -176,6 +176,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
const Status& fi_status, InstanceRec* ir) {
if (fi_status.ok()) {
mutex_lock l(ir->out_mu);
+ ir->WaitForOutMu(l);
response->set_instance_key(cp->instance.instance_key);
response->set_source_rank(ir->source_rank);
done_and_cleanup(fi_status);
@@ -289,6 +290,7 @@ void CollectiveParamResolverDistributed::UpdateInstanceCache(
Status status;
do {
mutex_lock l(ir->out_mu);
+ ir->WaitForOutMu(l);
if (ir->source_rank != source_rank) {
if (ir->source_rank >= 0) {
ir->status = errors::Internal(