aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2/momentum.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/momentum.py')
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum.py69
1 files changed, 12 insertions, 57 deletions
diff --git a/tensorflow/contrib/optimizer_v2/momentum.py b/tensorflow/contrib/optimizer_v2/momentum.py
index 0a5aadc2d1..0636f7e356 100644
--- a/tensorflow/contrib/optimizer_v2/momentum.py
+++ b/tensorflow/contrib/optimizer_v2/momentum.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class MomentumOptimizer(optimizer_v2.OptimizerV2):
+class MomentumOptimizer(sgd.SGD):
"""Optimizer that implements the Momentum algorithm.
Computes (if `use_nesterov = False`):
@@ -39,6 +39,10 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
when that part of the variable was used in the forward pass.
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, momentum,
use_locking=False, name="Momentum", use_nesterov=False):
"""Construct a new Momentum optimizer.
@@ -68,57 +72,8 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
optimizer functions.
@end_compatibility
"""
- super(MomentumOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("momentum", momentum)
- self._use_nesterov = use_nesterov
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
-
- def _apply_sparse(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.sparse_apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_sparse_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
+ super(MomentumOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ momentum=momentum,
+ name=name,
+ nesterov=use_nesterov)