diff options
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/training_ops.cc | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index b16c9c860a..2f9714a37a 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -1937,4 +1937,141 @@ REGISTER_KERNELS(GPU, double); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS + +// Note, this op works on cpu only. +template <typename T, typename Tindex> +class SparseApplyRMSPropOp : public OpKernel { + public: + explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { + auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); + + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); + Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); + + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, ms.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + OP_REQUIRES( + ctx, mom.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(2))); + + const Tensor& lr = ctx->input(3); + const Tensor& rho = ctx->input(4); + const Tensor& momentum = ctx->input(5); + const Tensor& epsilon = ctx->input(6); + const Tensor& grad = ctx->input(7); + const Tensor& indices = ctx->input(8); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), + errors::InvalidArgument("rho is not a scalar: ", + rho.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), + errors::InvalidArgument("momentum is not a scalar: ", + momentum.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon.shape().DebugString())); + + OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), + errors::InvalidArgument("var and ms do not have the same shape", + var.shape().DebugString(), " ", + ms.shape().DebugString())); + + OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), + errors::InvalidArgument( + "var and mom do not have the same shape", + var.shape().DebugString(), " ", mom.shape().DebugString())); + + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and grad do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices must be one-dimensional")); + + const Tindex N = indices.dim_size(0); + OP_REQUIRES( + ctx, grad.dim_size(0) == N, + errors::InvalidArgument( + "grad must be the same size as indices in the first dimension.")); + + if (N > 0) { + const Tindex first_dim_size = var.dim_size(0); + // Validate all the indices are in range + auto indices_vec = indices.vec<Tindex>(); + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range"))); + } + + auto var_flat = var.flat_outer_dims<T>(); + auto ms_flat = ms.flat_outer_dims<T>(); + auto mom_flat = mom.flat_outer_dims<T>(); + auto grad_flat = grad.flat_outer_dims<T>(); + const T lr_scalar = lr.scalar<T>()(); + const T rho_scalar = rho.scalar<T>()(); + const T epsilon_scalar = epsilon.scalar<T>()(); + const T momentum_scalar = momentum.scalar<T>()(); + + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + + auto ms_ = ms_flat.template chip<0>(index); + auto mom_ = mom_flat.template chip<0>(index); + auto grad_ = grad_flat.template chip<0>(i); + + ms_ = ms_ * ms_.constant(rho_scalar) + + grad_.square() * grad_.constant(T(1) - rho_scalar); + mom_ = mom_ * mom_.constant(momentum_scalar) + + (ms_ + ms_.constant(epsilon_scalar)).rsqrt() * + ms_.constant(lr_scalar) * grad_; + + auto v = var_flat.template chip<0>(index); + v -= mom_; + } + } + + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; +}; + +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyRMSPropOp<T, Tindices>); + +REGISTER_KERNELS(Eigen::half, int32); +REGISTER_KERNELS(Eigen::half, int64); +REGISTER_KERNELS(float, int32); +REGISTER_KERNELS(float, int64); +REGISTER_KERNELS(double, int32); +REGISTER_KERNELS(double, int64); + +#undef REGISTER_KERNELS + + } // namespace tensorflow |