aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/collective_param_resolver_local.h
diff options
context:
space:
mode:
authorGravatar Ayush Dubey <ayushd@google.com>2018-06-27 11:56:43 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commitd75edc93bfaf83aacbac4d25d0161141c7c928b0 (patch)
tree730554eb4c4a56da74f95207649d78bb1824b23b /tensorflow/core/common_runtime/collective_param_resolver_local.h
parent6d5668ab82cd40844c868c9a9b2433af51272857 (diff)
Fix synchronization across callbacks in collective params initialization.
During initialization of local collective params, we may issue RPCs to other workers in order to obtain device localities. Currently, we hold a mutex across these RPCs, but we do not ensure that the thread that unlocks the mutex is the same as the one that locked it. This change releases the mutex (InstanceRec::out_mu) before calling GetDeviceLocalitiesAsync. Before releasing out_mu, it marks the mutex unavailable. Any thread that wishes to acquire out_mu must wait on a condition variable if the mutex is unavailable. The callback for GetDeviceLocalitiesAsync marks the mutex as available again and notifies the condition variable. PiperOrigin-RevId: 202346357
Diffstat (limited to 'tensorflow/core/common_runtime/collective_param_resolver_local.h')
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h20
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 43c404f2ec..0be16cb9a8 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -88,7 +88,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// permit mutex locks to be taken in more than one order.
//
// out_mu guards access to most of the fields.
- // in_mu guards access to a queue of comsumer callbacks wanting to
+ // in_mu guards access to a queue of consumer callbacks wanting to
// read the fields guarded by out_mu.
//
// The in_mu should be locked only while holding instance_mu_; the
@@ -109,8 +109,12 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
bool is_init GUARDED_BY(in_mu);
std::vector<IRConsumer> init_waiters GUARDED_BY(in_mu);
- // Values to be shared by all instances, constant after initialization.
+ // A thread that wishes to acquire out_mu must ensure that it is available
+ // by invoking WaitForOutMu().
mutex out_mu;
+ condition_variable out_cv;
+ bool out_mu_available GUARDED_BY(out_mu);
+ // Values to be shared by all instances, constant after initialization.
CollectiveParams shared GUARDED_BY(out_mu);
// If an error occurs during initialization this structure stays in
// the table with a non-OK status. Purging the table and restarting
@@ -124,7 +128,15 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
std::vector<bool> known GUARDED_BY(out_mu);
std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu);
- InstanceRec() : is_init(false), source_rank(-1), known_count(0) {}
+ InstanceRec()
+ : is_init(false),
+ out_mu_available(true),
+ source_rank(-1),
+ known_count(0) {}
+
+ // If out_mu is unavailable during distributed device locality
+ // initialization, wait on out_cv until it is available again.
+ void WaitForOutMu(mutex_lock& lock) EXCLUSIVE_LOCKS_REQUIRED(out_mu);
};
// Find the InstanceRec with the same instance_key as cp. If it doesn't
@@ -147,7 +159,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// cp is populated with all DeviceLocalities
void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp,
InstanceRec* ir, const StatusCallback& done)
- EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
+ UNLOCK_FUNCTION(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
void CallInitInstanceSharedParams(const GroupRec* gr,
const CollectiveParams* cp, InstanceRec* ir,