aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 11:10:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 11:25:22 -0700
commit4cd79b3f6361b6518463349a51fe33f7520f3b49 (patch)
tree2228bd30766c9aa26542f44a59fc26d3b9044bd0
parent0e9af928f7a6711971ade159a511da093f307a81 (diff)
Fix LazyAdamOptimizer for sparse updates on resource variables.
PiperOrigin-RevId: 211488610
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py63
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py17
2 files changed, 63 insertions, 17 deletions
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index 72117c1e81..f026f437dc 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -25,9 +25,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
@@ -46,7 +48,12 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
may lead to different empirical results.
"""
- def _apply_sparse(self, grad, var):
+ def _apply_sparse_shared(self,
+ grad,
+ var,
+ indices,
+ scatter_update,
+ scatter_sub):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
@@ -58,23 +65,51 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
m = self.get_slot(var, "m")
- m_t = state_ops.scatter_update(m, grad.indices,
- beta1_t * array_ops.gather(m, grad.indices) +
- (1 - beta1_t) * grad.values,
- use_locking=self._use_locking)
+ m_t = scatter_update(m, indices,
+ beta1_t * array_ops.gather(m, indices) +
+ (1 - beta1_t) * grad)
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
v = self.get_slot(var, "v")
- v_t = state_ops.scatter_update(v, grad.indices,
- beta2_t * array_ops.gather(v, grad.indices) +
- (1 - beta2_t) * math_ops.square(grad.values),
- use_locking=self._use_locking)
+ v_t = scatter_update(v, indices,
+ beta2_t * array_ops.gather(v, indices) +
+ (1 - beta2_t) * math_ops.square(grad))
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
- m_t_slice = array_ops.gather(m_t, grad.indices)
- v_t_slice = array_ops.gather(v_t, grad.indices)
+ m_t_slice = array_ops.gather(m_t, indices)
+ v_t_slice = array_ops.gather(v_t, indices)
denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
- var_update = state_ops.scatter_sub(var, grad.indices,
- lr * m_t_slice / denominator_slice,
- use_locking=self._use_locking)
+ var_update = scatter_sub(var, indices,
+ lr * m_t_slice / denominator_slice)
return control_flow_ops.group(var_update, m_t, v_t)
+
+ def _apply_sparse(self, grad, var):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ self._scatter_update,
+ self._scatter_sub)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._apply_sparse_shared(
+ grad, var, indices,
+ self._resource_scatter_update,
+ self._resource_scatter_sub)
+
+ # Utility functions for updating resource or non-resource variables.
+ def _scatter_update(self, x, i, v):
+ return state_ops.scatter_update(
+ x, i, v, use_locking=self._use_locking)
+
+ def _scatter_sub(self, x, i, v):
+ return state_ops.scatter_sub(
+ x, i, v, use_locking=self._use_locking)
+
+ def _resource_scatter_update(self, x, i, v):
+ update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v)
+ with ops.control_dependencies([update_op]):
+ return x.value()
+
+ def _resource_scatter_sub(self, x, i, v):
+ sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v)
+ with ops.control_dependencies([sub_op]):
+ return x.value()
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index dc4c462ce4..d3e9e89502 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -51,7 +52,7 @@ def adam_update_numpy(param,
class AdamOptimizerTest(test.TestCase):
- def testSparse(self):
+ def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
# Initialize variables for numpy implementation.
@@ -61,8 +62,12 @@ class AdamOptimizerTest(test.TestCase):
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
@@ -94,6 +99,12 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ def testSparse(self):
+ self.doTestSparse(use_resource=False)
+
+ def testResourceSparse(self):
+ self.doTestSparse(use_resource=True)
+
def testSparseDevicePlacement(self):
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(force_gpu=test.is_gpu_available()):