diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-01-12 08:58:58 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-12 09:06:28 -0800 |
commit | dd9684fd9313f21ed66b02b816f2d806a1eccfd7 (patch) | |
tree | ee5d81398350602a7b082a035d082a656476fc0b /tensorflow/core/kernels/training_ops.cc | |
parent | e8f2aad0c0502fde74fc629f5b13f04d5d206700 (diff) |
Kernels and ops for all optimizers when using resource variables.
Only enables for gradient descent so far.
Change: 144331045
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/training_ops.cc | 441 |
1 files changed, 314 insertions, 127 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 641c991a7e..cbc44017dc 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -292,10 +293,26 @@ struct ApplyCenteredRMSProp<CPUDevice, T> { } // namespace functor +mutex* GetMutex(OpKernelContext* ctx, int input) { + if (ctx->input_dtype(input) == DT_RESOURCE) { + Var* var; + if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) { + return var->mu(); + } else { + ctx->CtxFailureWithWarning( + errors::Internal("Invalid variable reference.")); + return nullptr; + } + } + return ctx->input_ref_mutex(input); +} + // MaybeLockMutexesInOrder is a helper function to acquire mutexes in address -// order to mitigate deadlock. Returns a vector of acquired mutexes. -// Safe to pass duplicates - will only lock each distinct mutex once. -// If do_lock is false, returns immediately. +// order to mitigate deadlock. Returns a vector of acquired mutexes. Safe to +// pass duplicates - will only lock each distinct mutex once. If do_lock is +// false, returns immediately. Note that this silently doesn't lock mutexes for +// invalid variable references; in all usages this is followed by GetInputTensor +// which will signal a failure. std::vector<mutex_lock> MaybeLockMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) { std::vector<mutex_lock> locks; @@ -305,7 +322,7 @@ std::vector<mutex_lock> MaybeLockMutexesInOrder( std::vector<mutex*> mutexes; std::vector<int> acquire_order; for (auto input : input_ids) { - auto* mutex = ctx->input_ref_mutex(input); + mutex* mutex = GetMutex(ctx, input); // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { acquire_order.push_back(input); @@ -316,11 +333,41 @@ std::vector<mutex_lock> MaybeLockMutexesInOrder( [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); for (auto input : acquire_order) { - locks.emplace_back(*ctx->input_ref_mutex(input)); + mutex* mu = GetMutex(ctx, input); + if (mu != nullptr) { + locks.emplace_back(*mu); + } } return locks; } +Status GetInputTensor(OpKernelContext* ctx, int input, bool lock_held, + Tensor* out) { + if (ctx->input_dtype(input) == DT_RESOURCE) { + Var* var; + if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) { + if (lock_held) { + *out = *var->tensor(); + } else { + mutex_lock ml(*var->mu()); + *out = *var->tensor(); + } + return Status::OK(); + } else { + return errors::Internal("Invalid variable reference."); + } + } + *out = ctx->mutable_input(input, lock_held); + return Status::OK(); +} + +void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, + int output) { + if (ctx->input_dtype(input) != DT_RESOURCE) { + ctx->forward_ref_input_to_ref_output(input, output); + } +} + template <typename Device, typename T> class ApplyGradientDescentOp : public OpKernel { public: @@ -330,7 +377,8 @@ class ApplyGradientDescentOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -351,7 +399,7 @@ class ApplyGradientDescentOp : public OpKernel { functor::ApplyGradientDescent<Device, T>()( device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -361,7 +409,11 @@ class ApplyGradientDescentOp : public OpKernel { #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \ - ApplyGradientDescentOp<D##Device, T>); + ApplyGradientDescentOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceApplyGradientDescent") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T"), \ + ApplyGradientDescentOp<D##Device, T>); #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); TF_CALL_half(REGISTER_CPU_KERNELS); @@ -406,7 +458,7 @@ class ApplyAdadeltaOp : public OpKernel { void Compute(OpKernelContext* ctx) override { if (use_exclusive_lock_) { - mutex_lock l1(*ctx->input_ref_mutex(0)); + mutex_lock l1(*GetMutex(ctx, 0)); // Don't try to acquire a lock on the second ref as they share the same // mutex. // @@ -419,16 +471,20 @@ class ApplyAdadeltaOp : public OpKernel { if (!ctx->status().ok()) return; DoCompute(ctx); } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; void DoValidate(OpKernelContext* ctx) { - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); - Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); + Tensor accum_update; + OP_REQUIRES_OK(ctx, + GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -474,9 +530,13 @@ class ApplyAdadeltaOp : public OpKernel { void DoCompute(OpKernelContext* ctx) { const Device& device = ctx->template eigen_device<Device>(); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); - Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); + Tensor accum_update; + OP_REQUIRES_OK(ctx, + GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update)); const Tensor& lr = ctx->input(3); const Tensor& rho = ctx->input(4); @@ -492,9 +552,12 @@ class ApplyAdadeltaOp : public OpKernel { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -#define REGISTER_KERNELS(D, T) \ - REGISTER_KERNEL_BUILDER( \ - Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \ +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyAdadeltaOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ResourceApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \ ApplyAdadeltaOp<D##Device, T>); #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); @@ -536,7 +599,7 @@ class SparseApplyAdadeltaOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - mutex* mu_var = ctx->input_ref_mutex(0); + mutex* mu_var = GetMutex(ctx, 0); // mu_accum is actually the same mutex as mu_var since currently we use a // global mutex. // @@ -544,9 +607,14 @@ class SparseApplyAdadeltaOp : public OpKernel { if (use_exclusive_lock_) { mu_var->lock(); } - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum_grad = ctx->mutable_input(1, use_exclusive_lock_); - Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum_grad; + OP_REQUIRES_OK(ctx, + GetInputTensor(ctx, 1, use_exclusive_lock_, &accum_grad)); + Tensor accum_update; + OP_REQUIRES_OK(ctx, + GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -642,7 +710,7 @@ class SparseApplyAdadeltaOp : public OpKernel { mu_var->unlock(); } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -654,6 +722,11 @@ class SparseApplyAdadeltaOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyAdadeltaOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdadelta") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyAdadeltaOp<T, Tindices>); #define REGISTER_CPU_KERNELS(T) \ REGISTER_KERNELS(T, int32); \ @@ -677,7 +750,8 @@ class ApplyProximalGradientDescentOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -710,17 +784,21 @@ class ApplyProximalGradientDescentOp : public OpKernel { device, var.flat<T>(), alpha.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), delta.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; }; -#define REGISTER_KERNELS(D, T) \ - REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \ - .Device(DEVICE_##D) \ - .TypeConstraint<T>("T"), \ +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T"), \ + ApplyProximalGradientDescentOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalGradientDescent") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T"), \ ApplyProximalGradientDescentOp<D##Device, T>); REGISTER_KERNELS(CPU, float); @@ -738,7 +816,8 @@ class SparseApplyProximalGradientDescentOp : public OpKernel { void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), errors::InvalidArgument("var must be at least 1 dimensional")); @@ -846,18 +925,23 @@ class SparseApplyProximalGradientDescentOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; }; -#define REGISTER_KERNELS(T, Tindices) \ - REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .TypeConstraint<Tindices>("Tindices"), \ +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyProximalGradientDescentOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalGradientDescent") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyProximalGradientDescentOp<T, Tindices>); REGISTER_KERNELS(float, int32); @@ -875,8 +959,10 @@ class ApplyAdagradOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -905,7 +991,7 @@ class ApplyAdagradOp : public OpKernel { functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), grad.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -915,9 +1001,12 @@ class ApplyAdagradOp : public OpKernel { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -#define REGISTER_KERNELS(D, T) \ - REGISTER_KERNEL_BUILDER( \ - Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyAdagradOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ResourceApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ ApplyAdagradOp<D##Device, T>); #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); @@ -957,8 +1046,10 @@ class ApplyProximalAdagradOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1004,7 +1095,7 @@ class ApplyProximalAdagradOp : public OpKernel { device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), grad.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -1017,7 +1108,11 @@ using GPUDevice = Eigen::GpuDevice; #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyProximalAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ - ApplyProximalAdagradOp<D##Device, T>); + ApplyProximalAdagradOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalAdagrad") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T"), \ + ApplyProximalAdagradOp<D##Device, T>); REGISTER_KERNELS(CPU, float); REGISTER_KERNELS(CPU, double); @@ -1053,8 +1148,10 @@ class SparseApplyAdagradOp : public OpKernel { void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1142,7 +1239,7 @@ class SparseApplyAdagradOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -1154,6 +1251,11 @@ class SparseApplyAdagradOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyAdagradOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyAdagradOp<T, Tindices>); #define REGISTER_CPU_KERNELS(T) \ REGISTER_KERNELS(T, int32); \ @@ -1177,8 +1279,10 @@ class SparseApplyProximalAdagradOp : public OpKernel { void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1311,18 +1415,23 @@ class SparseApplyProximalAdagradOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; }; -#define REGISTER_KERNELS(T, Tindices) \ - REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .TypeConstraint<Tindices>("Tindices"), \ +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyProximalAdagradOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalAdagrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyProximalAdagradOp<T, Tindices>); REGISTER_KERNELS(float, int32); @@ -1340,9 +1449,14 @@ class ApplyAdagradDAOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_); - Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor gradient_accum; + OP_REQUIRES_OK( + ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum)); + Tensor gradient_squared_accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, + &gradient_squared_accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1399,7 +1513,7 @@ class ApplyAdagradDAOp : public OpKernel { global_step.scalar<int64>()(), l1.scalar<T>(), l2.scalar<T>(), grad.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -1428,9 +1542,14 @@ class SparseApplyAdagradDAOp : public OpKernel { void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_); - Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor gradient_accum; + OP_REQUIRES_OK( + ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum)); + Tensor gradient_squared_accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, + &gradient_squared_accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1580,7 +1699,7 @@ class SparseApplyAdagradDAOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -1592,6 +1711,11 @@ class SparseApplyAdagradDAOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyAdagradDAOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradDA") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyAdagradDAOp<T, Tindices>); REGISTER_KERNELS(float, int32); @@ -1610,9 +1734,12 @@ class ApplyFtrlOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); - Tensor linear = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); + Tensor linear; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1677,7 +1804,7 @@ class ApplyFtrlOp : public OpKernel { lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), lr_power.scalar<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -1687,9 +1814,12 @@ class ApplyFtrlOp : public OpKernel { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -#define REGISTER_KERNELS(D, T) \ - REGISTER_KERNEL_BUILDER( \ - Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \ +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyFtrlOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ResourceApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \ ApplyFtrlOp<D##Device, T>); #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); @@ -1710,9 +1840,12 @@ class SparseApplyFtrlOp : public OpKernel { void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); - Tensor linear = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); + Tensor linear; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1874,18 +2007,23 @@ class SparseApplyFtrlOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; }; -#define REGISTER_KERNELS(T, Tindices) \ - REGISTER_KERNEL_BUILDER(Name("SparseApplyFtrl") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .TypeConstraint<Tindices>("Tindices"), \ +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyFtrl") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyFtrlOp<CPUDevice, T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyFtrl") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyFtrlOp<CPUDevice, T, Tindices>); #define REGISTER_CPU_KERNELS(T) \ REGISTER_KERNELS(T, int32); \ @@ -1909,8 +2047,10 @@ class ApplyMomentumOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1944,7 +2084,7 @@ class ApplyMomentumOp : public OpKernel { functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), grad.flat<T>(), momentum.scalar<T>(), use_nesterov_); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -1955,9 +2095,12 @@ class ApplyMomentumOp : public OpKernel { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -#define REGISTER_KERNELS(D, T) \ - REGISTER_KERNEL_BUILDER( \ - Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \ +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyMomentumOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ResourceApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \ ApplyMomentumOp<D##Device, T>); #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); @@ -2001,8 +2144,10 @@ class SparseApplyMomentumOp : public OpKernel { void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor accum; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2072,7 +2217,7 @@ class SparseApplyMomentumOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -2085,6 +2230,11 @@ class SparseApplyMomentumOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyMomentumOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyMomentum") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyMomentumOp<T, Tindices>); #define REGISTER_CPU_KERNELS(T) \ REGISTER_KERNELS(T, int32); \ @@ -2107,9 +2257,12 @@ class ApplyAdamOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor m = ctx->mutable_input(1, use_exclusive_lock_); - Tensor v = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor m; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &m)); + Tensor v; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &v)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2171,7 +2324,7 @@ class ApplyAdamOp : public OpKernel { beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -2181,9 +2334,12 @@ class ApplyAdamOp : public OpKernel { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -#define REGISTER_KERNELS(D, T) \ - REGISTER_KERNEL_BUILDER( \ - Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \ +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyAdamOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ResourceApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \ ApplyAdamOp<D##Device, T>); #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); @@ -2236,9 +2392,12 @@ class ApplyRMSPropOp : public OpKernel { void Compute(OpKernelContext* ctx) override { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); - Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor ms; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms)); + Tensor mom; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -2294,7 +2453,7 @@ class ApplyRMSPropOp : public OpKernel { rho.scalar<T>(), momentum.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -2312,10 +2471,14 @@ class ApplyCenteredRMSPropOp : public OpKernel { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor mg = ctx->mutable_input(1, use_exclusive_lock_); - Tensor ms = ctx->mutable_input(2, use_exclusive_lock_); - Tensor mom = ctx->mutable_input(3, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor mg; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg)); + Tensor ms; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms)); + Tensor mom; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -2379,7 +2542,7 @@ class ApplyCenteredRMSPropOp : public OpKernel { device, var.flat<T>(), mg.flat<T>(), ms.flat<T>(), mom.flat<T>(), lr.scalar<T>(), rho.scalar<T>(), momentum.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>()); - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -2395,7 +2558,14 @@ using GPUDevice = Eigen::GpuDevice; ApplyRMSPropOp<D##Device, T>); \ REGISTER_KERNEL_BUILDER( \ Name("ApplyCenteredRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \ - ApplyCenteredRMSPropOp<D##Device, T>); + ApplyCenteredRMSPropOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("ResourceApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyRMSPropOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceApplyCenteredRMSProp") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T"), \ + ApplyCenteredRMSPropOp<D##Device, T>); #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); TF_CALL_half(REGISTER_CPU_KERNELS); @@ -2449,9 +2619,12 @@ class SparseApplyRMSPropOp : public OpKernel { void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); - Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor ms; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms)); + Tensor mom; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -2552,7 +2725,7 @@ class SparseApplyRMSPropOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: @@ -2572,10 +2745,14 @@ class SparseApplyCenteredRMSPropOp : public OpKernel { auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3}); - Tensor var = ctx->mutable_input(0, use_exclusive_lock_); - Tensor mg = ctx->mutable_input(1, use_exclusive_lock_); - Tensor ms = ctx->mutable_input(2, use_exclusive_lock_); - Tensor mom = ctx->mutable_input(3, use_exclusive_lock_); + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var)); + Tensor mg; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg)); + Tensor ms; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms)); + Tensor mom; + OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -2685,23 +2862,33 @@ class SparseApplyCenteredRMSPropOp : public OpKernel { } } - ctx->forward_ref_input_to_ref_output(0, 0); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); } private: bool use_exclusive_lock_; }; -#define REGISTER_KERNELS(T, Tindices) \ - REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .TypeConstraint<Tindices>("Tindices"), \ - SparseApplyRMSPropOp<T, Tindices>); \ - REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .TypeConstraint<Tindices>("Tindices"), \ +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyRMSPropOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyCenteredRMSPropOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyRMSProp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyRMSPropOp<T, Tindices>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyCenteredRMSProp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ SparseApplyCenteredRMSPropOp<T, Tindices>); REGISTER_KERNELS(Eigen::half, int32); |