diff options
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/momentum.py')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/momentum.py | 69 |
1 files changed, 12 insertions, 57 deletions
diff --git a/tensorflow/contrib/optimizer_v2/momentum.py b/tensorflow/contrib/optimizer_v2/momentum.py index 0a5aadc2d1..0636f7e356 100644 --- a/tensorflow/contrib/optimizer_v2/momentum.py +++ b/tensorflow/contrib/optimizer_v2/momentum.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.optimizer_v2 import optimizer_v2 -from tensorflow.python.training import training_ops +from tensorflow.python.keras.optimizer_v2 import sgd +from tensorflow.python.util import deprecation -class MomentumOptimizer(optimizer_v2.OptimizerV2): +class MomentumOptimizer(sgd.SGD): """Optimizer that implements the Momentum algorithm. Computes (if `use_nesterov = False`): @@ -39,6 +39,10 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2): when that part of the variable was used in the forward pass. """ + @deprecation.deprecated_args( + "2018-10-01", + "`use_locking = True` is no longer supported and will be ignored.", + ("use_locking", [False])) def __init__(self, learning_rate, momentum, use_locking=False, name="Momentum", use_nesterov=False): """Construct a new Momentum optimizer. @@ -68,57 +72,8 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2): optimizer functions. @end_compatibility """ - super(MomentumOptimizer, self).__init__(use_locking, name) - self._set_hyper("learning_rate", learning_rate) - self._set_hyper("momentum", momentum) - self._use_nesterov = use_nesterov - - def _create_vars(self, var_list, state): - for v in var_list: - state.zeros_slot(v, "momentum") - - def _apply_dense(self, grad, var, state): - mom = state.get_slot(var, "momentum") - return training_ops.apply_momentum( - var, - mom, - state.get_hyper("learning_rate", var.dtype.base_dtype), - grad, - state.get_hyper("momentum", var.dtype.base_dtype), - use_locking=self._use_locking, - use_nesterov=self._use_nesterov).op - - def _resource_apply_dense(self, grad, var, state): - mom = state.get_slot(var, "momentum") - return training_ops.resource_apply_momentum( - var.handle, - mom.handle, - state.get_hyper("learning_rate", var.dtype.base_dtype), - grad, - state.get_hyper("momentum", var.dtype.base_dtype), - use_locking=self._use_locking, - use_nesterov=self._use_nesterov) - - def _apply_sparse(self, grad, var, state): - mom = state.get_slot(var, "momentum") - return training_ops.sparse_apply_momentum( - var, - mom, - state.get_hyper("learning_rate", var.dtype.base_dtype), - grad.values, - grad.indices, - state.get_hyper("momentum", var.dtype.base_dtype), - use_locking=self._use_locking, - use_nesterov=self._use_nesterov).op - - def _resource_apply_sparse(self, grad, var, indices, state): - mom = state.get_slot(var, "momentum") - return training_ops.resource_sparse_apply_momentum( - var.handle, - mom.handle, - state.get_hyper("learning_rate", var.dtype.base_dtype), - grad, - indices, - state.get_hyper("momentum", var.dtype.base_dtype), - use_locking=self._use_locking, - use_nesterov=self._use_nesterov) + super(MomentumOptimizer, self).__init__( + learning_rate=learning_rate, + momentum=momentum, + name=name, + nesterov=use_nesterov) |