aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_op_helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_op_helpers.h')
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h37
1 files changed, 35 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 071cb371a7..9f173a80f7 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -23,9 +23,42 @@ limitations under the License.
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
+// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
+//
+// If `input` corresponds to a `DT_RESOURCE`-type variable input,
+// `*maybe_resource` will be updated to contain the underlying resource, and the
+// caller will be responsible for calling `Unref()` on that resource.
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource);
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// Utility structure that releases a sequence of borrowed mutexes when it is
+// deleted.
+struct VariableInputLockHolder {
+ public:
+ VariableInputLockHolder(std::vector<Var*> vars,
+ std::unique_ptr<std::vector<mutex_lock>> locks)
+ : vars_(std::move(vars)), locks_(std::move(locks)) {}
+
+ VariableInputLockHolder(VariableInputLockHolder&& other)
+ : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
+
+ ~VariableInputLockHolder() {
+ // Release the locks before unreffing the Vars, because each lock
+ // is potentially borrowed from a Var in vars_.
+ locks_.reset();
+ for (Var* var : vars_) {
+ var->Unref();
+ }
+ }
+
+ private:
+ std::vector<Var*> vars_;
+ // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
+ // because a `std::vector<mutex_lock>` is not movable on all platforms.
+ std::unique_ptr<std::vector<mutex_lock>> locks_;
+};
+
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,