aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-12 08:58:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-12 09:06:28 -0800
commitdd9684fd9313f21ed66b02b816f2d806a1eccfd7 (patch)
treeee5d81398350602a7b082a035d082a656476fc0b /tensorflow/core/kernels/training_ops.cc
parente8f2aad0c0502fde74fc629f5b13f04d5d206700 (diff)
Kernels and ops for all optimizers when using resource variables.
Only enables for gradient descent so far. Change: 144331045
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r--tensorflow/core/kernels/training_ops.cc441
1 files changed, 314 insertions, 127 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 641c991a7e..cbc44017dc 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
@@ -292,10 +293,26 @@ struct ApplyCenteredRMSProp<CPUDevice, T> {
} // namespace functor
+mutex* GetMutex(OpKernelContext* ctx, int input) {
+ if (ctx->input_dtype(input) == DT_RESOURCE) {
+ Var* var;
+ if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
+ return var->mu();
+ } else {
+ ctx->CtxFailureWithWarning(
+ errors::Internal("Invalid variable reference."));
+ return nullptr;
+ }
+ }
+ return ctx->input_ref_mutex(input);
+}
+
// MaybeLockMutexesInOrder is a helper function to acquire mutexes in address
-// order to mitigate deadlock. Returns a vector of acquired mutexes.
-// Safe to pass duplicates - will only lock each distinct mutex once.
-// If do_lock is false, returns immediately.
+// order to mitigate deadlock. Returns a vector of acquired mutexes. Safe to
+// pass duplicates - will only lock each distinct mutex once. If do_lock is
+// false, returns immediately. Note that this silently doesn't lock mutexes for
+// invalid variable references; in all usages this is followed by GetInputTensor
+// which will signal a failure.
std::vector<mutex_lock> MaybeLockMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
std::vector<mutex_lock> locks;
@@ -305,7 +322,7 @@ std::vector<mutex_lock> MaybeLockMutexesInOrder(
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
- auto* mutex = ctx->input_ref_mutex(input);
+ mutex* mutex = GetMutex(ctx, input);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(input);
@@ -316,11 +333,41 @@ std::vector<mutex_lock> MaybeLockMutexesInOrder(
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
for (auto input : acquire_order) {
- locks.emplace_back(*ctx->input_ref_mutex(input));
+ mutex* mu = GetMutex(ctx, input);
+ if (mu != nullptr) {
+ locks.emplace_back(*mu);
+ }
}
return locks;
}
+Status GetInputTensor(OpKernelContext* ctx, int input, bool lock_held,
+ Tensor* out) {
+ if (ctx->input_dtype(input) == DT_RESOURCE) {
+ Var* var;
+ if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
+ if (lock_held) {
+ *out = *var->tensor();
+ } else {
+ mutex_lock ml(*var->mu());
+ *out = *var->tensor();
+ }
+ return Status::OK();
+ } else {
+ return errors::Internal("Invalid variable reference.");
+ }
+ }
+ *out = ctx->mutable_input(input, lock_held);
+ return Status::OK();
+}
+
+void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
+ int output) {
+ if (ctx->input_dtype(input) != DT_RESOURCE) {
+ ctx->forward_ref_input_to_ref_output(input, output);
+ }
+}
+
template <typename Device, typename T>
class ApplyGradientDescentOp : public OpKernel {
public:
@@ -330,7 +377,8 @@ class ApplyGradientDescentOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES(
ctx, var.IsInitialized(),
@@ -351,7 +399,7 @@ class ApplyGradientDescentOp : public OpKernel {
functor::ApplyGradientDescent<Device, T>()(
device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -361,7 +409,11 @@ class ApplyGradientDescentOp : public OpKernel {
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \
- ApplyGradientDescentOp<D##Device, T>);
+ ApplyGradientDescentOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceApplyGradientDescent") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
+ ApplyGradientDescentOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
TF_CALL_half(REGISTER_CPU_KERNELS);
@@ -406,7 +458,7 @@ class ApplyAdadeltaOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
if (use_exclusive_lock_) {
- mutex_lock l1(*ctx->input_ref_mutex(0));
+ mutex_lock l1(*GetMutex(ctx, 0));
// Don't try to acquire a lock on the second ref as they share the same
// mutex.
//
@@ -419,16 +471,20 @@ class ApplyAdadeltaOp : public OpKernel {
if (!ctx->status().ok()) return;
DoCompute(ctx);
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
void DoValidate(OpKernelContext* ctx) {
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
+ Tensor accum_update;
+ OP_REQUIRES_OK(ctx,
+ GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
OP_REQUIRES(
ctx, var.IsInitialized(),
@@ -474,9 +530,13 @@ class ApplyAdadeltaOp : public OpKernel {
void DoCompute(OpKernelContext* ctx) {
const Device& device = ctx->template eigen_device<Device>();
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
+ Tensor accum_update;
+ OP_REQUIRES_OK(ctx,
+ GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
const Tensor& lr = ctx->input(3);
const Tensor& rho = ctx->input(4);
@@ -492,9 +552,12 @@ class ApplyAdadeltaOp : public OpKernel {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
- Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdadeltaOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ResourceApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyAdadeltaOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -536,7 +599,7 @@ class SparseApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
- mutex* mu_var = ctx->input_ref_mutex(0);
+ mutex* mu_var = GetMutex(ctx, 0);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
//
@@ -544,9 +607,14 @@ class SparseApplyAdadeltaOp : public OpKernel {
if (use_exclusive_lock_) {
mu_var->lock();
}
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum_grad = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor accum_update = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum_grad;
+ OP_REQUIRES_OK(ctx,
+ GetInputTensor(ctx, 1, use_exclusive_lock_, &accum_grad));
+ Tensor accum_update;
+ OP_REQUIRES_OK(ctx,
+ GetInputTensor(ctx, 2, use_exclusive_lock_, &accum_update));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -642,7 +710,7 @@ class SparseApplyAdadeltaOp : public OpKernel {
mu_var->unlock();
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -654,6 +722,11 @@ class SparseApplyAdadeltaOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyAdadeltaOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdadelta") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyAdadeltaOp<T, Tindices>);
#define REGISTER_CPU_KERNELS(T) \
REGISTER_KERNELS(T, int32); \
@@ -677,7 +750,8 @@ class ApplyProximalGradientDescentOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES(
ctx, var.IsInitialized(),
@@ -710,17 +784,21 @@ class ApplyProximalGradientDescentOp : public OpKernel {
device, var.flat<T>(), alpha.scalar<T>(), l1.scalar<T>(),
l2.scalar<T>(), delta.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \
- .Device(DEVICE_##D) \
- .TypeConstraint<T>("T"), \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
+ ApplyProximalGradientDescentOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalGradientDescent") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
ApplyProximalGradientDescentOp<D##Device, T>);
REGISTER_KERNELS(CPU, float);
@@ -738,7 +816,8 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
errors::InvalidArgument("var must be at least 1 dimensional"));
@@ -846,18 +925,23 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(T, Tindices) \
- REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<Tindices>("Tindices"), \
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyProximalGradientDescentOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalGradientDescent") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyProximalGradientDescentOp<T, Tindices>);
REGISTER_KERNELS(float, int32);
@@ -875,8 +959,10 @@ class ApplyAdagradOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -905,7 +991,7 @@ class ApplyAdagradOp : public OpKernel {
functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
lr.scalar<T>(), grad.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -915,9 +1001,12 @@ class ApplyAdagradOp : public OpKernel {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
- Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdagradOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ResourceApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyAdagradOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -957,8 +1046,10 @@ class ApplyProximalAdagradOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1004,7 +1095,7 @@ class ApplyProximalAdagradOp : public OpKernel {
device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), l1.scalar<T>(),
l2.scalar<T>(), grad.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -1017,7 +1108,11 @@ using GPUDevice = Eigen::GpuDevice;
#define REGISTER_KERNELS(D, T) \
REGISTER_KERNEL_BUILDER( \
Name("ApplyProximalAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
- ApplyProximalAdagradOp<D##Device, T>);
+ ApplyProximalAdagradOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalAdagrad") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
+ ApplyProximalAdagradOp<D##Device, T>);
REGISTER_KERNELS(CPU, float);
REGISTER_KERNELS(CPU, double);
@@ -1053,8 +1148,10 @@ class SparseApplyAdagradOp : public OpKernel {
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1142,7 +1239,7 @@ class SparseApplyAdagradOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -1154,6 +1251,11 @@ class SparseApplyAdagradOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyAdagradOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyAdagradOp<T, Tindices>);
#define REGISTER_CPU_KERNELS(T) \
REGISTER_KERNELS(T, int32); \
@@ -1177,8 +1279,10 @@ class SparseApplyProximalAdagradOp : public OpKernel {
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1311,18 +1415,23 @@ class SparseApplyProximalAdagradOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(T, Tindices) \
- REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<Tindices>("Tindices"), \
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyProximalAdagradOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalAdagrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyProximalAdagradOp<T, Tindices>);
REGISTER_KERNELS(float, int32);
@@ -1340,9 +1449,14 @@ class ApplyAdagradDAOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor gradient_accum;
+ OP_REQUIRES_OK(
+ ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
+ Tensor gradient_squared_accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
+ &gradient_squared_accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1399,7 +1513,7 @@ class ApplyAdagradDAOp : public OpKernel {
global_step.scalar<int64>()(), l1.scalar<T>(), l2.scalar<T>(),
grad.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -1428,9 +1542,14 @@ class SparseApplyAdagradDAOp : public OpKernel {
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor gradient_accum = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor gradient_squared_accum = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor gradient_accum;
+ OP_REQUIRES_OK(
+ ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &gradient_accum));
+ Tensor gradient_squared_accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_,
+ &gradient_squared_accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1580,7 +1699,7 @@ class SparseApplyAdagradDAOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -1592,6 +1711,11 @@ class SparseApplyAdagradDAOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyAdagradDAOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradDA") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyAdagradDAOp<T, Tindices>);
REGISTER_KERNELS(float, int32);
@@ -1610,9 +1734,12 @@ class ApplyFtrlOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor linear = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
+ Tensor linear;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1677,7 +1804,7 @@ class ApplyFtrlOp : public OpKernel {
lr.scalar<T>(), l1.scalar<T>(),
l2.scalar<T>(), lr_power.scalar<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -1687,9 +1814,12 @@ class ApplyFtrlOp : public OpKernel {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
- Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyFtrlOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ResourceApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyFtrlOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -1710,9 +1840,12 @@ class SparseApplyFtrlOp : public OpKernel {
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor linear = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
+ Tensor linear;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &linear));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1874,18 +2007,23 @@ class SparseApplyFtrlOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(T, Tindices) \
- REGISTER_KERNEL_BUILDER(Name("SparseApplyFtrl") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<Tindices>("Tindices"), \
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyFtrl") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyFtrlOp<CPUDevice, T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyFtrl") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyFtrlOp<CPUDevice, T, Tindices>);
#define REGISTER_CPU_KERNELS(T) \
REGISTER_KERNELS(T, int32); \
@@ -1909,8 +2047,10 @@ class ApplyMomentumOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -1944,7 +2084,7 @@ class ApplyMomentumOp : public OpKernel {
functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
lr.scalar<T>(), grad.flat<T>(),
momentum.scalar<T>(), use_nesterov_);
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -1955,9 +2095,12 @@ class ApplyMomentumOp : public OpKernel {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
- Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyMomentumOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ResourceApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyMomentumOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -2001,8 +2144,10 @@ class SparseApplyMomentumOp : public OpKernel {
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor accum;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &accum));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -2072,7 +2217,7 @@ class SparseApplyMomentumOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -2085,6 +2230,11 @@ class SparseApplyMomentumOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyMomentumOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyMomentum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyMomentumOp<T, Tindices>);
#define REGISTER_CPU_KERNELS(T) \
REGISTER_KERNELS(T, int32); \
@@ -2107,9 +2257,12 @@ class ApplyAdamOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor m = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor v = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor m;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &m));
+ Tensor v;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &v));
OP_REQUIRES(
ctx, var.IsInitialized(),
errors::FailedPrecondition(
@@ -2171,7 +2324,7 @@ class ApplyAdamOp : public OpKernel {
beta1.scalar<T>(), beta2.scalar<T>(),
epsilon.scalar<T>(), grad.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -2181,9 +2334,12 @@ class ApplyAdamOp : public OpKernel {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
- Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdamOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ResourceApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyAdamOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -2236,9 +2392,12 @@ class ApplyRMSPropOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor ms;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
+ Tensor mom;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),
@@ -2294,7 +2453,7 @@ class ApplyRMSPropOp : public OpKernel {
rho.scalar<T>(), momentum.scalar<T>(),
epsilon.scalar<T>(), grad.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -2312,10 +2471,14 @@ class ApplyCenteredRMSPropOp : public OpKernel {
auto locks =
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor mg = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor ms = ctx->mutable_input(2, use_exclusive_lock_);
- Tensor mom = ctx->mutable_input(3, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor mg;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
+ Tensor ms;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
+ Tensor mom;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),
@@ -2379,7 +2542,7 @@ class ApplyCenteredRMSPropOp : public OpKernel {
device, var.flat<T>(), mg.flat<T>(), ms.flat<T>(), mom.flat<T>(),
lr.scalar<T>(), rho.scalar<T>(), momentum.scalar<T>(),
epsilon.scalar<T>(), grad.flat<T>());
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -2395,7 +2558,14 @@ using GPUDevice = Eigen::GpuDevice;
ApplyRMSPropOp<D##Device, T>); \
REGISTER_KERNEL_BUILDER( \
Name("ApplyCenteredRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
- ApplyCenteredRMSPropOp<D##Device, T>);
+ ApplyCenteredRMSPropOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ResourceApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyRMSPropOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceApplyCenteredRMSProp") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
+ ApplyCenteredRMSPropOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
TF_CALL_half(REGISTER_CPU_KERNELS);
@@ -2449,9 +2619,12 @@ class SparseApplyRMSPropOp : public OpKernel {
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor ms;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &ms));
+ Tensor mom;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),
@@ -2552,7 +2725,7 @@ class SparseApplyRMSPropOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
@@ -2572,10 +2745,14 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
auto locks =
MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2, 3});
- Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
- Tensor mg = ctx->mutable_input(1, use_exclusive_lock_);
- Tensor ms = ctx->mutable_input(2, use_exclusive_lock_);
- Tensor mom = ctx->mutable_input(3, use_exclusive_lock_);
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 0, use_exclusive_lock_, &var));
+ Tensor mg;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 1, use_exclusive_lock_, &mg));
+ Tensor ms;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 2, use_exclusive_lock_, &ms));
+ Tensor mom;
+ OP_REQUIRES_OK(ctx, GetInputTensor(ctx, 3, use_exclusive_lock_, &mom));
OP_REQUIRES(
ctx, var.IsInitialized(),
@@ -2685,23 +2862,33 @@ class SparseApplyCenteredRMSPropOp : public OpKernel {
}
}
- ctx->forward_ref_input_to_ref_output(0, 0);
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(T, Tindices) \
- REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<Tindices>("Tindices"), \
- SparseApplyRMSPropOp<T, Tindices>); \
- REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .TypeConstraint<Tindices>("Tindices"), \
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyRMSPropOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyCenteredRMSPropOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyRMSProp") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyRMSPropOp<T, Tindices>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyCenteredRMSProp") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
SparseApplyCenteredRMSPropOp<T, Tindices>);
REGISTER_KERNELS(Eigen::half, int32);