aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/training_ops_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/training_ops_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/training_ops_gpu.cu.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index ab56880cfb..589e70e76d 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -17,8 +17,8 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/training_ops.h"
+#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
@@ -84,12 +84,18 @@ struct ApplyMomentum<GPUDevice, T> {
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
typename TTypes<T>::ConstFlat grad,
- typename TTypes<T>::ConstScalar momentum) {
+ typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
- var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
+ if (use_nesterov) {
+ var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
+ accum * momentum.reshape(single).broadcast(bcast) *
+ lr.reshape(single).broadcast(bcast);
+ } else {
+ var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
+ }
}
};