aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py')
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index 72117c1e81..f55209ec49 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -28,6 +28,7 @@ from __future__ import print_function
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
@@ -78,3 +79,36 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
lr * m_t_slice / denominator_slice,
use_locking=self._use_locking)
return control_flow_ops.group(var_update, m_t, v_t)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+ lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+ beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+ beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+ epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+
+ # \\(m := beta1 * m + (1 - beta1) * g_t\\)
+ m = self.get_slot(var, "m")
+ m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
+ m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
+ indices,
+ m_t_slice)
+
+ # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
+ v = self.get_slot(var, "v")
+ v_t_slice = (beta2_t * array_ops.gather(v, indices) +
+ (1 - beta2_t) * math_ops.square(grad))
+ v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
+ indices,
+ v_t_slice)
+
+ # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
+ var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
+ var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
+ indices,
+ var_slice)
+
+ return control_flow_ops.group(var_update_op, m_update_op, v_update_op)