aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/adam.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/adam.py')
-rw-r--r--tensorflow/python/training/adam.py33
1 files changed, 25 insertions, 8 deletions
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 111461f784..4b0ef50df5 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import optimizer
@@ -154,7 +155,7 @@ class AdamOptimizer(optimizer.Optimizer):
math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
grad, use_locking=self._use_locking)
- def _apply_sparse(self, grad, var):
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add):
beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
@@ -164,23 +165,39 @@ class AdamOptimizer(optimizer.Optimizer):
lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
- m_scaled_g_values = grad.values * (1 - beta1_t)
+ m_scaled_g_values = grad * (1 - beta1_t)
m_t = state_ops.assign(m, m * beta1_t,
use_locking=self._use_locking)
- m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values,
- use_locking=self._use_locking)
+ with ops.control_dependencies([m_t]):
+ m_t = scatter_add(m, indices, m_scaled_g_values)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
- v_scaled_g_values = (grad.values * grad.values) * (1 - beta2_t)
- v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
- v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values,
- use_locking=self._use_locking)
+ v_scaled_g_values = (grad * grad) * (1 - beta2_t)
+ v_t = state_ops.assign(v, v * beta2_t)
+ with ops.control_dependencies([v_t]):
+ v_t = scatter_add(v, indices, v_scaled_g_values)
v_sqrt = math_ops.sqrt(v_t)
var_update = state_ops.assign_sub(var,
lr * m_t / (v_sqrt + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
+ def _apply_sparse(self, grad, var):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x, i, v, use_locking=self._use_locking))
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(
+ x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._apply_sparse_shared(
+ grad, var, indices, self._resource_scatter_add)
+
def _finish(self, update_ops, name_scope):
# Update the power accumulators.
with ops.control_dependencies(update_ops):