path: root/tensorflow/core/kernels/training_ops.cc
diff options
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
1 files changed, 137 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index b16c9c860a..2f9714a37a 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -1937,4 +1937,141 @@ REGISTER_KERNELS(GPU, double);
+// Note, this op works on cpu only.
+template <typename T, typename Tindex>
+class SparseApplyRMSPropOp : public OpKernel {
+ public:
+ explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+ void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
+ auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ ctx, ms.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ ctx, mom.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(2)));
+ const Tensor& lr = ctx->input(3);
+ const Tensor& rho = ctx->input(4);
+ const Tensor& momentum = ctx->input(5);
+ const Tensor& epsilon = ctx->input(6);
+ const Tensor& grad = ctx->input(7);
+ const Tensor& indices = ctx->input(8);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
+ errors::InvalidArgument("rho is not a scalar: ",
+ rho.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
+ errors::InvalidArgument("momentum is not a scalar: ",
+ momentum.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon.shape().DebugString()));
+ OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
+ errors::InvalidArgument("var and ms do not have the same shape",
+ var.shape().DebugString(), " ",
+ ms.shape().DebugString()));
+ OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
+ errors::InvalidArgument(
+ "var and mom do not have the same shape",
+ var.shape().DebugString(), " ", mom.shape().DebugString()));
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices must be one-dimensional"));
+ const Tindex N = indices.dim_size(0);
+ ctx, grad.dim_size(0) == N,
+ errors::InvalidArgument(
+ "grad must be the same size as indices in the first dimension."));
+ if (N > 0) {
+ const Tindex first_dim_size = var.dim_size(0);
+ // Validate all the indices are in range
+ auto indices_vec = indices.vec<Tindex>();
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ }
+ auto var_flat = var.flat_outer_dims<T>();
+ auto ms_flat = ms.flat_outer_dims<T>();
+ auto mom_flat = mom.flat_outer_dims<T>();
+ auto grad_flat = grad.flat_outer_dims<T>();
+ const T lr_scalar = lr.scalar<T>()();
+ const T rho_scalar = rho.scalar<T>()();
+ const T epsilon_scalar = epsilon.scalar<T>()();
+ const T momentum_scalar = momentum.scalar<T>()();
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ auto ms_ = ms_flat.template chip<0>(index);
+ auto mom_ = mom_flat.template chip<0>(index);
+ auto grad_ = grad_flat.template chip<0>(i);
+ ms_ = ms_ * ms_.constant(rho_scalar) +
+ grad_.square() * grad_.constant(T(1) - rho_scalar);
+ mom_ = mom_ * mom_.constant(momentum_scalar) +
+ (ms_ + ms_.constant(epsilon_scalar)).rsqrt() *
+ ms_.constant(lr_scalar) * grad_;
+ auto v = var_flat.template chip<0>(index);
+ v -= mom_;
+ }
+ }
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+ private:
+ bool use_exclusive_lock_;
+#define REGISTER_KERNELS(T, Tindices) \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyRMSPropOp<T, Tindices>);
+REGISTER_KERNELS(Eigen::half, int32);
+REGISTER_KERNELS(Eigen::half, int64);
+REGISTER_KERNELS(float, int32);
+REGISTER_KERNELS(float, int64);
+REGISTER_KERNELS(double, int32);
+REGISTER_KERNELS(double, int64);
} // namespace tensorflow