diff options
Diffstat (limited to 'tensorflow/core/kernels/training_op_helpers.h')
-rw-r--r-- | tensorflow/core/kernels/training_op_helpers.h | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h index 071cb371a7..9f173a80f7 100644 --- a/tensorflow/core/kernels/training_op_helpers.h +++ b/tensorflow/core/kernels/training_op_helpers.h @@ -23,9 +23,42 @@ limitations under the License. namespace tensorflow { -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input); +// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. +// +// If `input` corresponds to a `DT_RESOURCE`-type variable input, +// `*maybe_resource` will be updated to contain the underlying resource, and the +// caller will be responsible for calling `Unref()` on that resource. +mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, + Var** maybe_resource); -std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( +// Utility structure that releases a sequence of borrowed mutexes when it is +// deleted. +struct VariableInputLockHolder { + public: + VariableInputLockHolder(std::vector<Var*> vars, + std::unique_ptr<std::vector<mutex_lock>> locks) + : vars_(std::move(vars)), locks_(std::move(locks)) {} + + VariableInputLockHolder(VariableInputLockHolder&& other) + : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {} + + ~VariableInputLockHolder() { + // Release the locks before unreffing the Vars, because each lock + // is potentially borrowed from a Var in vars_. + locks_.reset(); + for (Var* var : vars_) { + var->Unref(); + } + } + + private: + std::vector<Var*> vars_; + // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly, + // because a `std::vector<mutex_lock>` is not movable on all platforms. + std::unique_ptr<std::vector<mutex_lock>> locks_; +}; + +VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids); void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, |