aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/momentum.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/momentum.py')
-rw-r--r--tensorflow/python/training/momentum.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
new file mode 100644
index 0000000000..fdd434359f
--- /dev/null
+++ b/tensorflow/python/training/momentum.py
@@ -0,0 +1,51 @@
+"""Momentum for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class MomentumOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the Momentum algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, momentum,
+ use_locking=False, name="Momentum"):
+ """Construct a new Momentum optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or a floating point value. The learning rate.
+ momentum: A `Tensor` or a floating point value. The momentum.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Momentum".
+ """
+ super(MomentumOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+ self._momentum = momentum
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._zeros_slot(v, "momentum", self._name)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
+ self._momentum_tensor = ops.convert_to_tensor(self._momentum,
+ name="momentum")
+
+ def _apply_dense(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var, mom,
+ self._learning_rate_tensor, grad, self._momentum_tensor,
+ use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var, mom,
+ self._learning_rate_tensor, grad.values, grad.indices,
+ self._momentum_tensor, use_locking=self._use_locking).op