aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_ops.cc')
-rw-r--r--tensorflow/core/kernels/training_ops.cc150
1 files changed, 150 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index f53c567c4d..5b13b10937 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -330,6 +330,27 @@ struct ApplyAdamSYCL {
template <typename T>
struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {};
+template <typename Device, typename T>
+struct ApplyAdaMaxNonCuda {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ m.device(d) += (grad - m) * (T(1) - beta1());
+ // Here v is u in section 7.1
+ v.device(d) = (beta2() * v).cwiseMax(grad.abs());
+ // var is θ in section 7.1
+ var.device(d) -= lr() / (T(1) - beta1_power()) * (m / (v + epsilon()));
+ }
+};
+
+template <typename T>
+struct ApplyAdaMax<CPUDevice, T> : ApplyAdaMaxNonCuda<CPUDevice, T> {};
+
template <typename T>
struct ApplyRMSProp<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
@@ -2752,6 +2773,135 @@ REGISTER_KERNELS(GPU, double);
#undef REGISTER_KERNELS
template <typename Device, typename T>
+class ApplyAdaMaxOp : public OpKernel {
+ public:
+ explicit ApplyAdaMaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
+ {0, 1, 2});
+
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+ ctx, 0, use_exclusive_lock_, false, &var));
+ Tensor m;
+ OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+ ctx, 1, use_exclusive_lock_, false, &m));
+ Tensor v;
+ OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+ ctx, 2, use_exclusive_lock_, false, &v));
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", requested_input(0)));
+ OP_REQUIRES(
+ ctx, m.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", requested_input(1)));
+ OP_REQUIRES(
+ ctx, v.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", requested_input(2)));
+
+ const Tensor& beta1_power = ctx->input(3);
+ const Tensor& lr = ctx->input(4);
+ const Tensor& beta1 = ctx->input(5);
+ const Tensor& beta2 = ctx->input(6);
+ const Tensor& epsilon = ctx->input(7);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
+ errors::InvalidArgument("beta1_power is not a scalar: ",
+ beta1_power.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar : ",
+ lr.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
+ errors::InvalidArgument("beta1 is not a scalar: ",
+ beta1.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
+ errors::InvalidArgument("beta2 is not a scalar: ",
+ beta2.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon.shape().DebugString()));
+
+ const Tensor& grad = ctx->input(8);
+ OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var.shape().DebugString(), " ",
+ m.shape().DebugString()));
+ OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
+ errors::InvalidArgument("var and v do not have the same shape",
+ var.shape().DebugString(), " ",
+ v.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+
+ const Device& device = ctx->template eigen_device<Device>();
+ functor::ApplyAdaMax<Device, T>()(
+ device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
+ beta1_power.scalar<T>(), lr.scalar<T>(),
+ beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
+ grad.flat<T>());
+
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdaMaxOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \
+ .HostMemory("var") \
+ .HostMemory("m") \
+ .HostMemory("v") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
+ ApplyAdaMaxOp<D##Device, T>);
+#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
+
+TF_CALL_half(REGISTER_CPU_KERNELS);
+TF_CALL_float(REGISTER_CPU_KERNELS);
+TF_CALL_double(REGISTER_CPU_KERNELS);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ApplyAdaMax<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
+ typename TTypes<T>::ConstScalar beta1_power, \
+ typename TTypes<T>::ConstScalar lr, \
+ typename TTypes<T>::ConstScalar beta1, \
+ typename TTypes<T>::ConstScalar beta2, \
+ typename TTypes<T>::ConstScalar epsilon, \
+ typename TTypes<T>::ConstFlat grad); \
+ extern template struct ApplyAdaMax<GPUDevice, T>;
+DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half);
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_CPU_KERNELS
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
class ApplyRMSPropOp : public OpKernel {
public:
explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {