diff options
Diffstat (limited to 'tensorflow/python/training/adam.py')
-rw-r--r-- | tensorflow/python/training/adam.py | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 111461f784..4b0ef50df5 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.training import optimizer @@ -154,7 +155,7 @@ class AdamOptimizer(optimizer.Optimizer): math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), grad, use_locking=self._use_locking) - def _apply_sparse(self, grad, var): + def _apply_sparse_shared(self, grad, var, indices, scatter_add): beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype) lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) @@ -164,23 +165,39 @@ class AdamOptimizer(optimizer.Optimizer): lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) # m_t = beta1 * m + (1 - beta1) * g_t m = self.get_slot(var, "m") - m_scaled_g_values = grad.values * (1 - beta1_t) + m_scaled_g_values = grad * (1 - beta1_t) m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) - m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values, - use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") - v_scaled_g_values = (grad.values * grad.values) * (1 - beta2_t) - v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) - v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values, - use_locking=self._use_locking) + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) v_sqrt = math_ops.sqrt(v_t) var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) return control_flow_ops.group(*[var_update, m_t, v_t]) + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add) + def _finish(self, update_ops, name_scope): # Update the power accumulators. with ops.control_dependencies(update_ops): |