/* Copyright 2015 Google Inc. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/training_ops.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; namespace functor { static inline bool DoInline(int64 size) { return size <= (256ll << 10); } template struct ApplyGradientDescent { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad) { if (DoInline(var.size())) { var -= grad * lr(); } else { var.device(d) -= grad * lr(); } } }; template struct ApplyAdagrad { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat accum, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad) { if (DoInline(var.size())) { accum += grad.square(); var -= grad * lr() * accum.rsqrt(); } else { accum.device(d) += grad.square(); var.device(d) -= grad * lr() * accum.rsqrt(); } } }; template struct ApplyMomentum { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat accum, typename TTypes::ConstScalar lr, typename TTypes::ConstFlat grad, typename TTypes::ConstScalar momentum) { if (DoInline(var.size())) { accum = accum * momentum() + grad; var -= accum * lr(); } else { accum.device(d) = accum * momentum() + grad; var.device(d) -= accum * lr(); } } }; template struct ApplyAdam { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat m, typename TTypes::Flat v, typename TTypes::ConstScalar beta1_power, typename TTypes::ConstScalar beta2_power, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar beta1, typename TTypes::ConstScalar beta2, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { const T alpha = lr() * std::sqrt(1 - beta2_power()) / (1 - beta1_power()); if (DoInline(var.size())) { m += (grad - m) * (1 - beta1()); v += (grad.square() - v) * (1 - beta2()); var -= (m * alpha) / (v.sqrt() + epsilon()); } else { m.device(d) += (grad - m) * (1 - beta1()); v.device(d) += (grad.square() - v) * (1 - beta2()); var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); } } }; template struct ApplyRMSProp { void operator()(const CPUDevice& d, typename TTypes::Flat var, typename TTypes::Flat ms, typename TTypes::Flat mom, typename TTypes::ConstScalar lr, typename TTypes::ConstScalar rho, typename TTypes::ConstScalar momentum, typename TTypes::ConstScalar epsilon, typename TTypes::ConstFlat grad) { if (DoInline(var.size())) { ms += (grad.square() - ms) * (1 - rho()); mom = mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt()); var -= mom; } else { ms.device(d) += (grad.square() - ms) * (1 - rho()); mom.device(d) = mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt()); var.device(d) -= mom; } } }; } // namespace functor template class ApplyGradientDescentOp : public OpKernel { public: explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); } void Compute(OpKernelContext* ctx) override { if (use_exclusive_lock_) { mutex_lock l(*ctx->input_ref_mutex(0)); DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } else { DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; void DoValidate(OpKernelContext* ctx) { Tensor var = ctx->mutable_input(0, use_exclusive_lock_); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(0))); const Tensor& alpha = ctx->input(1); OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(alpha.shape()), errors::InvalidArgument("alpha is not a scalar: ", alpha.shape().DebugString())); const Tensor& delta = ctx->input(2); OP_REQUIRES( ctx, var.shape().IsSameSize(delta.shape()), errors::InvalidArgument("var and delta do not have the same shape", var.shape().DebugString(), " ", delta.shape().DebugString())); } void DoCompute(OpKernelContext* ctx) { const Device& device = ctx->template eigen_device(); Tensor var = ctx->mutable_input(0, use_exclusive_lock_); const Tensor& alpha = ctx->input(1); const Tensor& delta = ctx->input(2); functor::ApplyGradientDescent()( device, var.flat(), alpha.scalar(), delta.flat()); } }; #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint("T"), \ ApplyGradientDescentOp); REGISTER_KERNELS(CPU, float); REGISTER_KERNELS(CPU, double); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void ApplyGradientDescent::operator()( \ const GPUDevice& d, typename TTypes::Flat var, \ typename TTypes::ConstScalar alpha, \ typename TTypes::ConstFlat delta); \ extern template struct ApplyGradientDescent; DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); #endif #undef REGISTER_KERNELS template class ApplyAdagradOp : public OpKernel { public: explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); } void Compute(OpKernelContext* ctx) override { if (use_exclusive_lock_) { mutex_lock l1(*ctx->input_ref_mutex(0)); // Don't try to acquire a lock on the second ref as they share the same // mutex. // // mutex_lock l2(*ctx->input_ref_mutex(1)); DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } else { DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; void DoValidate(OpKernelContext* ctx) { Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(1))); const Tensor& lr = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& grad = ctx->input(3); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", var.shape().DebugString(), " ", accum.shape().DebugString())); OP_REQUIRES( ctx, var.shape().IsSameSize(grad.shape()), errors::InvalidArgument("var and delta do not have the same shape", var.shape().DebugString(), " ", grad.shape().DebugString())); } void DoCompute(OpKernelContext* ctx) { const Device& device = ctx->template eigen_device(); Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); const Tensor& lr = ctx->input(2); const Tensor& grad = ctx->input(3); functor::ApplyAdagrad()(device, var.flat(), accum.flat(), lr.scalar(), grad.flat()); } }; typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint("T"), \ ApplyAdagradOp); REGISTER_KERNELS(CPU, float); REGISTER_KERNELS(CPU, double); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void ApplyAdagrad::operator()( \ const GPUDevice& d, typename TTypes::Flat var, \ typename TTypes::Flat accum, typename TTypes::ConstScalar lr, \ typename TTypes::ConstFlat grad); \ extern template struct ApplyAdagrad; DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); #endif #undef REGISTER_KERNELS // Note, this op works on cpu only. template class SparseApplyAdagradOp : public OpKernel { public: explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { mutex* mu_var = ctx->input_ref_mutex(0); // mu_accum is actually the same mutex as mu_var since currently we use a // global mutex. // // mutex* mu_accum = ctx->input_ref_mutex(1); if (use_exclusive_lock_) { mu_var->lock(); } Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(1))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", var.shape().DebugString(), " ", accum.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), errors::InvalidArgument("var must be at least 1 dimensional")); const Tensor& lr = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& grad = ctx->input(3); const Tensor& indices = ctx->input(4); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), errors::InvalidArgument("indices must be one-dimensional")); for (int d = 1; d < var.dims(); d++) { OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), errors::InvalidArgument(strings::StrCat( "var and grad must match in dimension ", d))); } const Tindex N = indices.dim_size(0); OP_REQUIRES( ctx, grad.dim_size(0) == N, errors::InvalidArgument( "grad must be the same size as indices in the first dimension.")); if (N > 0) { const Tindex first_dim_size = var.dim_size(0); // Validate all the indices are in range auto indices_vec = indices.vec(); for (Tindex i = 0; i < N; i++) { const Tindex index = indices_vec(i); OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, errors::InvalidArgument( strings::StrCat("Index ", index, " at offset ", i, " in indices is out of range"))); } auto var_flat = var.flat_outer_dims(); auto accum_flat = accum.flat_outer_dims(); auto grad_flat = grad.flat_outer_dims(); T lr_scalar = lr.scalar()(); // Note(yonghui): It might be worth multi-threading square() and rsqrt(). for (Tindex i = 0; i < N; i++) { const Tindex index = indices_vec(i); auto a = accum_flat.template chip<0>(index); auto g = grad_flat.template chip<0>(i); auto v = var_flat.template chip<0>(index); a += g.square(); v -= g.constant(lr_scalar) * g * a.rsqrt(); } } if (use_exclusive_lock_) { mu_var->unlock(); } ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; }; #define REGISTER_KERNELS(T, Tindices) \ REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("Tindices"), \ SparseApplyAdagradOp); REGISTER_KERNELS(float, int32); REGISTER_KERNELS(float, int64); REGISTER_KERNELS(double, int32); REGISTER_KERNELS(double, int64); #undef REGISTER_KERNELS template class ApplyMomentumOp : public OpKernel { public: explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); } void Compute(OpKernelContext* ctx) override { if (use_exclusive_lock_) { mutex_lock l1(*ctx->input_ref_mutex(0)); // Don't try to acquire a lock on the second ref as they share the same // mutex. // // mutex_lock l2(*ctx->input_ref_mutex(1)); DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } else { DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; void DoValidate(OpKernelContext* ctx) { Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(1))); const Tensor& lr = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& grad = ctx->input(3); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", var.shape().DebugString(), " ", accum.shape().DebugString())); OP_REQUIRES( ctx, var.shape().IsSameSize(grad.shape()), errors::InvalidArgument("var and delta do not have the same shape", var.shape().DebugString(), " ", grad.shape().DebugString())); const Tensor& momentum = ctx->input(4); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), errors::InvalidArgument("momentum is not a scalar: ", momentum.shape().DebugString())); } void DoCompute(OpKernelContext* ctx) { const Device& device = ctx->template eigen_device(); Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); const Tensor& lr = ctx->input(2); const Tensor& grad = ctx->input(3); const Tensor& momentum = ctx->input(4); functor::ApplyMomentum()(device, var.flat(), accum.flat(), lr.scalar(), grad.flat(), momentum.scalar()); } }; typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint("T"), \ ApplyMomentumOp); REGISTER_KERNELS(CPU, float); REGISTER_KERNELS(CPU, double); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void ApplyMomentum::operator()( \ const GPUDevice& d, typename TTypes::Flat var, \ typename TTypes::Flat accum, typename TTypes::ConstScalar lr, \ typename TTypes::ConstFlat grad, \ typename TTypes::ConstScalar momentum); \ extern template struct ApplyMomentum; DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); #endif #undef REGISTER_KERNELS // Note, this op works on cpu only. template class SparseApplyMomentumOp : public OpKernel { public: explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { mutex* mu_var = ctx->input_ref_mutex(0); // mu_accum is actually the same mutex as mu_var since currently we use a // global mutex. // // mutex* mu_accum = ctx->input_ref_mutex(1); if (use_exclusive_lock_) { mu_var->lock(); } Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor accum = ctx->mutable_input(1, use_exclusive_lock_); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(0))); OP_REQUIRES( ctx, accum.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(1))); OP_REQUIRES( ctx, var.shape().IsSameSize(accum.shape()), errors::InvalidArgument("var and accum do not have the same shape", var.shape().DebugString(), " ", accum.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), errors::InvalidArgument("var must be at least 1 dimensional")); const Tensor& lr = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& grad = ctx->input(3); const Tensor& indices = ctx->input(4); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), errors::InvalidArgument("indices must be one-dimensional")); for (int d = 1; d < var.dims(); d++) { OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), errors::InvalidArgument(strings::StrCat( "var and grad must match in dimension ", d))); } const Tindex N = indices.dim_size(0); OP_REQUIRES( ctx, grad.dim_size(0) == N, errors::InvalidArgument( "grad must be the same size as indices in the first dimension.")); const Tensor& momentum = ctx->input(5); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), errors::InvalidArgument("momentum is not a scalar: ", momentum.shape().DebugString())); if (N > 0) { const Tindex first_dim_size = var.dim_size(0); // Validate all the indices are in range auto indices_vec = indices.vec(); for (Tindex i = 0; i < N; i++) { const Tindex index = indices_vec(i); OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, errors::InvalidArgument( strings::StrCat("Index ", index, " at offset ", i, " in indices is out of range"))); } auto var_flat = var.flat_outer_dims(); auto accum_flat = accum.flat_outer_dims(); auto grad_flat = grad.flat_outer_dims(); T lr_scalar = lr.scalar()(); T momentum_scalar = momentum.scalar()(); for (Tindex i = 0; i < N; i++) { const Tindex index = indices_vec(i); auto a = accum_flat.template chip<0>(index); auto g = grad_flat.template chip<0>(i); auto v = var_flat.template chip<0>(index); a = a * a.constant(momentum_scalar) + g; v -= a.constant(lr_scalar) * a; } } if (use_exclusive_lock_) { mu_var->unlock(); } ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; }; #define REGISTER_KERNELS(T, Tindices) \ REGISTER_KERNEL_BUILDER(Name("SparseApplyMomentum") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("Tindices"), \ SparseApplyMomentumOp); REGISTER_KERNELS(float, int32); REGISTER_KERNELS(float, int64); REGISTER_KERNELS(double, int32); REGISTER_KERNELS(double, int64); #undef REGISTER_KERNELS template class ApplyAdamOp : public OpKernel { public: explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); } void Compute(OpKernelContext* ctx) override { if (use_exclusive_lock_) { // all input refs share the same mutex mutex_lock l1(*ctx->input_ref_mutex(0)); DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } else { DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; void DoValidate(OpKernelContext* ctx) { Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor m = ctx->mutable_input(1, use_exclusive_lock_); Tensor v = ctx->mutable_input(2, use_exclusive_lock_); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(0))); OP_REQUIRES( ctx, m.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(1))); OP_REQUIRES( ctx, v.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(2))); const Tensor& beta1_power = ctx->input(3); const Tensor& beta2_power = ctx->input(4); const Tensor& lr = ctx->input(5); const Tensor& beta1 = ctx->input(6); const Tensor& beta2 = ctx->input(7); const Tensor& epsilon = ctx->input(8); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), errors::InvalidArgument("beta1_power is not a scalar: ", beta1_power.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()), errors::InvalidArgument("beta2_power is not a scalar: ", beta2_power.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), errors::InvalidArgument("beta1 is not a scalar: ", beta1.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), errors::InvalidArgument("beta2 is not a scalar: ", beta2.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), errors::InvalidArgument("epsilon is not a scalar: ", epsilon.shape().DebugString())); const Tensor& grad = ctx->input(9); OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), errors::InvalidArgument("var and m do not have the same shape", var.shape().DebugString(), " ", m.shape().DebugString())); OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), errors::InvalidArgument("var and v do not have the same shape", var.shape().DebugString(), " ", v.shape().DebugString())); OP_REQUIRES( ctx, var.shape().IsSameSize(grad.shape()), errors::InvalidArgument("var and grad do not have the same shape", var.shape().DebugString(), " ", grad.shape().DebugString())); } void DoCompute(OpKernelContext* ctx) { const Device& device = ctx->template eigen_device(); Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor m = ctx->mutable_input(1, use_exclusive_lock_); Tensor v = ctx->mutable_input(2, use_exclusive_lock_); const Tensor& beta1_power = ctx->input(3); const Tensor& beta2_power = ctx->input(4); const Tensor& lr = ctx->input(5); const Tensor& beta1 = ctx->input(6); const Tensor& beta2 = ctx->input(7); const Tensor& epsilon = ctx->input(8); const Tensor& grad = ctx->input(9); functor::ApplyAdam()(device, var.flat(), m.flat(), v.flat(), beta1_power.scalar(), beta2_power.scalar(), lr.scalar(), beta1.scalar(), beta2.scalar(), epsilon.scalar(), grad.flat()); } }; typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint("T"), \ ApplyAdamOp); REGISTER_KERNELS(CPU, float); REGISTER_KERNELS(CPU, double); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void ApplyAdam::operator()( \ const GPUDevice& d, typename TTypes::Flat var, \ typename TTypes::Flat m, typename TTypes::Flat v, \ typename TTypes::ConstScalar beta1_power, \ typename TTypes::ConstScalar beta2_power, \ typename TTypes::ConstScalar lr, \ typename TTypes::ConstScalar beta1, \ typename TTypes::ConstScalar beta2, \ typename TTypes::ConstScalar epsilon, \ typename TTypes::ConstFlat grad); \ extern template struct ApplyAdam; DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); #endif #undef REGISTER_KERNELS template class ApplyRMSPropOp : public OpKernel { public: explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); } void Compute(OpKernelContext* ctx) override { if (use_exclusive_lock_) { // all input refs share the same mutex mutex_lock l1(*ctx->input_ref_mutex(0)); DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } else { DoValidate(ctx); if (!ctx->status().ok()) return; DoCompute(ctx); } ctx->forward_ref_input_to_ref_output(0, 0); } private: bool use_exclusive_lock_; void DoValidate(OpKernelContext* ctx) { Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(0))); OP_REQUIRES( ctx, ms.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(1))); OP_REQUIRES( ctx, mom.IsInitialized(), errors::FailedPrecondition( "Attempting to use uninitialized variables: ", def().input(2))); const Tensor& lr = ctx->input(3); const Tensor& rho = ctx->input(4); const Tensor& momentum = ctx->input(5); const Tensor& epsilon = ctx->input(6); const Tensor& grad = ctx->input(7); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), errors::InvalidArgument("rho is not a scalar: ", rho.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), errors::InvalidArgument("momentum is not a scalar: ", momentum.shape().DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), errors::InvalidArgument("epsilon is not a scalar: ", epsilon.shape().DebugString())); OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), errors::InvalidArgument("var and ms do not have the same shape", var.shape().DebugString(), " ", ms.shape().DebugString())); OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), errors::InvalidArgument( "var and mom do not have the same shape", var.shape().DebugString(), " ", mom.shape().DebugString())); OP_REQUIRES( ctx, var.shape().IsSameSize(grad.shape()), errors::InvalidArgument("var and grad do not have the same shape", var.shape().DebugString(), " ", grad.shape().DebugString())); } void DoCompute(OpKernelContext* ctx) { const Device& device = ctx->template eigen_device(); Tensor var = ctx->mutable_input(0, use_exclusive_lock_); Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); const Tensor& lr = ctx->input(3); const Tensor& rho = ctx->input(4); const Tensor& momentum = ctx->input(5); const Tensor& epsilon = ctx->input(6); const Tensor& grad = ctx->input(7); functor::ApplyRMSProp()(device, var.flat(), ms.flat(), mom.flat(), lr.scalar(), rho.scalar(), momentum.scalar(), epsilon.scalar(), grad.flat()); } }; typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #define REGISTER_KERNELS(D, T) \ REGISTER_KERNEL_BUILDER( \ Name("ApplyRMSProp").Device(DEVICE_##D).TypeConstraint("T"), \ ApplyRMSPropOp); REGISTER_KERNELS(CPU, float); REGISTER_KERNELS(CPU, double); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void ApplyRMSProp::operator()( \ const GPUDevice& d, typename TTypes::Flat var, \ typename TTypes::Flat ms, typename TTypes::Flat mom, \ typename TTypes::ConstScalar lr, typename TTypes::ConstScalar rho, \ typename TTypes::ConstScalar momentum, \ typename TTypes::ConstScalar epsilon, \ typename TTypes::ConstFlat grad); \ extern template struct ApplyRMSProp; DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor REGISTER_KERNELS(GPU, float); REGISTER_KERNELS(GPU, double); #endif #undef REGISTER_KERNELS } // namespace tensorflow