aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 12:44:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:49:14 -0700
commitef838969b95de39353a3ba495c335cbb14a0c9b5 (patch)
tree800857a506c3d3695a7b3da2fd269a9fec85d93b /tensorflow/python
parent6919ab5787e6384d709adf051dc1ce99236b76bc (diff)
Brings V2 Optimizers into Keras w/ Keras signatures
PiperOrigin-RevId: 215950207
Diffstat (limited to 'tensorflow/python')
-rwxr-xr-xtensorflow/python/keras/BUILD155
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta.py116
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta_test.py166
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad.py119
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad_test.py276
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam.py203
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam_test.py333
-rw-r--r--tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py761
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2.py1349
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py277
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop.py239
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop_test.py444
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd.py170
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd_test.py759
14 files changed, 5367 insertions, 0 deletions
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 4a72c4b3f3..c4d23f117f 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -62,6 +62,7 @@ py_library(
":backend",
":engine",
":layers",
+ ":optimizer_v2",
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
@@ -189,6 +190,30 @@ py_library(
],
)
+py_library(
+ name = "optimizer_v2",
+ srcs = [
+ "optimizer_v2/adadelta.py",
+ "optimizer_v2/adagrad.py",
+ "optimizer_v2/adam.py",
+ "optimizer_v2/optimizer_v2.py",
+ "optimizer_v2/rmsprop.py",
+ "optimizer_v2/sgd.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distribute",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
py_test(
name = "integration_test",
size = "medium",
@@ -827,3 +852,133 @@ py_library(
"//third_party/py/numpy",
],
)
+
+cuda_py_test(
+ name = "adadelta_test",
+ size = "medium",
+ srcs = ["optimizer_v2/adadelta_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adagrad_test",
+ size = "small",
+ srcs = ["optimizer_v2/adagrad_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adam_test",
+ size = "small",
+ srcs = ["optimizer_v2/adam_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "checkpointable_utils_test",
+ srcs = ["optimizer_v2/checkpointable_utils_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@six_archive//:six",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:layers_base",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/keras",
+ ],
+ tags = ["notsan"],
+)
+
+cuda_py_test(
+ name = "sgd_test",
+ size = "medium",
+ srcs = ["optimizer_v2/sgd_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:resources",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "optimizer_v2_test",
+ size = "medium",
+ srcs = ["optimizer_v2/optimizer_v2_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:clip_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
+ name = "rmsprop_test",
+ size = "small",
+ srcs = ["optimizer_v2/rmsprop_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+ tags = ["optonly"],
+)
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
new file mode 100644
index 0000000000..d3b3c9c12e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -0,0 +1,116 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adadelta for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.training import training_ops
+
+
+class Adadelta(optimizer_v2.OptimizerV2):
+ """Adadelta optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
+ ([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate. It is recommended
+ to leave it at the default value.
+ rho: float hyperparameter >= 0. The decay rate.
+ epsilon: float hyperparameter >= 0. Fuzz factor. A constant epsilon used
+ to better condition the grad update.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adadelta'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.95,
+ epsilon=1e-8,
+ name="Adadelta"):
+ super(Adadelta, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("epsilon", epsilon)
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ state.zeros_slot(v, "accum")
+ state.zeros_slot(v, "accum_update")
+
+ def _apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.sparse_apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_sparse_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
new file mode 100644
index 0000000000..6e48f92e4f
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
@@ -0,0 +1,166 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adadelta Optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdadeltaOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ num_updates = 4 # number of ADADELTA steps to perform
+ for dtype in [dtypes.half, dtypes.float32]:
+ for grad in [0.2, 0.1, 0.01]:
+ for lr in [1.0, 0.5, 0.1]:
+ with self.cached_session():
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+ else:
+ var0 = variables.Variable(var0_init, dtype=dtype)
+ var1 = variables.Variable(var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ adadelta_opt = adadelta.Adadelta(lr, rho, epsilon)
+ adadelta_update = adadelta_opt.apply_gradients(
+ zip([grads, grads], [var0, var1]))
+
+ opt_vars = adadelta_opt.variables()
+ self.assertStartsWith(opt_vars[0].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[1].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[2].name, var1._shared_name)
+ self.assertStartsWith(opt_vars[3].name, var1._shared_name)
+ self.assertEqual(4, len(opt_vars))
+
+ variables.global_variables_initializer().run()
+
+ # Assign slots
+ slot = [None] * 2
+ slot_update = [None] * 2
+ self.assertEqual(["accum", "accum_update"],
+ adadelta_opt.get_slot_names())
+ slot[0] = adadelta_opt.get_slot(var0, "accum")
+ self.assertEquals(slot[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot[0] in variables.trainable_variables())
+
+ slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
+ self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot_update[0] in variables.trainable_variables())
+
+ slot[1] = adadelta_opt.get_slot(var1, "accum")
+ self.assertEquals(slot[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot[1] in variables.trainable_variables())
+
+ slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
+ self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot_update[1] in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, var0.eval())
+ self.assertAllClose(var1_init, var1.eval())
+
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ adadelta_update.run()
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (accum_update * rho + (update[step]**2) *
+ (1.0 - rho))
+ tot_update += update[step] * lr
+
+ # Check that the accumulators have been updated
+ for slot_idx in range(2):
+ self.assertAllCloseAccordingToType(
+ np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
+ slot[slot_idx].eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [accum_update, accum_update],
+ dtype=dtype.as_numpy_dtype()),
+ slot_update[slot_idx].eval(),
+ rtol=1e-5)
+
+ # Check that the parameters have been updated
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var0_init[0] - tot_update, var0_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var0.eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var1_init[0] - tot_update, var1_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var1.eval(),
+ rtol=1e-5)
+
+ def testBasic(self):
+ self.doTestBasic(use_resource=False)
+
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adadelta.Adadelta(1.0, 1.0, 1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-111, -138]], var0.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
new file mode 100644
index 0000000000..2d8cec2300
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -0,0 +1,119 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adagrad optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import training_ops
+
+
+class Adagrad(optimizer_v2.OptimizerV2):
+ """Adagrad optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
+ or this
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
+
+ The learning_rate arg below is a hyperparameter, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ initial_accumulator_value: A floating point value. Starting value for the
+ accumulators, must be positive.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adagrad'.
+
+ Raises:
+ ValueError: If the `initial_accumulator_value` is invalid.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ initial_accumulator_value=0.1,
+ name="Adagrad"):
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value must be positive: %s" %
+ initial_accumulator_value)
+ super(Adagrad, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+
+ self._initial_accumulator_value = initial_accumulator_value
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ dtype = v.dtype.base_dtype
+ if v.get_shape().is_fully_defined():
+ init = init_ops.constant_initializer(self._initial_accumulator_value,
+ dtype=dtype)
+ else:
+ def init(v=v, dtype=dtype):
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
+ state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
+ "accumulator")
+
+ def _apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.sparse_apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_sparse_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
new file mode 100644
index 0000000000..fc4ef5c399
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
@@ -0,0 +1,276 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for aggregate operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdagradOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testBasic(self):
+ self.doTestBasic()
+
+ def testBasicResource(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable(
+ [[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adagrad.Adagrad(1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType(
+ [[1.0, 2.0], [3.0, 4.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0, 1], [3, 4]], var0.eval(), atol=0.01)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(
+ constant_op.constant(3.0), initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([[-1.6026098728179932], [2.0]]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[3.0], [3.715679168701172]]), var1.eval())
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def testSparseRepeatedIndicesResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var_repeated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_repeated = math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_repeated, [0, 0]))
+ var_aggregated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_aggregated = 2 * math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_aggregated, [0]))
+ update_op_repeated = adagrad.Adagrad(2.0).minimize(loss_repeated)
+ update_op_aggregated = adagrad.Adagrad(2.0).minimize(loss_aggregated)
+ variables.global_variables_initializer().run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+ for _ in range(3):
+ update_op_repeated.run()
+ update_op_aggregated.run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+
+ def testSparseStability(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ shape = [1, 6]
+ var0 = variables.Variable(
+ [[
+ 0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257,
+ -0.0105945
+ ]],
+ dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[
+ -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
+ -8.4877e-05, -9.48906e-05
+ ]],
+ shape=shape,
+ dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant(shape))
+ ada_opt = adagrad.Adagrad(1.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ init = variables.global_variables_initializer()
+ for _ in range(100):
+ init.run()
+ ada_update.run()
+ self.assertAllCloseAccordingToType(
+ np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[
+ 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573,
+ -0.01029443
+ ]]), var0.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0)
+ # Apply the optimizer twice. Both applications will use
+ # the same accums.
+ ada_update1 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ ada_update2 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ ada_update1.run()
+ ada_update2.run()
+ ada_update1.run()
+ # Validate updated params (the same as with only 1 Adagrad).
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testDynamicShapeVariable_Ok(self):
+ with self.cached_session():
+ v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
+ validate_shape=False)
+ self.assertFalse(v.shape.is_fully_defined())
+ # Creating optimizer should cause no exception.
+ adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
new file mode 100644
index 0000000000..8367228d7a
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -0,0 +1,203 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adam optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import training_ops
+
+
+class Adam(optimizer_v2.OptimizerV2):
+ r"""Adam Optimizer.
+
+ Default parameters follow those provided in the original paper.
+
+ See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+ ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+
+ Some of the args below are hyperparameters where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Initialization:
+
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ $$t := t + 1$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
+ formulation just before Section 2.1 of the Kingma and Ba paper rather than
+ the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
+ hat" in the paper.
+
+ The sparse implementation of this algorithm (used when the gradient is an
+ IndexedSlices object, typically because of `tf.gather` or an embedding
+ lookup in the forward pass) does apply momentum to variable slices even if
+ they were not used in the forward pass (meaning they have a gradient equal
+ to zero). Momentum decay (beta1) is also applied to the entire momentum
+ accumulator. This means that the sparse behavior is equivalent to the dense
+ behavior (in contrast to some momentum implementations which ignore momentum
+ unless a variable slice was actually used).
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ beta_1: float hyperparameter, 0 < beta_1 < 1. Generally close to 1. The
+ exponential decay rate for the 1st moment estimates.
+ beta_2: float hyperparameter, 0 < beta_2 < 1. Generally close to 1. The
+ exponential decay rate for the 2nd moment estimates.
+ epsilon: float hyperparameter >= 0. Fuzz factor. This epsilon is "epsilon
+ hat" in the Kingma and Ba paper (in the formula just before Section
+ 2.1), not the epsilon in Algorithm 1 of the paper.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-8,
+ name="Adam"):
+ super(Adam, self).__init__(name)
+
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("beta_1", beta_1)
+ self._set_hyper("beta_2", beta_2)
+ self._set_hyper("epsilon", epsilon)
+
+ def _get_beta_accumulators(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return (state.get_non_slot("beta_1_power"),
+ state.get_non_slot("beta_2_power"))
+
+ def _create_vars(self, var_list, state):
+ # Non-slot variables end up on the same device(s).
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_1"), name="beta_1_power")
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_2"), name="beta_2_power")
+
+ # Create slots for the first and second moments.
+ for v in var_list:
+ state.zeros_slot(v, "m")
+ state.zeros_slot(v, "v")
+
+ def _apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.apply_adam(
+ var,
+ m,
+ v,
+ math_ops.cast(beta_1_power, var.dtype.base_dtype),
+ math_ops.cast(beta_2_power, var.dtype.base_dtype),
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("beta_1", var.dtype.base_dtype),
+ state.get_hyper("beta_2", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.resource_apply_adam(
+ var.handle,
+ m.handle,
+ v.handle,
+ math_ops.cast(beta_1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta_2_power, grad.dtype.base_dtype),
+ state.get_hyper("learning_rate", grad.dtype.base_dtype),
+ state.get_hyper("beta_1", grad.dtype.base_dtype),
+ state.get_hyper("beta_2", grad.dtype.base_dtype),
+ state.get_hyper("epsilon", grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ beta_1_power = math_ops.cast(beta_1_power, var.dtype.base_dtype)
+ beta_2_power = math_ops.cast(beta_2_power, var.dtype.base_dtype)
+ lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
+ beta_1_t = state.get_hyper("beta_1", var.dtype.base_dtype)
+ beta_2_t = state.get_hyper("beta_2", var.dtype.base_dtype)
+ epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
+ # m_t = beta_1 * m + (1 - beta_1) * g_t
+ m = state.get_slot(var, "m")
+ m_scaled_g_values = grad * (1 - beta_1_t)
+ m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
+ with ops.control_dependencies([m_t]):
+ m_t = scatter_add(m, indices, m_scaled_g_values)
+ # v_t = beta_2 * v + (1 - beta_2) * (g_t * g_t)
+ v = state.get_slot(var, "v")
+ v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
+ v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
+ with ops.control_dependencies([v_t]):
+ v_t = scatter_add(v, indices, v_scaled_g_values)
+ v_sqrt = math_ops.sqrt(v_t)
+ var_update = state_ops.assign_sub(var,
+ lr * m_t / (v_sqrt + epsilon_t),
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _apply_sparse(self, grad, var, state):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x, i, v, use_locking=self._use_locking),
+ state)
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(
+ x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ return self._apply_sparse_shared(
+ grad, var, indices, self._resource_scatter_add, state)
+
+ def _finish(self, state):
+ # Update the power accumulators.
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ update_beta_1 = beta_1_power.assign(
+ beta_1_power * state.get_hyper("beta_1"), use_locking=self._use_locking)
+ update_beta_2 = beta_2_power.assign(
+ beta_2_power * state.get_hyper("beta_2"), use_locking=self._use_locking)
+ return control_flow_ops.group(update_beta_1, update_beta_2)
diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py
new file mode 100644
index 0000000000..77796317a1
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam_test.py
@@ -0,0 +1,333 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adam optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def adam_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ alpha=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8):
+ alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+
+class AdamOptimizerTest(test.TestCase):
+
+ def doTestSparse(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([2]))
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSparse(self):
+ self.doTestSparse(use_resource=False)
+
+ def testResourceSparse(self):
+ self.doTestSparse(use_resource=True)
+
+ def testSparseDevicePlacement(self):
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(force_gpu=test.is_gpu_available()):
+ # If a GPU is available, tests that all optimizer ops can be placed on
+ # it (i.e. they have GPU kernels).
+ var = variables.Variable([[1.0], [2.0]])
+ indices = constant_op.constant([0, 1], dtype=index_dtype)
+ gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
+ optimizer = adam.Adam(3.0)
+ minimize_op = optimizer.minimize(gathered_sum)
+ variables.global_variables_initializer().run()
+ minimize_op.run()
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adam.Adam().apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adam.Adam().apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def doTestBasic(self, use_resource=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertTrue(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = adam.Adam()
+ g = ops.Graph()
+ with g.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = adam.Adam(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
new file mode 100644
index 0000000000..338c04148b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
@@ -0,0 +1,761 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(josh11b): Forked from contrib/eager/python to test OptimizerV2 the same way
+# OptimizerV1 is tested. This file should be removed once the fork is resolved.
+
+import functools
+import os
+
+import six
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as core_saver
+from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
+
+
+class NonLayerCheckpointable(tracking.Checkpointable):
+
+ def __init__(self):
+ super(NonLayerCheckpointable, self).__init__()
+ self.a_variable = util.add_variable(
+ self, name="a_variable", shape=[])
+
+
+# pylint: disable=not-callable
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+ # We can still track Checkpointables which aren't Layers.
+ self._non_layer = NonLayerCheckpointable()
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
+class _MirroringSaveable(
+ core_saver.BaseSaverBuilder.ResourceVariableSaveable):
+
+ def __init__(self, primary_variable, mirrored_variable, name):
+ self._primary_variable = primary_variable
+ self._mirrored_variable = mirrored_variable
+ super(_MirroringSaveable, self).__init__(
+ self._primary_variable, "", name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group(
+ self._primary_variable.assign(tensor),
+ self._mirrored_variable.assign(tensor))
+
+
+class CheckpointingTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testNamingWithOptimizer(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ # A nuisance Model using the same optimizer. Its slot variables should not
+ # go in the checkpoint, since it is never depended on.
+ other_model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value),
+ global_step=optimizer_step)
+ optimizer.minimize(
+ lambda: other_model(input_value),
+ global_step=optimizer_step)
+ else:
+ train_op = optimizer.minimize(
+ model(input_value), global_step=optimizer_step)
+ optimizer.minimize(
+ other_model(input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ named_variables, serialized_graph, _ = (
+ util._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
+ expected_checkpoint_names = (
+ # Created in the root node, so no prefix.
+ "optimizer_step",
+ "model/_second/kernel",
+ "model/_named_dense/kernel",
+ "model/_named_dense/bias",
+ # non-Layer dependency of the model
+ "model/_non_layer/a_variable",
+ # The optimizer creates two non-slot variables
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
+ # Slot variables
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
+ )
+ suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
+ expected_checkpoint_names = [
+ name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
+ six.assertCountEqual(self, expected_checkpoint_names,
+ named_variables.keys())
+ # Check that we've mapped to the right variable objects (not exhaustive)
+ self.assertEqual(
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
+ self.assertEqual(
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
+ self.assertEqual(
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
+ # Spot check the generated protocol buffers.
+ self.assertEqual("optimizer",
+ serialized_graph.nodes[0].children[1].local_name)
+ optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
+ 1].node_id]
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .original_variable_node_id]
+ .attributes[0].full_name)
+ # We strip off the :0 suffix, as variable.name-based saving does.
+ self.assertEqual(
+ "my_model/dense/kernel/Adam",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .slot_variable_node_id]
+ .attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel/Adam:0",
+ optimizer.get_slot(
+ var=model._named_dense.kernel,
+ name="m").name)
+ self.assertEqual(
+ "model/_named_dense/kernel" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .original_variable_node_id].attributes[0].checkpoint_key)
+ self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
+ self.assertEqual(
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .slot_variable_node_id].attributes[0].checkpoint_key)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestore(self):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model)
+ input_value = constant_op.constant([[3.]])
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value))
+ else:
+ train_op = optimizer.minimize(model(input_value))
+ # TODO(allenl): Make initialization more pleasant when graph building.
+ root_checkpointable.save_counter # pylint: disable=pointless-statement
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
+ m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
+ self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
+ save_path = root_checkpointable.save(file_prefix=prefix)
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
+ self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
+ optimizer_variables = self.evaluate(optimizer.variables())
+ self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
+ # Immediate restoration
+ status = root_checkpointable.restore(save_path=save_path).assert_consumed()
+ status.run_restore_ops()
+ self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
+ self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
+ self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
+ if not context.executing_eagerly():
+ return # Restore-on-create is only supported when executing eagerly
+ on_create_model = MyModel()
+ on_create_optimizer = adam.Adam(
+ 0.001,
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta_1=1.0,
+ beta_2=1.0)
+ on_create_root = util.Checkpoint(
+ optimizer=on_create_optimizer, model=on_create_model)
+ # Deferred restoration
+ status = on_create_root.restore(save_path=save_path)
+ on_create_model(constant_op.constant([[3.]])) # create variables
+ self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
+ self.assertAllEqual([42.],
+ self.evaluate(
+ on_create_model._named_dense.variables[1]))
+ on_create_m_bias_slot = on_create_optimizer.get_slot(
+ on_create_model._named_dense.variables[1], "m")
+ # Optimizer slot variables are created when the original variable is
+ # restored.
+ self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
+ self.assertAllEqual(optimizer_variables[2:],
+ self.evaluate(on_create_optimizer.variables()))
+ dummy_var = resource_variable_ops.ResourceVariable([1.])
+ on_create_optimizer.minimize(loss=dummy_var.read_value)
+ status.assert_consumed()
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
+
+ # TODO(allenl): Debug garbage created by this test in python3.
+ def testDeferredRestorationUsageEager(self):
+ """An idiomatic eager execution example."""
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ optimizer_step=training_util.get_or_create_global_step())
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
+ for _ in range(num_training_steps):
+ # TODO(allenl): Use a Dataset and serialize/checkpoint it.
+ input_value = constant_op.constant([[3.]])
+ optimizer.minimize(
+ lambda: model(input_value), # pylint: disable=cell-var-from-loop
+ global_step=root.optimizer_step)
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ root.optimizer_step.numpy())
+
+ def testUsageGraph(self):
+ """Expected usage when graph building."""
+ with context.graph_mode():
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default():
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ input_value = constant_op.constant([[3.]])
+ train_op = optimizer.minimize(
+ model(input_value),
+ global_step=root.global_step)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ with self.session(graph=ops.get_default_graph()) as session:
+ status = root.restore(save_path=checkpoint_path)
+ status.initialize_or_restore(session=session)
+ if checkpoint_path is None:
+ self.assertEqual(0, training_continuation)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+ else:
+ status.assert_consumed()
+ for _ in range(num_training_steps):
+ session.run(train_op)
+ root.save(file_prefix=checkpoint_prefix, session=session)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ session.run(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ session.run(root.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAgnosticUsage(self):
+ """Graph/eager agnostic usage."""
+ # Does create garbage when executing eagerly due to ops.Graph() creation.
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(model, input_value),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+
+ # pylint: disable=cell-var-from-loop
+ @test_util.run_in_graph_and_eager_modes
+ def testWithDefun(self):
+ num_training_steps = 2
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ # Don't actually train so we can test variable values
+ optimizer = adam.Adam(0.)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ def train_fn():
+ @function.defun
+ def _call_model(x):
+ return model(x)
+ with backprop.GradientTape() as tape:
+ loss = _call_model(constant_op.constant([[3.]]))
+ gradients = tape.gradient(loss, model.variables)
+ return optimizer.apply_gradients(zip(gradients, model.variables),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(
+ self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ if training_continuation > 0:
+ status.assert_consumed()
+ self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
+ else:
+ self.evaluate(model.variables[0].assign([[42.]]))
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+ # pylint: enable=cell-var-from-loop
+
+ def testAnonymousVarsInInit(self):
+
+ class Model(training.Model):
+
+ def __init__(self):
+ super(Model, self).__init__()
+ self.w = resource_variable_ops.ResourceVariable(0.0)
+ self.b = resource_variable_ops.ResourceVariable(0.0)
+ self.vars = [self.w, self.b]
+
+ def call(self, x):
+ return x * self.w + self.b
+
+ with context.eager_mode():
+ model = Model()
+ optimizer = adam.Adam(learning_rate=0.05)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ checkpoint = util.Checkpoint(
+ model=model, optimizer=optimizer)
+ for _ in range(2):
+ checkpoint.save(checkpoint_prefix)
+ with backprop.GradientTape() as tape:
+ loss = (constant_op.constant(1.)
+ - model(constant_op.constant(1.))) ** 2
+ grad = tape.gradient(loss, model.vars)
+ optimizer.apply_gradients(
+ [(g, v) for g, v in zip(grad, model.vars)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDeferredSlotRestoration(self):
+ checkpoint_directory = self.get_temp_dir()
+
+ root = tracking.Checkpointable()
+ root.var = util.add_variable(
+ root, name="var", initializer=0.)
+ optimizer = adam.Adam(0.1)
+ if context.executing_eagerly():
+ optimizer.minimize(root.var.read_value)
+ else:
+ train_op = optimizer.minimize(root.var)
+ # Note that `optimizer` has not been added as a dependency of
+ # `root`. Create a one-off grouping so that slot variables for `root.var`
+ # get initialized too.
+ self.evaluate(util.gather_initializers(
+ util.Checkpoint(root=root, optimizer=optimizer)))
+ self.evaluate(train_op)
+ self.evaluate(state_ops.assign(root.var, 12.))
+ no_slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "no_slots"))
+ root.optimizer = optimizer
+ self.evaluate(state_ops.assign(root.var, 13.))
+ self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
+ 14.))
+ slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "with_slots"))
+ new_root = tracking.Checkpointable()
+ # Load the slot-containing checkpoint (deferred), then immediately overwrite
+ # the non-slot variable (also deferred).
+ slot_status = util.CheckpointableSaver(
+ new_root).restore(slots_path)
+ no_slot_status = util.CheckpointableSaver(
+ new_root).restore(no_slots_path)
+ with self.assertRaises(AssertionError):
+ no_slot_status.assert_consumed()
+ new_root.var = util.add_variable(
+ new_root, name="var", shape=[])
+ no_slot_status.assert_consumed()
+ no_slot_status.run_restore_ops()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ new_root.optimizer = adam.Adam(0.1)
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
+ slot_status.assert_consumed()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ if context.executing_eagerly():
+ # Slot variables are only created with restoring initializers when
+ # executing eagerly.
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ else:
+ self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
+ None)
+ if context.executing_eagerly():
+ new_root.optimizer.minimize(new_root.var.read_value)
+ else:
+ train_op = new_root.optimizer.minimize(new_root.var)
+ # The slot variable now exists; restore() didn't create it, but we should
+ # now have a restore op for it.
+ slot_status.run_restore_ops()
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ self.evaluate(train_op)
+ slot_status.assert_consumed()
+
+ def testManySavesGraph(self):
+ """Saves after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ saver.save(checkpoint_prefix)
+ before_ops = graph.get_operations()
+ saver.save(checkpoint_prefix)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testManyRestoresGraph(self):
+ """Restores after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ save_path = saver.save(checkpoint_prefix)
+ saver.restore(save_path)
+ before_ops = graph.get_operations()
+ saver.restore(save_path)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testMultipleGraphsNonSlotVariables(self):
+ with context.graph_mode():
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ optimizer = adam.Adam(0.001)
+ # Construct a model in one graph
+ first_graph = ops.Graph()
+ first_session = session_lib.Session(graph=first_graph)
+ with first_graph.as_default(), first_session.as_default():
+ first_variable = resource_variable_ops.ResourceVariable([1.])
+ first_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=first_variable)
+ train_op = optimizer.minimize(first_variable.read_value)
+ self.evaluate(util.gather_initializers(
+ first_root_checkpointable))
+ self.evaluate(train_op)
+ self.evaluate(first_variable.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+
+ # Save and load in a second graph
+ second_graph = ops.Graph()
+ with second_graph.as_default(), session_lib.Session(graph=second_graph):
+ second_variable = resource_variable_ops.ResourceVariable([1.])
+ second_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=second_variable)
+ train_op = optimizer.minimize(second_variable.read_value)
+ second_root_checkpointable.restore(None).initialize_or_restore()
+ self.evaluate(train_op)
+ self.evaluate(second_variable.assign([4.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([5.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
+ save_path = second_root_checkpointable.save(checkpoint_prefix)
+ self.evaluate(second_variable.assign([7.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([8.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+ status = second_root_checkpointable.restore(save_path)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([4.], self.evaluate(second_variable))
+ self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+
+ # Check that the first graph is unmolested
+ with first_graph.as_default(), first_session.as_default():
+ self.assertAllEqual([1.], self.evaluate(first_variable))
+ self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+
+class TemplateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_checkpointable_save_restore(self):
+
+ def _templated():
+ v = variable_scope.get_variable(
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ v2 = variable_scope.get_variable(
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ return v, v + 1., v2
+
+ save_template = template.make_template("s1", _templated)
+ v1_save, _, v2_save = save_template()
+ optimizer = adam.Adam(0.0)
+ save_root = util.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ self.evaluate(v1_save.assign([12.]))
+ self.evaluate(v2_save.assign([14.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_root.save(checkpoint_prefix)
+
+ load_template = template.make_template("s2", _templated)
+ load_optimizer = adam.Adam(0.0)
+ load_root = util.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
+ status = load_root.restore(save_path)
+ var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
+ self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+ self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([12.], self.evaluate(var))
+ self.assertAllEqual([13.], self.evaluate(var_plus_one))
+ self.assertAllEqual([14.], self.evaluate(var2))
+
+
+class CheckpointCompatibilityTests(test.TestCase):
+
+ def _initialized_model(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ train_op = optimizer.minimize(
+ functools.partial(model, input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ # A regular variable, a slot variable, and a non-slot Optimizer variable
+ # with known values to check when loading.
+ self.evaluate(model._named_dense.bias.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=model._named_dense.bias, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+ return root_checkpointable
+
+ def _set_sentinels(self, root_checkpointable):
+ self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
+ self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")
+ .assign([102.]))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
+
+ def _check_sentinels(self, root_checkpointable):
+ self.assertAllEqual(
+ [1.], self.evaluate(root_checkpointable.model._named_dense.bias))
+ self.assertAllEqual([2.], self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+ def _write_name_based_checkpoint(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ name_saver = core_saver.Saver()
+ return name_saver.save(
+ sess=session, save_path=checkpoint_prefix,
+ global_step=root.optimizer_step)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testLoadFromNameBasedSaver(self):
+ """Save a name-based checkpoint, load it using the object-based API."""
+ with test_util.device(use_gpu=True):
+ save_path = self._write_name_based_checkpoint()
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ with self.assertRaises(AssertionError):
+ self._check_sentinels(root)
+ object_saver = util.CheckpointableSaver(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ if context.executing_eagerly():
+ self._check_sentinels(root)
+ if context.executing_eagerly():
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_consumed()
+ else:
+ # When graph building, we haven't read any keys, so we don't know
+ # whether the restore will be complete.
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_consumed()
+ status.run_restore_ops()
+ self._check_sentinels(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ status.initialize_or_restore()
+ self._check_sentinels(root)
+
+ # TODO(allenl): Test for the core name-based saver loading object-based
+ # checkpoints once object-based checkpointing is in core.
+
+ def testSaveGraphLoadEager(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ save_path = root.save(
+ session=session, file_prefix=checkpoint_prefix)
+ with context.eager_mode():
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed()
+ self._check_sentinels(root)
+
+ def testSaveEagerLoadGraph(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.eager_mode():
+ root = self._initialized_model()
+ save_path = root.save(file_prefix=checkpoint_prefix)
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph):
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed().run_restore_ops()
+ self._check_sentinels(root)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
new file mode 100644
index 0000000000..bd5557f4fd
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -0,0 +1,1349 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Version 2 of class Optimizer."""
+# pylint: disable=g-bad-name
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import optimizer as optimizer_v1
+from tensorflow.python.training import slot_creator
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
+
+
+class _OptimizableVariable(object):
+ """Interface for abstracting over variables in the optimizers."""
+
+ @abc.abstractmethod
+ def target(self):
+ """Returns the optimization target for this variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+ @abc.abstractmethod
+ def update_op(self, optimizer, g, *args):
+ """Returns the update ops for updating the variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+
+class _RefVariableProcessor(_OptimizableVariable):
+ """Processor for Variable."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v._ref() # pylint: disable=protected-access
+
+ def update_op(self, optimizer, g, *args):
+ if isinstance(g, ops.Tensor):
+ update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+ else:
+ assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
+ "tensor nor IndexedSlices.")
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ # pylint: disable=protected-access
+ return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
+
+
+class _DenseReadResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _DenseResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ if isinstance(g, ops.IndexedSlices):
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ return optimizer._resource_apply_sparse_duplicate_indices(
+ g.values, self._v, g.indices, *args)
+ update_op = optimizer._resource_apply_dense(g, self._v, *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _TensorProcessor(_OptimizableVariable):
+ """Processor for ordinary Tensors.
+
+ Even though a Tensor can't really be updated, sometimes it is useful to
+ compute the gradients with respect to a Tensor using the optimizer. Updating
+ the Tensor is, of course, unsupported.
+ """
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ raise NotImplementedError("Trying to update a Tensor ", self._v)
+
+
+def _get_processor(v):
+ """The processor of v."""
+ if context.executing_eagerly():
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ else:
+ return _DenseResourceVariableProcessor(v)
+ if v.op.type == "VarHandleOp":
+ return _DenseResourceVariableProcessor(v)
+ if isinstance(v, variables.Variable):
+ return _RefVariableProcessor(v)
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ raise NotImplementedError("Trying to optimize unsupported type ", v)
+
+
+def _var_key_v2(var):
+ """Key for representing a primary variable, for looking up slots."""
+ # pylint: disable=protected-access
+ if hasattr(var, "_distributed_container"):
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
+ if context.executing_eagerly():
+ return distributed_container._unique_id
+ return distributed_container._shared_name
+ if context.executing_eagerly():
+ return var._unique_id
+ return var.op.name
+
+
+def _resolve(value, name):
+ if callable(value):
+ value = value()
+ return ops.convert_to_tensor(value, name=name)
+
+
+def _is_dynamic(value):
+ """Returns true if __init__ arg `value` should be re-evaluated each step."""
+ if callable(value): return True
+ # Don't need to do anything special in graph mode, since dynamic values
+ # will propagate correctly automatically.
+ # TODO(josh11b): Add per-device caching across steps using variables for
+ # truly static values once we add distributed support.
+ if context.executing_eagerly() and isinstance(
+ value, resource_variable_ops.ResourceVariable):
+ return True
+ return False
+
+
+class _OptimizerV2State(object):
+ """Holds per-graph and per-step optimizer state.
+
+ Use _init_with_static_hyper() to create the state for a graph, and then
+ _copy_with_dynamic_hyper() to convert that to state for a particular step.
+ The difference between the two is that the former only has hyper
+ parameter values that are static and the latter also has values that
+ can change every step (according to _is_dynamic()).
+ """
+
+ def __init__(self, op_name):
+ self._op_name = op_name
+
+ def _init_with_static_hyper(self, hyper):
+ """Initialize a fresh state object from hyper dict."""
+ # self._hyper contains a dict from name to a dict with the Tensor values.
+ # This dict starts with a single item with key "None" with the hyper
+ # parameter value converted to a Tensor. Other items have dtype keys
+ # with that Tensor cast to that dtype.
+ with ops.init_scope():
+ self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if not dynamic}
+ self._slots = {}
+ self._non_slot_dict = {}
+ # Extra state to help Optimizers implement Checkpointable. Holds information
+ # about variables which will be restored as soon as they're created.
+ self._deferred_dependencies = {} # Non-slot variables
+ self._deferred_slot_restorations = {} # Slot variables
+
+ def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
+ """Create a new state object for a particular step."""
+ ret = _OptimizerV2State(self._op_name)
+ # pylint: disable=protected-access
+ ret._slots = self._slots
+ ret._non_slot_dict = self._non_slot_dict
+ ret._deferred_dependencies = self._deferred_dependencies
+ ret._deferred_slot_restorations = self._deferred_slot_restorations
+ ret._hyper = {name: {None: _resolve(value, name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if dynamic}
+ ret._hyper.update(self._hyper)
+ ret._non_slot_devices = non_slot_devices
+ ret._distribution = distribution
+ return ret
+
+ def _variables(self):
+ """Returns a list of all variables held by self."""
+ optimizer_variables = list(self._non_slot_dict.values())
+ for variable_dict in self._slots.values():
+ for slot_for_variable in variable_dict.values():
+ optimizer_variables.append(slot_for_variable)
+ # Sort variables by name so that the return is deterministic.
+ return sorted(optimizer_variables, key=lambda v: v.name)
+
+ def _slot_dict(self, slot_name):
+ """Returns a dict for caching slots created under the given name.
+
+ Args:
+ slot_name: Name for the slot.
+
+ Returns:
+ A dict that maps primary `Variable` objects to the slot created
+ for that variable, under the given slot name.
+ """
+ named_slots = self._slots.get(slot_name, None)
+ if named_slots is None:
+ named_slots = {}
+ self._slots[slot_name] = named_slots
+ return named_slots
+
+ def create_slot(self, var, val, slot_name, optional_op_name=None):
+ """Find or create a slot for a variable.
+
+ Args:
+ var: A `Variable` object.
+ val: A `Tensor`. The initial value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot(
+ var, val, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def create_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, optional_op_name=None):
+ """Find or create a slot for a variable, using an Initializer.
+
+ Args:
+ var: A `Variable` object.
+ initializer: An `Initializer`. The initial value of the slot.
+ shape: Shape of the initial value of the slot.
+ dtype: Type of the value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot_with_initializer(
+ var, initializer, shape, dtype, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def zeros_slot(self, var, slot_name, optional_op_name=None):
+ """Find or create a slot initialized with 0.0.
+
+ Args:
+ var: A `Variable` object.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_zeros_slot(
+ var, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable,
+ optional_op_name=None):
+ """Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored. When executing eagerly, we create the slot variable with a
+ restoring initializer.
+
+ No new variables are created when graph building. Instead,
+ _restore_slot_variable catches these after normal creation and adds restore
+ ops to the graph. This method is nonetheless important when graph building
+ for the case when a slot variable has already been created but `variable`
+ has just been added to a dependency graph (causing us to realize that the
+ slot variable needs to be restored).
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+ """
+ slot_variable = self.get_slot(var=variable, name=slot_name)
+ if (slot_variable is None and context.executing_eagerly() and
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
+ initializer = checkpointable.CheckpointInitialValue(
+ checkpoint_position=slot_variable_position)
+ slot_variable = self.create_slot(
+ var=variable,
+ val=initializer,
+ slot_name=slot_name,
+ optional_op_name=optional_op_name)
+ # Optimizers do not have unconditional dependencies on their slot
+ # variables (nor do any other objects). They are only saved if the
+ # variables they were created for are also saved.
+ if slot_variable is not None:
+ # If we've either made this slot variable, or if we've pulled out an
+ # existing slot variable, we should restore it.
+ slot_variable_position.restore(slot_variable)
+ else:
+ # We didn't make the slot variable. Defer restoring until it gets created
+ # normally. We keep a list rather than the one with the highest restore
+ # UID in case slot variables have their own dependencies, in which case
+ # those could differ between restores.
+ variable_key = _var_key_v2(variable)
+ self._deferred_slot_restorations.setdefault(
+ slot_name, {}).setdefault(variable_key, []).append(
+ slot_variable_position)
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ named_slots = self._slots.get(name, None)
+ if not named_slots:
+ return None
+ return named_slots.get(_var_key_v2(var), None)
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ return sorted(self._slots.keys())
+
+ def create_non_slot(self, initial_value, name, colocate_with=None):
+ """Add an extra variable, not associated with a slot."""
+ v = self._non_slot_dict.get(name, None)
+ if v is None:
+ if colocate_with is None: colocate_with = self._non_slot_devices
+ with self._distribution.colocate_vars_with(colocate_with):
+ # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
+ v = variable_scope.variable(initial_value, name=name, trainable=False)
+ self._non_slot_dict[name] = v
+ deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
+ for checkpoint_position in sorted(
+ deferred_dependencies_list,
+ key=lambda restore: restore.checkpoint.restore_uid,
+ reverse=True):
+ checkpoint_position.restore(v)
+ return v
+
+ def _restore_slot_variable(self, slot_name, variable, slot_variable):
+ """Restore a newly created slot variable's value."""
+ variable_key = _var_key_v2(variable)
+ deferred_restorations = self._deferred_slot_restorations.get(
+ slot_name, {}).pop(variable_key, [])
+ # Iterate over restores, highest restore UID first to minimize the number
+ # of assignments.
+ deferred_restorations.sort(key=lambda position: position.restore_uid,
+ reverse=True)
+ for checkpoint_position in deferred_restorations:
+ checkpoint_position.restore(slot_variable)
+
+ def get_non_slot(self, name):
+ """Returns the non-slot variable identified by `name`."""
+ return self._non_slot_dict.get(name, None)
+
+ def get_hyper(self, name, dtype=None):
+ """Returns the `name` hyper parameter, optionally cast to `dtype`."""
+ dtype_dict = self._hyper[name]
+ # Do we have the value cast to dtype already cached? This should always
+ # succeed when dtype is None.
+ if dtype in dtype_dict:
+ return dtype_dict[dtype]
+ # Not cached, cast to dtype and save the result in the cache.
+ result = math_ops.cast(dtype_dict[None], dtype)
+ dtype_dict[dtype] = result
+ return result
+
+
+class OptimizerV2(optimizer_v1.Optimizer):
+ """Updated base class for optimizers.
+
+ This class defines the API to add Ops to train a model. You never use this
+ class directly, but instead instantiate one of its subclasses such as
+ `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
+
+ ### Usage
+
+ ```python
+ # Create an optimizer with the desired parameters.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+ # Add Ops to the graph to minimize a cost by updating a list of variables.
+ # "cost" is a Tensor, and the list of variables contains tf.Variable
+ # objects.
+ opt_op = opt.minimize(cost, var_list=<list of variables>)
+ ```
+
+ In the training program you will just have to run the returned Op.
+
+ ```python
+ # Execute opt_op to do one step of training:
+ opt_op.run()
+ ```
+
+ ### Processing gradients before applying them.
+
+ Calling `minimize()` takes care of both computing the gradients and
+ applying them to the variables. If you want to process the gradients
+ before applying them you can instead use the optimizer in three steps:
+
+ 1. Compute the gradients with `compute_gradients()`.
+ 2. Process the gradients as you wish.
+ 3. Apply the processed gradients with `apply_gradients()`.
+
+ Example:
+
+ ```python
+ # Create an optimizer.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Compute the gradients for a list of variables.
+ grads_and_vars = opt.compute_gradients(loss, <list of variables>)
+
+ # grads_and_vars is a list of tuples (gradient, variable). Do whatever you
+ # need to the 'gradient' part, for example cap them, etc.
+ capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
+
+ # Ask the optimizer to apply the capped gradients.
+ opt.apply_gradients(capped_grads_and_vars)
+ ```
+
+ ### Gating Gradients
+
+ Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
+ argument that controls the degree of parallelism during the application of
+ the gradients.
+
+ The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
+
+ <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
+ the maximum parallelism in execution, at the cost of some non-reproducibility
+ in the results. For example the two gradients of `matmul` depend on the input
+ values: With `GATE_NONE` one of the gradients could be applied to one of the
+ inputs _before_ the other gradient is computed resulting in non-reproducible
+ results.
+
+ <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
+ they are used. This prevents race conditions for Ops that generate gradients
+ for multiple inputs where the gradients depend on the inputs.
+
+ <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
+ before any one of them is used. This provides the least parallelism but can
+ be useful if you want to process all gradients before applying any of them.
+
+ ### Slots
+
+ Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
+ allocate and manage additional variables associated with the variables to
+ train. These are called <i>Slots</i>. Slots have names and you can ask the
+ optimizer for the names of the slots that it uses. Once you have a slot name
+ you can ask the optimizer for the variable it created to hold the slot value.
+
+ This can be useful if you want to log debug a training algorithm, report stats
+ about the slots, etc.
+
+ ### Non-slot variables
+
+ Some optimizer subclasses, such as `AdamOptimizer` have variables that
+ are not associated with the variables to train, just the step itself.
+
+ ### Hyper parameters
+
+ These are arguments passed to the optimizer subclass constructor
+ (the `__init__` method), and then passed to `self._set_hyper()`.
+ They can be either regular Python values (like 1.0), tensors, or
+ callables. If they are callable, the callable will be called during
+ `apply_gradients()` to get the value for the hyper parameter.
+
+ ### State
+
+ Internal methods are passed a `state` argument with the correct
+ values to use for the slot and non-slot variables, and the hyper
+ parameters.
+ """
+
+ # Values for gate_gradients.
+ GATE_NONE = 0
+ GATE_OP = 1
+ GATE_GRAPH = 2
+
+ def __init__(self, name):
+ """Create a new Optimizer.
+
+ This must be called by the constructors of subclasses.
+ Note that Optimizer instances should not bind to a single graph,
+ and so shouldn't keep Tensors as member variables. Generally
+ you should be able to use the _set_hyper()/state.get_hyper()
+ facility instead.
+
+ Args:
+ name: A non-empty string. The name to use for accumulators created
+ for the optimizer.
+
+ Raises:
+ ValueError: If name is malformed.
+ RuntimeError: If _create_slots has been overridden instead of
+ _create_vars.
+ """
+ # Note: We intentionally don't call parent __init__.
+
+ # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
+ if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
+ OptimizerV2._create_slots.__code__):
+ raise RuntimeError("Override _create_vars instead of _create_slots when "
+ "descending from OptimizerV2 (class %s)" %
+ self.__class__.__name__)
+ if not name:
+ raise ValueError("Must specify the optimizer name")
+
+ self._use_locking = False
+ self._name = name
+ # Map from graph_key to state for that graph. We use the graph_key
+ # since it works in both eager and graph mode, and gives the outer
+ # graph inside functions.
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context is None:
+ # In a cross-tower context for a DistributionStrategy, which means
+ # only one Optimizer will be created, not one per tower.
+ self._per_graph_state = {}
+ else:
+ # We use get_tower_context().merge_call() to get a single dict
+ # shared across all model replicas when running with a
+ # DistributionStrategy.
+ self._per_graph_state = tower_context.merge_call(lambda _: {})
+
+ # Hyper parameters, and whether they should be re-evaluated every step.
+ self._hyper = {}
+
+ def _set_hyper(self, name, value):
+ self._hyper[name] = (_is_dynamic(value), value)
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=GATE_OP, aggregation_method=None,
+ colocate_gradients_with_ops=False, name=None,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Add operations to minimize `loss` by updating `var_list`.
+
+ This method simply combines calls `compute_gradients()` and
+ `apply_gradients()`. If you want to process the gradient before applying
+ them call `compute_gradients()` and `apply_gradients()` explicitly instead
+ of using this function.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
+ Raises:
+ ValueError: If some of the variables are not `Variable` objects.
+
+ @compatibility(eager)
+ When eager execution is enabled, `loss` should be a Python function that
+ takes elements of `var_list` as arguments and computes the value to be
+ minimized. If `var_list` is None, `loss` should take no arguments.
+ Minimization (and gradient computation) is done with respect to the
+ elements of `var_list` if not None, else with respect to any trainable
+ variables created during the execution of the `loss` function.
+ `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
+ `grad_loss` are ignored when eager execution is enabled.
+ @end_compatibility
+ """
+ grads_and_vars = self.compute_gradients(
+ loss, var_list=var_list, gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss, stop_gradients=stop_gradients,
+ scale_loss_by_num_towers=scale_loss_by_num_towers)
+
+ vars_with_grad = [v for g, v in grads_and_vars if g is not None]
+ if not vars_with_grad:
+ raise ValueError(
+ "No gradients provided for any variable, check your graph for ops"
+ " that do not support gradients, between variables %s and loss %s." %
+ ([str(v) for _, v in grads_and_vars], loss))
+
+ return self.apply_gradients(grads_and_vars, global_step=global_step,
+ name=name)
+
+ def compute_gradients(self, loss, var_list=None,
+ gate_gradients=GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Compute gradients of `loss` for the variables in `var_list`.
+
+ This is the first part of `minimize()`. It returns a list
+ of (gradient, variable) pairs where "gradient" is the gradient
+ for "variable". Note that "gradient" can be a `Tensor`, an
+ `IndexedSlices`, or `None` if there is no gradient for the
+ given variable.
+
+ Args:
+ loss: A Tensor containing the value to minimize or a callable taking
+ no arguments which returns the value to minimize. When eager execution
+ is enabled it must be a callable.
+ var_list: Optional list or tuple of `tf.Variable` to update to minimize
+ `loss`. Defaults to the list of variables collected in the graph
+ under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ A list of (gradient, variable) pairs. Variable is always present, but
+ gradient can be `None`.
+
+ Raises:
+ TypeError: If `var_list` contains anything else than `Variable` objects.
+ ValueError: If some arguments are invalid.
+ RuntimeError: If called with eager execution enabled and `loss` is
+ not callable.
+
+ @compatibility(eager)
+ When eager execution is enabled, `gate_gradients`, `aggregation_method`,
+ and `colocate_gradients_with_ops` are ignored.
+ @end_compatibility
+ """
+ # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+ if callable(loss):
+ with backprop.GradientTape() as tape:
+ if var_list is not None:
+ tape.watch(var_list)
+ loss_value = loss()
+
+ # Scale loss for number of towers (callable-loss case). In this case,
+ # we have to be careful to call distribute_lib.get_loss_reduction()
+ # *after* loss() is evaluated, so we know what loss reduction it uses.
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss_value *= 1. / num_towers
+
+ if var_list is None:
+ var_list = tape.watched_variables()
+ grads = tape.gradient(loss_value, var_list, grad_loss)
+ return list(zip(grads, var_list))
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "`loss` passed to Optimizer.compute_gradients should "
+ "be a function when eager execution is enabled.")
+
+ # Scale loss for number of towers (non-callable-loss case).
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss *= 1. / num_towers
+
+ if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
+ optimizer_v1.Optimizer.GATE_OP,
+ optimizer_v1.Optimizer.GATE_GRAPH]:
+ raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
+ "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
+ gate_gradients)
+ self._assert_valid_dtypes([loss])
+ if grad_loss is not None:
+ self._assert_valid_dtypes([grad_loss])
+ if var_list is None:
+ var_list = (
+ variables.trainable_variables() +
+ ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ else:
+ var_list = nest.flatten(var_list)
+ # pylint: disable=protected-access
+ var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
+ # pylint: enable=protected-access
+ processors = [_get_processor(v) for v in var_list]
+ if not var_list:
+ raise ValueError("No variables to optimize.")
+ var_refs = [p.target() for p in processors]
+ grads = gradients.gradients(
+ loss, var_refs, grad_ys=grad_loss,
+ gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ stop_gradients=stop_gradients)
+ if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
+ grads = control_flow_ops.tuple(grads)
+ grads_and_vars = list(zip(grads, var_list))
+ self._assert_valid_dtypes(
+ [v for g, v in grads_and_vars
+ if g is not None and v.dtype != dtypes.resource])
+ return grads_and_vars
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This is the second part of `minimize()`. It returns an `Operation` that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+
+ Raises:
+ TypeError: If `grads_and_vars` is malformed.
+ ValueError: If none of the variables have gradients.
+ """
+ # This is a default implementation of apply_gradients() that can be shared
+ # by most optimizers. It relies on the subclass implementing the following
+ # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
+
+ # Filter out variables with gradients of `None`.
+ grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
+ if not grads_and_vars:
+ raise ValueError("No variables provided.")
+ filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
+ if not filtered:
+ raise ValueError("No gradients provided for any variable: %s." %
+ ([str(v) for _, v in grads_and_vars],))
+ return distribution_strategy_context.get_tower_context().merge_call(
+ self._distributed_apply, filtered, global_step=global_step, name=name)
+
+ def _get_or_create_state(self, var_list=None):
+ """Either looks up or creates `_OptimizerV2State`.
+
+ If any variables are available, they should be passed via the `var_list`
+ argument, and these will be used to determine the graph to create/retrieve
+ state for. Otherwise the returned state is for the current default graph.
+
+ Args:
+ var_list: A list of variables to extract a graph from.
+
+ Returns:
+ An `_OptimizerV2State` object.
+ """
+ # Determine the graph_key from the current graph.
+ eager_execution = context.executing_eagerly()
+ if eager_execution or var_list is None:
+ graph = ops.get_default_graph()
+ else:
+ graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
+ assert graph is not None
+ graph_key = graph._graph_key # pylint: disable=protected-access
+
+ # Get the per graph state by looking up the graph_key.
+ if graph_key in self._per_graph_state:
+ per_graph_state = self._per_graph_state[graph_key]
+ else:
+ per_graph_state = _OptimizerV2State(self._name)
+ per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
+ self._per_graph_state[graph_key] = per_graph_state
+ return per_graph_state
+
+ def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
+ """`apply_gradients` for use with a `DistributionStrategy`."""
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
+ var_list = [v for _, v in grads_and_vars]
+ grads_and_vars = zip(reduced_grads, var_list)
+
+ unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
+ eager_execution = context.executing_eagerly()
+ if eager_execution:
+ # Give a clear error in this case instead of "name not supported
+ # for Eager Tensors" when we compute non_slot_devices.
+ for v in unwrapped_var_list:
+ if isinstance(v, ops.Tensor):
+ raise NotImplementedError("Trying to update a Tensor ", v)
+
+ with ops.name_scope(name, self._name) as name:
+ per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
+ # Include the current value of any dynamic hyper parameters in `state`.
+ non_slot_devices = distribution.non_slot_devices(var_list)
+ state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
+ self._hyper, distribution, non_slot_devices)
+
+ # Create any slot and non-slot variables we need in `state`.
+ with ops.init_scope():
+ self._create_vars(var_list, state)
+
+ with ops.name_scope(name): # Re-enter name_scope created above
+ # Give the child class a chance to do something before we start
+ # applying gradients.
+ self._prepare(state)
+
+ def update(v, g):
+ """Update variable `v` using gradient `g`."""
+ assert v is not None
+
+ # Convert the grad to Tensor or IndexedSlices if necessary, and
+ # look up a processor for each variable's type.
+ try:
+ g = ops.convert_to_tensor_or_indexed_slices(g)
+ except TypeError:
+ raise TypeError(
+ "Gradient must be convertible to a Tensor"
+ " or IndexedSlices, or None: %s" % g)
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ processor = _get_processor(v)
+
+ # We colocate all ops created in _apply_dense or _apply_sparse
+ # on the same device as the variable.
+ # TODO(apassos): figure out how to get the variable name here.
+ scope_name = "" if eager_execution else v.op.name
+ # device_policy is set because non-mirrored tensors will be read in
+ # `update_op`.
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with ops.name_scope("update_" + scope_name), \
+ context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return processor.update_op(self, g, state)
+
+ # Use the processors to update the variables.
+ update_ops = []
+ for grad, var in grads_and_vars:
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
+
+ # Give the child class a chance to do something after applying
+ # gradients
+ def finish():
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return self._finish(state)
+
+ update_ops = control_flow_ops.group(update_ops)
+ with ops.control_dependencies([update_ops]):
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
+
+ # Update `global_step` (if any).
+ if global_step is None:
+ apply_updates = distribution.group(finish_updates, name=name)
+ else:
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(global_step, update_global_step,
+ name)
+
+ # Add the training op to the TRAIN_OP graph collection in graph mode.
+ if not eager_execution:
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
+
+ return apply_updates
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ state = self._get_state_for_var(var)
+ return state.get_slot(var, name) if state is not None else None
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ state = self._get_per_graph_state()
+ return state.get_slot_names() if state is not None else []
+
+ def variables(self):
+ """A list of variables which encode the current state of `Optimizer`.
+
+ Includes slot variables and additional global variables created by the
+ optimizer in the current default graph.
+
+ Returns:
+ A list of variables.
+ """
+ state = self._get_per_graph_state()
+ return state._variables() if state is not None else [] # pylint: disable=protected-access
+
+ # --------------
+ # Methods to be implemented by subclasses if they want to use the
+ # inherited implementation of apply_gradients() or compute_gradients().
+ # --------------
+ def _create_vars(self, var_list, state):
+ """Create all slots needed by the variables and any non-slot variables.
+
+ Args:
+ var_list: A list of `Variable` objects.
+ state: An object with these methods:
+ `create_slot(var, val, slot_name, optional_op_name)`,
+ `create_slot_with_initializer(`
+ `var, initializer, shape, dtype, slot_name, optional_op_name)`,
+ `zeros_slot(var, slot_name, optional_op_name)`,
+ `create_non_slot_variable(initial_value, name, colocate_with)`,
+ `get_hyper(name)`
+ """
+ # No slots needed by default
+ pass
+
+ def _prepare(self, state):
+ """Code to execute before applying gradients.
+
+ Note that most uses of _prepare() in Optimizer have been subsumed
+ by explicit support for hyper parameters in OptimizerV2
+
+ Args:
+ state: An object with a `get_hyper(name)` method.
+
+ Returns:
+ Return value will be ignored.
+ """
+ pass
+
+ def _apply_dense(self, grad, var, state):
+ """Add ops to apply dense gradients to `var`.
+
+ Args:
+ grad: A `Tensor`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_dense(self, grad, handle, state):
+ """Add ops to apply dense gradients to the variable `handle`.
+
+ Args:
+ grad: a `Tensor` representing the gradient.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_sparse_duplicate_indices(
+ self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to `handle`, with repeated indices.
+
+ Optimizers which override this method must deal with repeated indices. See
+ the docstring of `_apply_sparse_duplicate_indices` for details. By default
+ the correct behavior, to sum non-unique indices and their associated
+ gradients, is enforced by first pre-processing `grad` and `indices` and
+ passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
+ with duplicate indices may instead override this method to avoid the
+ overhead of summing.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices may be repeated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ # pylint: disable=protected-access
+ summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad, indices=indices)
+ # pylint: enable=protected-access
+ return self._resource_apply_sparse(
+ summed_grad, handle, unique_indices, state)
+
+ def _resource_apply_sparse(self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to the variable `handle`.
+
+ Similar to `_apply_sparse`, the `indices` argument to this method has been
+ de-duplicated. Optimizers which deal correctly with non-unique indices may
+ instead override `_resource_apply_sparse_duplicate_indices` to avoid this
+ overhead.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices are unique.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
+
+ Optimizers which override this method must deal with IndexedSlices objects
+ such as the following:
+
+ IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
+
+ The correct interpretation is:
+
+ IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
+
+ Many optimizers deal incorrectly with repeated indices when updating based
+ on sparse gradients (e.g. summing squares rather than squaring the sum, or
+ applying momentum terms multiple times). Adding first is always the correct
+ behavior, so this is enforced here by reconstructing the IndexedSlices to
+ have only unique indices, then calling _apply_sparse.
+
+ Optimizers which deal correctly with repeated indices may instead override
+ this method to avoid the overhead of summing indices.
+
+ Args:
+ grad: `IndexedSlices`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ # pylint: disable=protected-access
+ summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad.values, indices=grad.indices)
+ # pylint: enable=protected-access
+ gradient_no_duplicate_indices = ops.IndexedSlices(
+ indices=unique_indices,
+ values=summed_values,
+ dense_shape=grad.dense_shape)
+ return self._apply_sparse(gradient_no_duplicate_indices, var, state)
+
+ def _apply_sparse(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`.
+
+ The IndexedSlices object passed to `grad` in this function is by default
+ pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
+ indices (see its docstring for details). Optimizers which can tolerate or
+ have correct special cases for duplicate sparse indices may override
+ `_apply_sparse_duplicate_indices` instead of this function, avoiding that
+ overhead.
+
+ Args:
+ grad: `IndexedSlices`, with no repeated indices.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _finish(self, state):
+ """Do what is needed to finish the update.
+
+ This is called inside a scope colocated with any non-slot variables.
+
+ Args:
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ The operation to apply updates, or None if no updates.
+ """
+ return None
+
+ # --------------
+ # Utility methods for subclasses.
+ # --------------
+ def _get_per_graph_state(self):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
+
+ def _get_state_for_var(self, var):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(var._graph_key, None)
+
+ # --------------
+ # Overridden methods from Checkpointable.
+ # --------------
+
+ def _track_checkpointable(self, *args, **kwargs):
+ """Optimizers may not track dependencies. Raises an error."""
+ raise NotImplementedError(
+ "Optimizers may not have dependencies. File a feature request if this "
+ "limitation bothers you.")
+
+ @property
+ def _checkpoint_dependencies(self):
+ """From Checkpointable. Gather graph-specific non-slot variables to save."""
+ current_graph_non_slot_variables = []
+ state = self._get_per_graph_state()
+ if state is not None:
+ for name, variable_object in sorted(
+ state._non_slot_dict.items(), # pylint: disable=protected-access
+ # Avoid comparing variables
+ key=lambda item: item[0]):
+ current_graph_non_slot_variables.append(
+ checkpointable.CheckpointableReference(
+ name=name, ref=variable_object))
+ # Note: ignores super(); Optimizers may not have any dependencies outside of
+ # state objects.
+ return current_graph_non_slot_variables
+
+ def _lookup_dependency(self, name):
+ """From Checkpointable. Find a non-slot variable in the current graph."""
+ state = self._get_per_graph_state()
+ if state is None:
+ return None
+ else:
+ return state.get_non_slot(name)
+
+ @property
+ def _deferred_dependencies(self):
+ """Lets Checkpointable know where non-slot variables are created.
+
+ If necessary, creates a new state object for the current default graph.
+ Checkpointable will then add entries to that state's deferred dependency
+ dictionary. The state object will check that dictionary when creating
+ non-slot variables, restoring their value if an entry is found.
+
+ Returns:
+ A dictionary which holds deferred dependencies for the current default
+ graph.
+ """
+ state = self._get_or_create_state()
+ return state._deferred_dependencies # pylint: disable=protected-access
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable):
+ """Checkpointable: Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored.
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ """
+ state = self._get_or_create_state(var_list=[variable])
+ state._create_or_restore_slot_variable( # pylint: disable=protected-access
+ slot_variable_position=slot_variable_position,
+ slot_name=slot_name,
+ variable=variable,
+ optional_op_name=self._name)
+
+ # --------------
+ # Unsupported parent methods
+ # --------------
+ def _slot_dict(self, slot_name):
+ raise NotImplementedError(
+ "_slot_dict() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot(self, var, val, slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot_with_initializer() method unsupported in "
+ "OptimizerV2")
+
+ def _create_non_slot_variable(self, initial_value, name, colocate_with):
+ raise NotImplementedError(
+ "_create_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _get_non_slot_variable(self, name, graph=None):
+ raise NotImplementedError(
+ "_get_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _non_slot_variables(self):
+ raise NotImplementedError(
+ "_non_slot_variables() method unsupported in OptimizerV2")
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
new file mode 100644
index 0000000000..a6c939393e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -0,0 +1,277 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional test for OptimizerV2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class OptimizerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBasic(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ # Note that for eager execution, minimize expects a function instead of a
+ # Tensor.
+ global_step = resource_variable_ops.ResourceVariable(
+ array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
+ sgd_op = sgd.SGD(3.0)
+
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Run 1 step of sgd through optimizer
+ opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
+ self.evaluate(opt_op)
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ def testAggregationMethod(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost,
+ global_step, [var0, var1],
+ aggregation_method=gradients_impl.AggregationMethod.
+ EXPERIMENTAL_ACCUMULATE_N)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-14., -13.], var0.eval())
+ self.assertAllClose([-6., -5.], var1.eval())
+
+ def testPrecomputedGradient(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ grad_loss = constant_op.constant([42, -42], dtype=dtype)
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost, global_step, [var0, var1], grad_loss=grad_loss)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
+ var0.eval())
+ self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
+ var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoVariables(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, trainable=False, name='a')
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, trainable=False, name='b')
+ return 5 * var0 + var1
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No.*variables'):
+ sgd_op.minimize(loss)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ return 5 * var0
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No gradients'):
+ # var1 has no gradient
+ sgd_op.minimize(loss, var_list=[var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_Minimize(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return constant_op.constant(5.0)
+
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.minimize(loss, var_list=[var0, var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_ApplyGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.apply_gradients([(None, var0), (None, var1)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGradientsAsVariables(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
+ # Convert gradients to tf.Variables
+ converted_grads = [
+ resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
+ name='c_%d_%d' % (i, j))
+ for j, gv in enumerate(grads_and_vars)
+ ]
+ convert_ops = [
+ state_ops.assign(converted_grads[j], gv[0])
+ for j, gv in enumerate(grads_and_vars)
+ ]
+
+ self.evaluate(variables.global_variables_initializer())
+ # Run convert_ops to achieve the gradietns converting
+ self.evaluate(convert_ops)
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 1 step of sgd through optimizer
+ converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
+ opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
+ self.evaluate(opt_op)
+
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testComputeGradientsWithTensors(self):
+ x = ops.convert_to_tensor(1.0)
+ def f():
+ return x * x
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(f, [x])
+ self.assertEqual(1, len(grads_and_vars))
+ grad, x_as_var = grads_and_vars[0]
+ self.assertIs(x, x_as_var)
+ self.assertEqual(2.0, self.evaluate(grad))
+
+ with self.assertRaises(NotImplementedError):
+ sgd_op.apply_gradients(grads_and_vars)
+
+ def testTrainOp(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0])
+ var1 = variables.Variable([3.0, 4.0])
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+ self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
+
+ def testConstraint(self):
+ constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
+ constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0],
+ constraint=constraint_01)
+ var1 = variables.Variable([3.0, 4.0],
+ constraint=constraint_0)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-0.1, -0.1], var0.eval())
+ self.assertAllClose([0., 0.], var1.eval())
+
+ def testStopGradients(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], name='var0')
+ var1 = variables.Variable([3.0, 4.0], name='var1')
+ var0_id = array_ops.identity(var0)
+ cost = 5 * var0_id + 3 * var1
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1],
+ stop_gradients=[var0_id])
+ grad_dict = {var.op.name: grad for grad, var in grads_and_vars}
+ self.assertIsNone(grad_dict['var0'])
+ self.assertIsNotNone(grad_dict['var1'])
+
+ def testDoNotOverrideCreateSlots(self):
+ class ShouldNotOverrideCreateSlots(optimizer_v2.OptimizerV2):
+
+ def _create_slots(self, var_list):
+ """In OptimizerV2 _create_slots was renamed _create_vars."""
+ return var_list
+
+ with self.assertRaises(RuntimeError):
+ ShouldNotOverrideCreateSlots('name')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
new file mode 100644
index 0000000000..2748d8eff7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -0,0 +1,239 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RMSprop optimizer for Tensorflow.
+
+rmsprop algorithm [tieleman2012rmsprop]
+
+A detailed description of rmsprop.
+
+- maintain a moving (discounted) average of the square of gradients
+- divide gradient by the root of this average
+
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square)
+delta = - mom
+
+This implementation of RMSProp uses plain momentum, not Nesterov momentum.
+
+The centered version additionally maintains a moving (discounted) average of the
+gradients, and uses that average to estimate the variance:
+
+mean_grad = rho * mean_square{t-1} + (1-rho) * gradient
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t /
+ sqrt(mean_square - mean_grad**2)
+delta = - mom
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+
+from tensorflow.python.training import training_ops
+
+
+class RMSProp(optimizer_v2.OptimizerV2):
+ """RMSProp optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values (except the learning rate, which can be freely tuned).
+
+ This optimizer is usually a good choice for recurrent neural networks.
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense implementation of this algorithm, variables and their
+ corresponding accumulators (momentum, gradient moving average, square
+ gradient moving average) will be updated even if the gradient is zero
+ (i.e. accumulators will decay, momentum will be applied). The sparse
+ implementation (used when the gradient is an `IndexedSlices` object,
+ typically because of `tf.gather` or an embedding lookup in the forward pass)
+ will not update variable slices or their accumulators unless those slices
+ were used in the forward pass (nor is there an "eventual" correction to
+ account for these omitted updates). This leads to more efficient updates for
+ large embedding lookup tables (where most of the slices are not accessed in
+ a particular graph execution), but differs from the published algorithm.
+
+ Arguments:
+ learning_rate: A float hyperparameter >= 0. The learning rate.
+ rho: A float hyperparameter >= 0. Discounting factor for the
+ history/coming gradient.
+ momentum: A float hyperparameter >= 0.
+ epsilon: A float hyperparameter >= 0 . Small value to initialize the
+ average square gradient variable and avoid zero denominator.
+ centered: If True, gradients are normalized by the estimated variance of
+ the gradient; if False, by the uncentered second moment. Setting this to
+ True may help with training, but is slightly more expensive in terms of
+ computation and memory. Defaults to False.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "RMSProp".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.9,
+ momentum=None,
+ epsilon=1e-10,
+ centered=False,
+ name="RMSProp"):
+ super(RMSProp, self).__init__(name)
+ # Momentum default is `None` for consistency with SGD
+ # but underlying implementation uses `momentum` hyperparameter here
+ # regardless unlike SGD. Since extneral Keras RMSProp does not have
+ # a `momentum` weight, for compatibility with external Keras h5 files,
+ # when `momentum` was set as `None` we should ignore the `momentum`
+ # variable in `get_weights` and not require it in `set_weights`.
+ if momentum is None:
+ momentum = 0.0
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("momentum", momentum)
+ self._set_hyper("epsilon", epsilon)
+
+ self._centered = centered
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ init_rms = state.get_hyper(
+ "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
+ state.create_slot_with_initializer(v, init_rms, v.get_shape(),
+ v.dtype.base_dtype, "rms")
+ if self._centered:
+ state.zeros_slot(v, "mg")
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ # epsilon is now the rms initial value and is not added to the
+ # denominator anymore, hence calling the kernel op with epsilon=0.
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+ else:
+ return training_ops.apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.resource_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.sparse_apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.sparse_apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = self.get_slot(var, "mg")
+ return training_ops.resource_sparse_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_sparse_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
new file mode 100644
index 0000000000..2c5eccdc5b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
@@ -0,0 +1,444 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for rmsprop optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import math
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+_DATA_TYPES = [dtypes.half, dtypes.float32]
+
+_TEST_PARAM_VALUES = [
+ # learning_rate, rho, momentum, epsilon, centered, use_resource
+ [0.5, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.9, 0.0, 1.0, False, False],
+ [0.5, 0.9, 0.0, 1.0, True, True],
+ [0.5, 0.9, 0.0, 1.0, False, True],
+ [0.1, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.95, 0.0, 1.0, False, False],
+ [0.5, 0.8, 0.0, 1e-3, True, False],
+ [0.5, 0.8, 0.9, 1e-3, True, False],
+]
+
+
+class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
+
+ def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, rho, momentum,
+ centered):
+ rms_t = rms * rho + (1 - rho) * g * g
+ if centered:
+ mg_t = mg * rho + (1 - rho) * g
+ denom_t = rms_t - mg_t * mg_t
+ else:
+ mg_t = mg
+ denom_t = rms_t
+ mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
+ var_t = var - mom_t
+ return var_t, mg_t, rms_t, mom_t
+
+ def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
+ lr, rho, momentum, centered):
+ mg_t = copy.deepcopy(mg)
+ rms_t = copy.deepcopy(rms)
+ mom_t = copy.deepcopy(mom)
+ var_t = copy.deepcopy(var)
+ for i in range(len(gindexs)):
+ gindex = gindexs[i]
+ gvalue = gvalues[i]
+ rms_t[gindex] = rms[gindex] * rho + (1 - rho) * gvalue * gvalue
+ denom_t = rms_t[gindex]
+ if centered:
+ mg_t[gindex] = mg_t[gindex] * rho + (1 - rho) * gvalue
+ denom_t -= mg_t[gindex] * mg_t[gindex]
+ mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t)
+ var_t[gindex] = var[gindex] - mom_t[gindex]
+ return var_t, mg_t, rms_t, mom_t
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testDense(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered,
+ use_resource) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, rho,
+ momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, rho,
+ momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=0.01, half_atol=0.01)
+ self.assertAllCloseAccordingToType(
+ var1_np, var1.eval(), half_rtol=0.01, half_atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariable(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.0, momentum=0.0, epsilon=0.0,
+ centered=False).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0., 1.]], var0.eval(), atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariableCentered(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.1, momentum=0.0, epsilon=1.0,
+ centered=True).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-7/3.0, -4/3.0]], var0.eval(), atol=0.01)
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testSparse(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered, _) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([1]))
+ grads1_np_indices = np.array([1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([1]))
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
+ var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
+ learning_rate, rho, momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
+ var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np,
+ learning_rate, rho, momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithoutMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.5, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: rms = 1, mom = 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the momentum accumulators
+ self.assertAllCloseAccordingToType(
+ np.array([(0.1 * 2.0 / math.sqrt(0.901)),
+ (0.1 * 2.0 / math.sqrt(0.901))]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.01 * 2.0 / math.sqrt(0.90001)),
+ (0.01 * 2.0 / math.sqrt(0.90001))]), mom1.eval())
+
+ # Check that the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), mom1.eval())
+
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)))
+ ]), var0.eval())
+
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)))
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/sgd.py b/tensorflow/python/keras/optimizer_v2/sgd.py
new file mode 100644
index 0000000000..f5583691f7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd.py
@@ -0,0 +1,170 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Momentum for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import training_ops
+
+
+class SGD(optimizer_v2.OptimizerV2):
+ """Stochastic gradient descent optimizer.
+
+ Includes support for momentum and Nesterov momentum.
+
+ Computes (if `nesterov = False`):
+
+ ```
+ accumulation = momentum * accumulation + gradient
+ variable -= learning_rate * accumulation
+ ```
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense version of this algorithm, `accumulation` is updated
+ and applied regardless of a gradient's value, whereas the sparse version (when
+ the gradient is an `IndexedSlices`, typically because of `tf.gather` or an
+ embedding) only updates variable slices and corresponding `accumulation` terms
+ when that part of the variable was used in the forward pass.
+
+ @compatibility(eager)
+ When eager execution is enabled, learning_rate and momentum can each be a
+ callable that takes no arguments and returns the actual value to use. This
+ can be useful for changing these values across different invocations of
+ optimizer functions.
+ @end_compatibility
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ momentum: float hyperparameter >= 0 or None. Parameter that accelerates
+ SGD in the relevant direction and dampens oscillations.
+ nesterov: boolean. Whether to apply Nesterov momentum. See [Sutskever et
+ al., 2013](http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). This
+ implementation always computes gradients at the value of the
+ variable(s) passed to the optimizer. Using Nesterov Momentum makes the
+ variable(s) track the values called `theta_t + mu*v_t` in the paper.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'SGD'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ momentum=None,
+ nesterov=False,
+ name="SGD"):
+ super(SGD, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ # Only create momentum variables and use momentum ops if needed.
+ if momentum is not None:
+ self._set_hyper("momentum", momentum)
+ self._use_nesterov = nesterov
+ self._use_momentum = True
+ else:
+ self._use_momentum = False
+
+ def _create_vars(self, var_list, state):
+ if self._use_momentum:
+ for v in var_list:
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return training_ops.apply_gradient_descent(
+ var,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return training_ops.resource_apply_gradient_descent(
+ var.handle, lr, grad, use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return super(SGD, self)._apply_sparse(grad, var, state)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_sparse_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ return super(SGD, self)._resource_apply_sparse(grad, var, indices, state)
+
+ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices, state):
+ if self._use_momentum:
+ return super(SGD, self)._resource_apply_sparse_duplicate_indices(
+ grad, var, indices, state)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return resource_variable_ops.resource_scatter_add(var.handle, indices,
+ -grad * lr)
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ if self._use_momentum:
+ return super(SGD, self)._apply_sparse_duplicate_indices(grad, var, state)
+ else:
+ delta = ops.IndexedSlices(
+ grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.indices, grad.dense_shape)
+ return var.scatter_sub(delta, use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/sgd_test.py b/tensorflow/python/keras/optimizer_v2/sgd_test.py
new file mode 100644
index 0000000000..eb39aac283
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd_test.py
@@ -0,0 +1,759 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Momentum."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class GradientDescentOptimizerTest(test.TestCase):
+
+ def testBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ optimizer = sgd.SGD(3.0)
+ sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertEqual(0, len(optimizer.variables()))
+
+ def testBasicResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testMinimizeResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(var0, x) + var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ pred += var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ lrate = constant_op.constant(3.0)
+ sgd_op = sgd.SGD(lrate).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testGradWrtRef(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ opt = sgd.SGD(3.0)
+ values = [1.0, 3.0]
+ vars_ = [variables.Variable([v], dtype=dtype) for v in values]
+ grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_)
+ variables.global_variables_initializer().run()
+ for grad, _ in grads_and_vars:
+ self.assertAllCloseAccordingToType([1.0], grad.eval())
+
+ def testWithGlobalStep(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ global_step = variables.Variable(0, trainable=False)
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params and global_step
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertAllCloseAccordingToType(1, global_step.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]],
+ var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
+ var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
+
+
+class MomentumOptimizerTest(test.TestCase):
+
+ def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
+ var = var + accum * lr * momentum
+ accum = accum * momentum + g
+ var = var - lr * accum
+ var = var - accum * lr * momentum
+ return var, accum
+
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ learning_rate = lambda: 2.0
+ momentum = lambda: 0.9
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ momentum = momentum()
+ mom_opt = sgd.SGD(learning_rate=learning_rate, momentum=momentum)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ if not context.executing_eagerly():
+ self.assertFalse(slot0 in variables.trainable_variables())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ if not context.executing_eagerly():
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+ self.evaluate(var1))
+ # Step 2: the momentum accumulators contain the previous update.
+ if context.executing_eagerly():
+ mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ else:
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), self.evaluate(var1))
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testVariablesAcrossGraphs(self):
+ optimizer = sgd.SGD(0.01, 0.5)
+ with ops.Graph().as_default():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var0")
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var1")
+ loss = math_ops.reduce_sum(var0 + var1)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var0")
+ self.assertStartsWith(optimizer_variables[1].name, "var1")
+ self.assertEquals(2, len(optimizer_variables))
+
+ with ops.Graph().as_default():
+ var2 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var2")
+ var3 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var3")
+ loss = math_ops.reduce_sum(var2 + var3)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var2")
+ self.assertStartsWith(optimizer_variables[1].name, "var3")
+ self.assertEquals(2, len(optimizer_variables))
+
+ def testNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ cost = 5 * var0 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name="global_step")
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ opt_op = mom_op.minimize(cost, global_step, [var0, var1])
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_op.run()
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testSparseNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ grads = []
+ for t in range(1, 5):
+ grads.append(var0_np * 10)
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ loss = 5 * var0 * var0 + 3 * var1
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ x_feed = array_ops.placeholder(dtype)
+ y_feed = ops.IndexedSlices(
+ x_feed, constant_op.constant([0, 1]), constant_op.constant([2]))
+ grads_and_vars = [(y_feed, var0), (constant_op.constant(
+ [3.0, 3.0], dtype=dtype), var1)]
+ opt_update = mom_op.apply_gradients(grads_and_vars)
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_update.run(feed_dict={x_feed: grads[t - 1]})
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ return pred * pred
+ # pylint: enable=cell-var-from-loop
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ # Run 1 step of sgd
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+
+ def loss():
+ return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(sgd_op)
+ self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0))
+
+ def testTensorLearningRateAndMomentum(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(
+ learning_rate=constant_op.constant(2.0),
+ momentum=constant_op.constant(0.9))
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ self.assertFalse(slot0 in variables.trainable_variables())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+ def _dbParamsMom01(self):
+ """Return dist-belief momentum values.
+
+ Return values been generated from the dist-belief momentum unittest,
+ running with a learning rate of 0.1 and a momentum of 0.1.
+
+ These values record how a parameter vector of size 10, initialized with 0.0,
+ gets updated with 10 consecutive momentum steps. It uses random gradients.
+
+ Returns:
+ db_grad: The gradients to apply
+ db_out: The parameters after the momentum update.
+ """
+ db_grad = [[]] * 10
+ db_out = [[]] * 10
+ # pylint: disable=line-too-long
+ db_grad[0] = [
+ 0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018,
+ 0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615
+ ]
+ db_out[0] = [
+ -9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018,
+ -0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618
+ ]
+ db_grad[1] = [
+ 0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378,
+ 0.5513742, 0.94687688, 0.16012503, 0.22159521
+ ]
+ db_out[1] = [
+ -0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884,
+ -0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544
+ ]
+ db_grad[2] = [
+ 0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965,
+ 0.31168157, 0.43203235, 0.16792089, 0.24644311
+ ]
+ db_out[2] = [
+ -0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978,
+ -0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189
+ ]
+ db_grad[3] = [
+ 0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098,
+ 0.81454384, 0.03848977, 0.89759839, 0.93665648
+ ]
+ db_out[3] = [
+ -0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105,
+ -0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303
+ ]
+ db_grad[4] = [
+ 0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359,
+ 0.69107032, 0.81897682, 0.5433259, 0.67860287
+ ]
+ db_out[4] = [
+ -0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165,
+ -0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544
+ ]
+ db_grad[5] = [
+ 0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563,
+ 0.84163809, 0.41172323, 0.83259648, 0.44941229
+ ]
+ db_out[5] = [
+ -0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094,
+ -0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717
+ ]
+ db_grad[6] = [
+ 0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221,
+ 0.73577434, 0.16014607, 0.57500273, 0.071136251
+ ]
+ db_out[6] = [
+ -0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685,
+ -0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997
+ ]
+ db_grad[7] = [
+ 0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646,
+ 0.74053431, 0.16033, 0.66625422, 0.73515922
+ ]
+ db_out[7] = [
+ -0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838,
+ -0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418
+ ]
+ db_grad[8] = [
+ 0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039,
+ 0.55561525, 0.22567581, 0.93331909, 0.29438227
+ ]
+ db_out[8] = [
+ -0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527,
+ -0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781
+ ]
+ db_grad[9] = [
+ 0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893,
+ 0.68593478, 0.50580865, 0.12602448, 0.093537711
+ ]
+ db_out[9] = [
+ -0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302,
+ -0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295
+ ]
+ # pylint: enable=line-too-long
+ return db_grad, db_out
+
+ def testLikeDistBeliefMom01(self):
+ with self.cached_session():
+ db_grad, db_out = self._dbParamsMom01()
+ num_samples = len(db_grad)
+ var0 = variables.Variable([0.0] * num_samples)
+ grads0 = constant_op.constant([0.0] * num_samples)
+ mom_opt = sgd.SGD(learning_rate=0.1, momentum=0.1)
+ mom_update = mom_opt.apply_gradients(zip([grads0], [var0]))
+ variables.global_variables_initializer().run()
+ for i in xrange(num_samples):
+ mom_update.run(feed_dict={grads0: db_grad[i]})
+ self.assertAllClose(np.array(db_out[i]), var0.eval())
+
+ def testSparse(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
+ var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.1, .1]], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([4, 2]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.01, .01], [.01, .01]], dtype=dtype),
+ constant_op.constant([2, 3]),
+ constant_op.constant([4, 2]))
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([0, 0], var0.eval()[0])
+ self.assertAllClose([0, 0], var0.eval()[1])
+ self.assertAllClose([1, 1], var1.eval()[2])
+
+ # Step 1: the momentum accumulators are 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(np.array([.1, .1]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([.01, .01]), slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([-(0.1 * 2.0), -(0.1 * 2.0)]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.01 * 2.0), 1.0 - (0.01 * 2.0)]), var1.eval()[2])
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ -(0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), -(0.1 * 2.0) - (
+ (0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.98 - ((0.9 * 0.01 + 0.01) * 2.0), 0.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval()[2])
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update1 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ mom_update2 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update1.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the second momentum accumulators contain the previous update.
+ mom_update2.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()