diff options
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/training_ops.cc | 150 |
1 files changed, 150 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index f53c567c4d..5b13b10937 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -330,6 +330,27 @@ struct ApplyAdamSYCL { template <typename T> struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {}; +template <typename Device, typename T> +struct ApplyAdaMaxNonCuda { + void operator()(const Device& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, + typename TTypes<T>::ConstScalar beta1_power, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar beta1, + typename TTypes<T>::ConstScalar beta2, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstFlat grad) { + m.device(d) += (grad - m) * (T(1) - beta1()); + // Here v is u in section 7.1 + v.device(d) = (beta2() * v).cwiseMax(grad.abs()); + // var is θ in section 7.1 + var.device(d) -= lr() / (T(1) - beta1_power()) * (m / (v + epsilon())); + } +}; + +template <typename T> +struct ApplyAdaMax<CPUDevice, T> : ApplyAdaMaxNonCuda<CPUDevice, T> {}; + template <typename T> struct ApplyRMSProp<CPUDevice, T> { void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, @@ -2752,6 +2773,135 @@ REGISTER_KERNELS(GPU, double); #undef REGISTER_KERNELS template <typename Device, typename T> +class ApplyAdaMaxOp : public OpKernel { + public: + explicit ApplyAdaMaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, + {0, 1, 2}); + + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( + ctx, 0, use_exclusive_lock_, false, &var)); + Tensor m; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( + ctx, 1, use_exclusive_lock_, false, &m)); + Tensor v; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( + ctx, 2, use_exclusive_lock_, false, &v)); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(0))); + OP_REQUIRES( + ctx, m.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(1))); + OP_REQUIRES( + ctx, v.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(2))); + + const Tensor& beta1_power = ctx->input(3); + const Tensor& lr = ctx->input(4); + const Tensor& beta1 = ctx->input(5); + const Tensor& beta2 = ctx->input(6); + const Tensor& epsilon = ctx->input(7); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar : ", + lr.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon.shape().DebugString())); + + const Tensor& grad = ctx->input(8); + OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), + errors::InvalidArgument("var and m do not have the same shape", + var.shape().DebugString(), " ", + m.shape().DebugString())); + OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), + errors::InvalidArgument("var and v do not have the same shape", + var.shape().DebugString(), " ", + v.shape().DebugString())); + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and grad do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + + const Device& device = ctx->template eigen_device<Device>(); + functor::ApplyAdaMax<Device, T>()( + device, var.flat<T>(), m.flat<T>(), v.flat<T>(), + beta1_power.scalar<T>(), lr.scalar<T>(), + beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(), + grad.flat<T>()); + + MaybeForwardRefInputToRefOutput(ctx, 0, 0); + } + + private: + bool use_exclusive_lock_; +}; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyAdaMaxOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \ + .HostMemory("var") \ + .HostMemory("m") \ + .HostMemory("v") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T"), \ + ApplyAdaMaxOp<D##Device, T>); +#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); + +TF_CALL_half(REGISTER_CPU_KERNELS); +TF_CALL_float(REGISTER_CPU_KERNELS); +TF_CALL_double(REGISTER_CPU_KERNELS); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyAdaMax<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::Flat var, \ + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \ + typename TTypes<T>::ConstScalar beta1_power, \ + typename TTypes<T>::ConstScalar lr, \ + typename TTypes<T>::ConstScalar beta1, \ + typename TTypes<T>::ConstScalar beta2, \ + typename TTypes<T>::ConstScalar epsilon, \ + typename TTypes<T>::ConstFlat grad); \ + extern template struct ApplyAdaMax<GPUDevice, T>; +DECLARE_GPU_SPEC(Eigen::half); +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, Eigen::half); +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_CPU_KERNELS +#undef REGISTER_KERNELS + +template <typename Device, typename T> class ApplyRMSPropOp : public OpKernel { public: explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |