aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/momentum.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/momentum.py')
-rw-r--r--tensorflow/python/training/momentum.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
index 1586ddfdec..62f8028ce6 100644
--- a/tensorflow/python/training/momentum.py
+++ b/tensorflow/python/training/momentum.py
@@ -31,7 +31,7 @@ class MomentumOptimizer(optimizer.Optimizer):
"""
def __init__(self, learning_rate, momentum,
- use_locking=False, name="Momentum"):
+ use_locking=False, name="Momentum", use_nesterov=False):
"""Construct a new Momentum optimizer.
Args:
@@ -44,6 +44,7 @@ class MomentumOptimizer(optimizer.Optimizer):
super(MomentumOptimizer, self).__init__(use_locking, name)
self._learning_rate = learning_rate
self._momentum = momentum
+ self._use_nesterov = use_nesterov
def _create_slots(self, var_list):
for v in var_list:
@@ -62,7 +63,8 @@ class MomentumOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
grad,
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
- use_locking=self._use_locking).op
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
def _apply_sparse(self, grad, var):
mom = self.get_slot(var, "momentum")
@@ -71,4 +73,5 @@ class MomentumOptimizer(optimizer.Optimizer):
math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
grad.values, grad.indices,
math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
- use_locking=self._use_locking).op
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op