diff options
Diffstat (limited to 'tensorflow/core/kernels/training_op_helpers.cc')
-rw-r--r-- | tensorflow/core/kernels/training_op_helpers.cc | 45 |
1 files changed, 28 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc index 83b83fcdb9..4262a5404b 100644 --- a/tensorflow/core/kernels/training_op_helpers.cc +++ b/tensorflow/core/kernels/training_op_helpers.cc @@ -15,14 +15,16 @@ limitations under the License. #include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/util/ptr_util.h" + namespace tensorflow { -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) { +mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, + Var** maybe_resource) { + *maybe_resource = nullptr; if (ctx->input_dtype(input) == DT_RESOURCE) { - Var* var; - if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) { - core::ScopedUnref scoped_unref(var); - return var->mu(); + if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { + return (*maybe_resource)->mu(); } else { ctx->CtxFailureWithWarning( errors::Internal("Invalid variable reference.")); @@ -33,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) { } // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes -// in address order to mitigate deadlock. Returns a vector of acquired mutexes. -// Safe to pass duplicates - will only lock each distinct mutex once. If -// do_lock is false, returns immediately. Note that this silently doesn't lock -// mutexes for invalid variable references; in all usages this is followed by -// GetInputTensor which will signal a failure. -std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( +// in address order to mitigate deadlock. Returns a structure that, when +// deleted, will release the acquired mutexes. Safe to pass duplicates - will +// only lock each distinct mutex once. If do_lock is false, returns +// immediately. Note that this silently doesn't lock mutexes for invalid +// variable references; in all usages this is followed by GetInputTensor which +// will signal a failure. +VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) { bool any_resource = false; for (auto i : input_ids) { @@ -47,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( break; } } - std::vector<mutex_lock> locks; if (!do_lock && !any_resource) { - return locks; + return VariableInputLockHolder({}, {}); } + std::vector<Var*> vars; std::vector<mutex*> mutexes; std::vector<int> acquire_order; for (auto input : input_ids) { - mutex* mutex = GetTrainingVariableMutex(ctx, input); + Var* var; + mutex* mutex = GetTrainingVariableMutex(ctx, input, &var); + if (var) vars.push_back(var); // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { acquire_order.push_back(mutexes.size()); @@ -64,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( std::sort(acquire_order.begin(), acquire_order.end(), [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); + std::unique_ptr<std::vector<mutex_lock>> locks = + MakeUnique<std::vector<mutex_lock>>(); + locks->reserve(acquire_order.size()); + for (auto input : acquire_order) { - mutex* mu = GetTrainingVariableMutex(ctx, input); + Var* var; + mutex* mu = GetTrainingVariableMutex(ctx, input, &var); + core::ScopedUnref scoped_unref(var); if (mu != nullptr) { - locks.emplace_back(*mu); + locks->emplace_back(*mu); } } - return locks; + return VariableInputLockHolder(std::move(vars), std::move(locks)); } void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, |