diff options
author | Yifei Feng <yifeif@google.com> | 2018-04-23 21:19:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 21:21:38 -0700 |
commit | 22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch) | |
tree | d16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/core/kernels/training_ops.h | |
parent | 24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff) |
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/core/kernels/training_ops.h')
-rw-r--r-- | tensorflow/core/kernels/training_ops.h | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index 7ee956053a..f536a61eb0 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -140,6 +140,18 @@ struct ApplyAdam { }; template <typename Device, typename T> +struct ApplyAdaMax { + void operator()(const Device& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, + typename TTypes<T>::ConstScalar beta1_power, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar beta1, + typename TTypes<T>::ConstScalar beta2, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstFlat grad); +}; + +template <typename Device, typename T> struct ApplyRMSProp { void operator()(const Device& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, |