aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r--tensorflow/core/kernels/training_ops.cc137
1 files changed, 137 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index b16c9c860a..2f9714a37a 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -1937,4 +1937,141 @@ REGISTER_KERNELS(GPU, double);
#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS
+
+// Note, this op works on cpu only.
+template <typename T, typename Tindex>
+class SparseApplyRMSPropOp : public OpKernel {
+ public:
+ explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
+ auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
+
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
+
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, ms.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, mom.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(2)));
+
+ const Tensor& lr = ctx->input(3);
+ const Tensor& rho = ctx->input(4);
+ const Tensor& momentum = ctx->input(5);
+ const Tensor& epsilon = ctx->input(6);
+ const Tensor& grad = ctx->input(7);
+ const Tensor& indices = ctx->input(8);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
+ errors::InvalidArgument("rho is not a scalar: ",
+ rho.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
+ errors::InvalidArgument("momentum is not a scalar: ",
+ momentum.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon.shape().DebugString()));
+
+ OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
+ errors::InvalidArgument("var and ms do not have the same shape",
+ var.shape().DebugString(), " ",
+ ms.shape().DebugString()));
+
+ OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
+ errors::InvalidArgument(
+ "var and mom do not have the same shape",
+ var.shape().DebugString(), " ", mom.shape().DebugString()));
+
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices must be one-dimensional"));
+
+ const Tindex N = indices.dim_size(0);
+ OP_REQUIRES(
+ ctx, grad.dim_size(0) == N,
+ errors::InvalidArgument(
+ "grad must be the same size as indices in the first dimension."));
+
+ if (N > 0) {
+ const Tindex first_dim_size = var.dim_size(0);
+ // Validate all the indices are in range
+ auto indices_vec = indices.vec<Tindex>();
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ }
+
+ auto var_flat = var.flat_outer_dims<T>();
+ auto ms_flat = ms.flat_outer_dims<T>();
+ auto mom_flat = mom.flat_outer_dims<T>();
+ auto grad_flat = grad.flat_outer_dims<T>();
+ const T lr_scalar = lr.scalar<T>()();
+ const T rho_scalar = rho.scalar<T>()();
+ const T epsilon_scalar = epsilon.scalar<T>()();
+ const T momentum_scalar = momentum.scalar<T>()();
+
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+
+ auto ms_ = ms_flat.template chip<0>(index);
+ auto mom_ = mom_flat.template chip<0>(index);
+ auto grad_ = grad_flat.template chip<0>(i);
+
+ ms_ = ms_ * ms_.constant(rho_scalar) +
+ grad_.square() * grad_.constant(T(1) - rho_scalar);
+ mom_ = mom_ * mom_.constant(momentum_scalar) +
+ (ms_ + ms_.constant(epsilon_scalar)).rsqrt() *
+ ms_.constant(lr_scalar) * grad_;
+
+ auto v = var_flat.template chip<0>(index);
+ v -= mom_;
+ }
+ }
+
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyRMSPropOp<T, Tindices>);
+
+REGISTER_KERNELS(Eigen::half, int32);
+REGISTER_KERNELS(Eigen::half, int64);
+REGISTER_KERNELS(float, int32);
+REGISTER_KERNELS(float, int64);
+REGISTER_KERNELS(double, int32);
+REGISTER_KERNELS(double, int64);
+
+#undef REGISTER_KERNELS
+
+
} // namespace tensorflow