diff options
Diffstat (limited to 'tensorflow/python/training/momentum.py')
-rw-r--r-- | tensorflow/python/training/momentum.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py index 1586ddfdec..62f8028ce6 100644 --- a/tensorflow/python/training/momentum.py +++ b/tensorflow/python/training/momentum.py @@ -31,7 +31,7 @@ class MomentumOptimizer(optimizer.Optimizer): """ def __init__(self, learning_rate, momentum, - use_locking=False, name="Momentum"): + use_locking=False, name="Momentum", use_nesterov=False): """Construct a new Momentum optimizer. Args: @@ -44,6 +44,7 @@ class MomentumOptimizer(optimizer.Optimizer): super(MomentumOptimizer, self).__init__(use_locking, name) self._learning_rate = learning_rate self._momentum = momentum + self._use_nesterov = use_nesterov def _create_slots(self, var_list): for v in var_list: @@ -62,7 +63,8 @@ class MomentumOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), grad, math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), - use_locking=self._use_locking).op + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op def _apply_sparse(self, grad, var): mom = self.get_slot(var, "momentum") @@ -71,4 +73,5 @@ class MomentumOptimizer(optimizer.Optimizer): math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), grad.values, grad.indices, math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), - use_locking=self._use_locking).op + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op |