diff options
Diffstat (limited to 'tensorflow/core/kernels/training_ops_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/training_ops_gpu.cu.cc | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index ab56880cfb..589e70e76d 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -17,8 +17,8 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/training_ops.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -84,12 +84,18 @@ struct ApplyMomentum<GPUDevice, T> { typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstFlat grad, - typename TTypes<T>::ConstScalar momentum) { + typename TTypes<T>::ConstScalar momentum, bool use_nesterov) { Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; bcast[0] = grad.dimension(0); Eigen::Sizes<1> single; accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad; - var.device(d) -= lr.reshape(single).broadcast(bcast) * accum; + if (use_nesterov) { + var.device(d) -= grad * lr.reshape(single).broadcast(bcast) + + accum * momentum.reshape(single).broadcast(bcast) * + lr.reshape(single).broadcast(bcast); + } else { + var.device(d) -= lr.reshape(single).broadcast(bcast) * accum; + } } }; |