aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_op_helpers.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_op_helpers.cc')
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc45
1 files changed, 28 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index 83b83fcdb9..4262a5404b 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -15,14 +15,16 @@ limitations under the License.
#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/util/ptr_util.h"
+
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource) {
+ *maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
- Var* var;
- if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
- core::ScopedUnref scoped_unref(var);
- return var->mu();
+ if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
+ return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
@@ -33,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
-// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
-// Safe to pass duplicates - will only lock each distinct mutex once. If
-// do_lock is false, returns immediately. Note that this silently doesn't lock
-// mutexes for invalid variable references; in all usages this is followed by
-// GetInputTensor which will signal a failure.
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// in address order to mitigate deadlock. Returns a structure that, when
+// deleted, will release the acquired mutexes. Safe to pass duplicates - will
+// only lock each distinct mutex once. If do_lock is false, returns
+// immediately. Note that this silently doesn't lock mutexes for invalid
+// variable references; in all usages this is followed by GetInputTensor which
+// will signal a failure.
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
@@ -47,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
break;
}
}
- std::vector<mutex_lock> locks;
if (!do_lock && !any_resource) {
- return locks;
+ return VariableInputLockHolder({}, {});
}
+ std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
- mutex* mutex = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
+ if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
@@ -64,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
+ std::unique_ptr<std::vector<mutex_lock>> locks =
+ MakeUnique<std::vector<mutex_lock>>();
+ locks->reserve(acquire_order.size());
+
for (auto input : acquire_order) {
- mutex* mu = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
+ core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
- locks.emplace_back(*mu);
+ locks->emplace_back(*mu);
}
}
- return locks;
+ return VariableInputLockHolder(std::move(vars), std::move(locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,