diff options
author | 2018-03-29 15:28:24 -0700 | |
---|---|---|
committer | 2018-03-29 15:32:01 -0700 | |
commit | 6f5d7a97cd2c0741ddfa756853ce5321377b5d53 (patch) | |
tree | e79afd91cd68bc9ed75bfe278511312da3918fe6 /tensorflow/contrib/optimizer_v2 | |
parent | 40f8291db5c0b05b31d7bbe23b847cdbb2408718 (diff) |
Add tf.contrib.distribute, which defines classes DistributionStrategy
and MirroredStrategy, and related functionality.
Also add tf.contrib.optimizer_v2, an update to the Optimizer API.
RELNOTES: Can now pass tf.contrib.distribute.MirroredStrategy() to
tf.estimator.RunConfig() to run an Estimator model on multiple GPUs
on one machine.
PiperOrigin-RevId: 190996247
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
17 files changed, 5454 insertions, 0 deletions
diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD new file mode 100644 index 0000000000..26ea9135f5 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -0,0 +1,205 @@ +# Prototype of OptimizerV2. + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "optimizer_v2_py", + srcs = ["optimizer_v2_symbols.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":training", + "//tensorflow/python:util", + ], +) + +py_library( + name = "training", + srcs = [ + "adadelta.py", + "adagrad.py", + "adam.py", + "gradient_descent.py", + "momentum.py", + "optimizer_v2.py", + "rmsprop.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( + name = "adadelta_test", + size = "medium", + srcs = ["adadelta_test.py"], + additional_deps = [ + ":training", + "//tensorflow/python:client_testlib", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "adagrad_test", + size = "small", + srcs = ["adagrad_test.py"], + additional_deps = [ + ":training", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "adam_test", + size = "small", + srcs = ["adam_test.py"], + additional_deps = [ + ":training", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "checkpointable_utils_test", + srcs = ["checkpointable_utils_test.py"], + additional_deps = [ + ":training", + "@six_archive//:six", + "//tensorflow/contrib/eager/python:checkpointable_utils", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:layers_base", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/keras", + ], + tags = ["notsan"], +) + +cuda_py_test( + name = "gradient_descent_test", + size = "medium", + srcs = ["gradient_descent_test.py"], + additional_deps = [ + ":training", + "//tensorflow/python:client_testlib", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:resources", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( + name = "momentum_test", + size = "medium", + srcs = ["momentum_test.py"], + additional_deps = [ + ":training", + "//tensorflow/python:client_testlib", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:resources", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "optimizer_v2_test", + size = "medium", + srcs = ["optimizer_v2_test.py"], + additional_deps = [ + ":training", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:array_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:variables", + ], +) + +cuda_py_test( + name = "rmsprop_test", + size = "small", + srcs = ["rmsprop_test.py"], + additional_deps = [ + ":training", + "//tensorflow/python:array_ops", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/optimizer_v2/adadelta.py b/tensorflow/contrib/optimizer_v2/adadelta.py new file mode 100644 index 0000000000..b206f9f61b --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/adadelta.py @@ -0,0 +1,113 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adadelta for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.training import training_ops + + +class AdadeltaOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the Adadelta algorithm. + + See [M. D. Zeiler](http://arxiv.org/abs/1212.5701) + ([pdf](http://arxiv.org/pdf/1212.5701v1.pdf)) + """ + + def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-8, + use_locking=False, name="Adadelta"): + """Construct a new Adadelta optimizer. + + Some of the args below are hyperparameters, where a hyperparameter is + defined as a scalar Tensor, a regular Python value or a callable (which + will be evaluated when `apply_gradients` is called) returning a scalar + Tensor or a Python value. + + Args: + learning_rate: A float hyperparameter. The learning rate. + To match the exact form in the original paper use 1.0. + rho: A float hyperparameter. The decay rate. + epsilon: A float hyperparameter. A constant epsilon used to better + condition the grad update. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Adadelta". + """ + super(AdadeltaOptimizer, self).__init__(use_locking, name) + self._set_hyper("learning_rate", learning_rate) + self._set_hyper("rho", rho) + self._set_hyper("epsilon", epsilon) + + def _create_vars(self, var_list, state): + for v in var_list: + state.zeros_slot(v, "accum") + state.zeros_slot(v, "accum_update") + + def _apply_dense(self, grad, var, state): + accum = state.get_slot(var, "accum") + accum_update = state.get_slot(var, "accum_update") + return training_ops.apply_adadelta( + var, + accum, + accum_update, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("rho", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + use_locking=self._use_locking) + + def _resource_apply_dense(self, grad, var, state): + accum = state.get_slot(var, "accum") + accum_update = state.get_slot(var, "accum_update") + return training_ops.resource_apply_adadelta( + var.handle, + accum.handle, + accum_update.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("rho", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + use_locking=self._use_locking) + + def _apply_sparse(self, grad, var, state): + accum = state.get_slot(var, "accum") + accum_update = state.get_slot(var, "accum_update") + return training_ops.sparse_apply_adadelta( + var, + accum, + accum_update, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("rho", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad.values, + grad.indices, + use_locking=self._use_locking) + + def _resource_apply_sparse(self, grad, var, indices, state): + accum = state.get_slot(var, "accum") + accum_update = state.get_slot(var, "accum_update") + return training_ops.resource_sparse_apply_adadelta( + var.handle, + accum.handle, + accum_update.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("rho", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + indices, + use_locking=self._use_locking) diff --git a/tensorflow/contrib/optimizer_v2/adadelta_test.py b/tensorflow/contrib/optimizer_v2/adadelta_test.py new file mode 100644 index 0000000000..31cfec0d50 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/adadelta_test.py @@ -0,0 +1,167 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adadelta Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.optimizer_v2 import adadelta +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class AdadeltaOptimizerTest(test.TestCase): + + def doTestBasic(self, use_resource=False): + num_updates = 4 # number of ADADELTA steps to perform + for dtype in [dtypes.half, dtypes.float32]: + for grad in [0.2, 0.1, 0.01]: + for lr in [1.0, 0.5, 0.1]: + with self.test_session(): + var0_init = [1.0, 2.0] + var1_init = [3.0, 4.0] + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_init, dtype=dtype) + var1 = resource_variable_ops.ResourceVariable( + var1_init, dtype=dtype) + else: + var0 = variables.Variable(var0_init, dtype=dtype) + var1 = variables.Variable(var1_init, dtype=dtype) + + grads = constant_op.constant([grad, grad], dtype=dtype) + + accum = 0.0 + accum_update = 0.0 + + # ADADELTA gradient optimizer + rho = 0.95 + epsilon = 1e-8 + adadelta_opt = adadelta.AdadeltaOptimizer(lr, rho, epsilon) + adadelta_update = adadelta_opt.apply_gradients( + zip([grads, grads], [var0, var1])) + + opt_vars = adadelta_opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var0._shared_name) + self.assertStartsWith(opt_vars[2].name, var1._shared_name) + self.assertStartsWith(opt_vars[3].name, var1._shared_name) + self.assertEqual(4, len(opt_vars)) + + variables.global_variables_initializer().run() + + # Assign slots + slot = [None] * 2 + slot_update = [None] * 2 + self.assertEqual(["accum", "accum_update"], + adadelta_opt.get_slot_names()) + slot[0] = adadelta_opt.get_slot(var0, "accum") + self.assertEquals(slot[0].get_shape(), var0.get_shape()) + self.assertFalse(slot[0] in variables.trainable_variables()) + + slot_update[0] = adadelta_opt.get_slot(var0, "accum_update") + self.assertEquals(slot_update[0].get_shape(), var0.get_shape()) + self.assertFalse(slot_update[0] in variables.trainable_variables()) + + slot[1] = adadelta_opt.get_slot(var1, "accum") + self.assertEquals(slot[1].get_shape(), var1.get_shape()) + self.assertFalse(slot[1] in variables.trainable_variables()) + + slot_update[1] = adadelta_opt.get_slot(var1, "accum_update") + self.assertEquals(slot_update[1].get_shape(), var1.get_shape()) + self.assertFalse(slot_update[1] in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose(var0_init, var0.eval()) + self.assertAllClose(var1_init, var1.eval()) + + update = [None] * num_updates + tot_update = 0 + for step in range(num_updates): + # Run adadelta update for comparison + adadelta_update.run() + + # Perform initial update without previous accum values + accum = accum * rho + (grad**2) * (1 - rho) + update[step] = (np.sqrt(accum_update + epsilon) * + (1. / np.sqrt(accum + epsilon)) * grad) + accum_update = (accum_update * rho + (update[step]**2) * + (1.0 - rho)) + tot_update += update[step] * lr + + # Check that the accumulators have been updated + for slot_idx in range(2): + self.assertAllCloseAccordingToType( + np.array([accum, accum], dtype=dtype.as_numpy_dtype()), + slot[slot_idx].eval(), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array( + [accum_update, accum_update], + dtype=dtype.as_numpy_dtype()), + slot_update[slot_idx].eval(), + rtol=1e-5) + + # Check that the parameters have been updated + self.assertAllCloseAccordingToType( + np.array( + [var0_init[0] - tot_update, var0_init[1] - tot_update], + dtype=dtype.as_numpy_dtype()), + var0.eval(), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array( + [var1_init[0] - tot_update, var1_init[1] - tot_update], + dtype=dtype.as_numpy_dtype()), + var1.eval(), + rtol=1e-5) + + def testBasic(self): + self.doTestBasic(use_resource=False) + + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testMinimizeSparseResourceVariable(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = adadelta.AdadeltaOptimizer( + 1.0, 1.0, 1.0).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[-111, -138]], var0.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py new file mode 100644 index 0000000000..e54f990cca --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/adagrad.py @@ -0,0 +1,118 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adagrad optimizer for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import training_ops + + +class AdagradOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the Adagrad algorithm. + + See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) + or this + [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf). + """ + + def __init__(self, learning_rate, initial_accumulator_value=0.1, + use_locking=False, name="Adagrad"): + """Construct a new Adagrad optimizer. + + The learning_rate arg below is a hyperparameter, where a hyperparameter is + defined as a scalar Tensor, a regular Python value or a callable (which + will be evaluated when `apply_gradients` is called) returning a scalar + Tensor or a Python value. + + Args: + learning_rate: A float hyperparameter. The learning rate. + initial_accumulator_value: A floating point value. + Starting value for the accumulators, must be positive. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Adagrad". + + Raises: + ValueError: If the `initial_accumulator_value` is invalid. + """ + if initial_accumulator_value <= 0.0: + raise ValueError("initial_accumulator_value must be positive: %s" % + initial_accumulator_value) + super(AdagradOptimizer, self).__init__(use_locking, name) + self._set_hyper("learning_rate", learning_rate) + + self._initial_accumulator_value = initial_accumulator_value + + def _create_vars(self, var_list, state): + for v in var_list: + with ops.colocate_with(v): + dtype = v.dtype.base_dtype + if v.get_shape().is_fully_defined(): + init = init_ops.constant_initializer(self._initial_accumulator_value, + dtype=dtype) + else: + # Use a Tensor instead of initializer if variable does not have static + # shape. + init_constant = gen_array_ops.fill( + array_ops.shape(v), self._initial_accumulator_value) + init = math_ops.cast(init_constant, dtype) + state.create_slot_with_initializer(v, init, v.get_shape(), dtype, + "accumulator") + + def _apply_dense(self, grad, var, state): + acc = state.get_slot(var, "accumulator") + return training_ops.apply_adagrad( + var, + acc, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad, + use_locking=self._use_locking) + + def _resource_apply_dense(self, grad, var, state): + acc = state.get_slot(var, "accumulator") + return training_ops.resource_apply_adagrad( + var.handle, + acc.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad, + use_locking=self._use_locking) + + def _apply_sparse(self, grad, var, state): + acc = state.get_slot(var, "accumulator") + return training_ops.sparse_apply_adagrad( + var, + acc, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad.values, + grad.indices, + use_locking=self._use_locking) + + def _resource_apply_sparse(self, grad, var, indices, state): + acc = state.get_slot(var, "accumulator") + return training_ops.resource_sparse_apply_adagrad( + var.handle, + acc.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad, + indices, + use_locking=self._use_locking) diff --git a/tensorflow/contrib/optimizer_v2/adagrad_test.py b/tensorflow/contrib/optimizer_v2/adagrad_test.py new file mode 100644 index 0000000000..18191c3ef2 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/adagrad_test.py @@ -0,0 +1,282 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for aggregate operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.optimizer_v2 import adagrad +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class AdagradOptimizerTest(test.TestCase): + + def doTestBasic(self, use_locking=False, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + if use_resource: + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + else: + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + ada_opt = adagrad.AdagradOptimizer( + 3.0, initial_accumulator_value=0.1, use_locking=use_locking) + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 3 steps of adagrad + for _ in range(3): + ada_update.run() + # Validate updated params + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + + def testBasic(self): + self.doTestBasic(use_locking=False) + + def testBasicResource(self): + self.doTestBasic(use_locking=False, use_resource=True) + + def testBasicLocked(self): + self.doTestBasic(use_locking=True) + + def testMinimizeSparseResourceVariable(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable( + [[1.0, 2.0], [3.0, 4.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = adagrad.AdagradOptimizer(1.0).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType( + [[1.0, 2.0], [3.0, 4.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[0, 1], [3, 4]], var0.eval(), atol=0.01) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + ada_opt = adagrad.AdagradOptimizer( + constant_op.constant(3.0), initial_accumulator_value=0.1) + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 3 steps of adagrad + for _ in range(3): + ada_update.run() + # Validate updated params + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant( + [0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), + constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant( + [0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + ada_opt = adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1) + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([[1.0], [2.0]], var0.eval()) + self.assertAllClose([[3.0], [4.0]], var1.eval()) + # Run 3 step of sgd + for _ in range(3): + ada_update.run() + # Validate updated params + self.assertAllCloseAccordingToType( + np.array([[-1.6026098728179932], [2.0]]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([[3.0], [3.715679168701172]]), var1.eval()) + + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adagrad.AdagradOptimizer(3.0).apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update = adagrad.AdagradOptimizer(3.0).apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def testSparseRepeatedIndicesResourceVariable(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var_repeated = resource_variable_ops.ResourceVariable( + [1.0, 2.0], dtype=dtype) + loss_repeated = math_ops.reduce_sum( + embedding_ops.embedding_lookup(var_repeated, [0, 0])) + var_aggregated = resource_variable_ops.ResourceVariable( + [1.0, 2.0], dtype=dtype) + loss_aggregated = 2 * math_ops.reduce_sum( + embedding_ops.embedding_lookup(var_aggregated, [0])) + update_op_repeated = adagrad.AdagradOptimizer( + 2.0).minimize(loss_repeated) + update_op_aggregated = adagrad.AdagradOptimizer( + 2.0).minimize(loss_aggregated) + variables.global_variables_initializer().run() + self.assertAllCloseAccordingToType( + var_repeated.eval(), var_aggregated.eval()) + for _ in range(3): + update_op_repeated.run() + update_op_aggregated.run() + self.assertAllCloseAccordingToType( + var_repeated.eval(), var_aggregated.eval()) + + def testSparseStability(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + shape = [1, 6] + var0 = variables.Variable( + [[ + 0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257, + -0.0105945 + ]], + dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant( + [[ + -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05, + -8.4877e-05, -9.48906e-05 + ]], + shape=shape, + dtype=dtype), + constant_op.constant([0]), + constant_op.constant(shape)) + ada_opt = adagrad.AdagradOptimizer(1.0, initial_accumulator_value=0.1) + ada_update = ada_opt.apply_gradients(zip([grads0], [var0])) + self.assertEqual(["accumulator"], ada_opt.get_slot_names()) + slot0 = ada_opt.get_slot(var0, "accumulator") + init = variables.global_variables_initializer() + for _ in range(100): + init.run() + ada_update.run() + self.assertAllCloseAccordingToType( + np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval()) + self.assertAllCloseAccordingToType( + np.array([[ + 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573, + -0.01029443 + ]]), var0.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + ada_opt = adagrad.AdagradOptimizer(3.0) + # Apply the optimizer twice. Both applications will use + # the same accums. + ada_update1 = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + ada_update2 = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.assertEqual(["accumulator"], ada_opt.get_slot_names()) + slot0 = ada_opt.get_slot(var0, "accumulator") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + slot1 = ada_opt.get_slot(var1, "accumulator") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values. + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Mix the first and the second adagrad for 3 steps. + ada_update1.run() + ada_update2.run() + ada_update1.run() + # Validate updated params (the same as with only 1 Adagrad). + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([2.715679168701172, 3.715679168701172]), var1.eval()) + + def testDynamicShapeVariable_Ok(self): + with self.test_session(): + v = variable_scope.get_variable("v", initializer=constant_op.constant(1.), + validate_shape=False) + self.assertFalse(v.shape.is_fully_defined()) + # Creating optimizer should cause no exception. + adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py new file mode 100644 index 0000000000..42b7f92a76 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -0,0 +1,202 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Adam optimizer for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training_ops + + +class AdamOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the Adam algorithm. + + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, + use_locking=False, name="Adam"): + """Construct a new Adam optimizer. + + Initialization: + + ``` + m_0 <- 0 (Initialize initial 1st moment vector) + v_0 <- 0 (Initialize initial 2nd moment vector) + t <- 0 (Initialize timestep) + ``` + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + ``` + t <- t + 1 + lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) + + m_t <- beta1 * m_{t-1} + (1 - beta1) * g + v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g + variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) + ``` + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Some of the args below are hyperparameters where a hyperparameter is + defined as a scalar Tensor, a regular Python value or a callable (which + will be evaluated when `apply_gradients` is called) returning a scalar + Tensor or a Python value. + + Args: + learning_rate: A float hyperparameter. The learning rate. + beta1: A float hyperparameter. The exponential decay rate for the 1st + moment estimates. + beta2: A float hyperparameter. The exponential decay rate for the 2nd + moment estimates. + epsilon: A float hyperparameter. This epsilon is "epsilon hat" in the + Kingma and Ba paper (in the formula just before Section 2.1), not the + epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(AdamOptimizer, self).__init__(use_locking, name) + + self._set_hyper("learning_rate", learning_rate) + self._set_hyper("beta1", beta1) + self._set_hyper("beta2", beta2) + self._set_hyper("epsilon", epsilon) + + def _get_beta_accumulators(self, state=None): + if state is None: + state = self._get_per_graph_state() + return (state.get_non_slot("beta1_power"), + state.get_non_slot("beta2_power")) + + def _create_vars(self, var_list, state): + # Non-slot variables end up on the same device(s). + state.create_non_slot(initial_value=state.get_hyper("beta1"), + name="beta1_power") + state.create_non_slot(initial_value=state.get_hyper("beta2"), + name="beta2_power") + + # Create slots for the first and second moments. + for v in var_list: + state.zeros_slot(v, "m") + state.zeros_slot(v, "v") + + def _apply_dense(self, grad, var, state): + m = state.get_slot(var, "m") + v = state.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators(state) + return training_ops.apply_adam( + var, m, v, + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("beta1", var.dtype.base_dtype), + state.get_hyper("beta2", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var, state): + m = state.get_slot(var, "m") + v = state.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators(state) + return training_ops.resource_apply_adam( + var.handle, m.handle, v.handle, + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), + state.get_hyper("learning_rate", grad.dtype.base_dtype), + state.get_hyper("beta1", grad.dtype.base_dtype), + state.get_hyper("beta2", grad.dtype.base_dtype), + state.get_hyper("epsilon", grad.dtype.base_dtype), + grad, use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add, state): + beta1_power, beta2_power = self._get_beta_accumulators(state) + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype) + beta1_t = state.get_hyper("beta1", var.dtype.base_dtype) + beta2_t = state.get_hyper("beta2", var.dtype.base_dtype) + epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = state.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, + use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = state.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub(var, + lr * m_t / (v_sqrt + epsilon_t), + use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var, state): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, i, v, use_locking=self._use_locking), + state) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add( + x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices, state): + return self._apply_sparse_shared( + grad, var, indices, self._resource_scatter_add, state) + + def _finish(self, state): + # Update the power accumulators. + beta1_power, beta2_power = self._get_beta_accumulators(state) + update_beta1 = beta1_power.assign( + beta1_power * state.get_hyper("beta1"), + use_locking=self._use_locking) + update_beta2 = beta2_power.assign( + beta2_power * state.get_hyper("beta2"), + use_locking=self._use_locking) + return control_flow_ops.group(update_beta1, update_beta2) diff --git a/tensorflow/contrib/optimizer_v2/adam_test.py b/tensorflow/contrib/optimizer_v2/adam_test.py new file mode 100644 index 0000000000..d9ad58b0a6 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/adam_test.py @@ -0,0 +1,333 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adam optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.optimizer_v2 import adam +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = adam.AdamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def testSparseDevicePlacement(self): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = adam.AdamOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adam.AdamOptimizer().apply_gradients( + [(grad_repeated_index, repeated_index_update_var)]) + aggregated_update = adam.AdamOptimizer().apply_gradients( + [(grad_aggregated, aggregated_update_var)]) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = adam.AdamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertTrue(beta2_power is not None) + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam.AdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam.AdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = adam.AdamOptimizer() + g = ops.Graph() + with g.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = adam.AdamOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertEqual(6, len(set(opt.variables()))) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py new file mode 100644 index 0000000000..08f9699e85 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -0,0 +1,686 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# TODO(josh11b): Forked from contrib/eager/python to test OptimizerV2 the same way +# OptimizerV1 is tested. This file should be removed once the fork is resolved. + +import functools +import os + +import six + +from tensorflow.contrib.eager.python import checkpointable_utils +from tensorflow.contrib.optimizer_v2 import adam +from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.layers import core +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpointable +from tensorflow.python.training import saver as core_saver +from tensorflow.python.training import training_util + + +class NonLayerCheckpointable(checkpointable.Checkpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + self.a_variable = checkpointable_utils.add_variable( + self, name="a_variable", shape=[]) + + +# pylint: disable=not-callable +class MyModel(training.Model): + """A concrete Model for testing.""" + + def __init__(self): + super(MyModel, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() + + def call(self, values): + ret = self._second(self._named_dense(values)) + return ret + + +class _MirroringSaveable( + core_saver.BaseSaverBuilder.ResourceVariableSaveable): + + def __init__(self, primary_variable, mirrored_variable, name): + self._primary_variable = primary_variable + self._mirrored_variable = mirrored_variable + super(_MirroringSaveable, self).__init__( + self._primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return control_flow_ops.group( + self._primary_variable.assign(tensor), + self._mirrored_variable.assign(tensor)) + + +class _OwnsMirroredVariables(checkpointable.CheckpointableBase): + """A Checkpointable object which returns a more complex SaveableObject.""" + + def __init__(self): + self.non_dep_variable = variable_scope.get_variable( + name="non_dep_variable", initializer=6., use_resource=True) + self.mirrored = variable_scope.get_variable( + name="mirrored", initializer=15., use_resource=True) + + def _gather_saveables_for_checkpoint(self): + def _saveable_factory(name=self.non_dep_variable.name): + return _MirroringSaveable( + primary_variable=self.non_dep_variable, + mirrored_variable=self.mirrored, + name=name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + # The Saver sorts by name before parsing, so we need a name property. + @property + def name(self): + return self.non_dep_variable.name + + +class CheckpointingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNamingWithOptimizer(self): + input_value = constant_op.constant([[3.]]) + model = MyModel() + # A nuisance Model using the same optimizer. Its slot variables should not + # go in the checkpoint, since it is never depended on. + other_model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + optimizer_step = training_util.get_or_create_global_step() + root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=optimizer_step) + if context.executing_eagerly(): + optimizer.minimize( + lambda: model(input_value), + global_step=optimizer_step) + optimizer.minimize( + lambda: other_model(input_value), + global_step=optimizer_step) + else: + train_op = optimizer.minimize( + model(input_value), global_step=optimizer_step) + optimizer.minimize( + other_model(input_value), + global_step=optimizer_step) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + named_variables, serialized_graph = ( + checkpointable_utils._serialize_object_graph(root_checkpointable)) + expected_checkpoint_names = ( + # Created in the root node, so no prefix. + "optimizer_step", + "model/_second/kernel", + "model/_named_dense/kernel", + "model/_named_dense/bias", + # non-Layer dependency of the model + "model/_non_layer/a_variable", + # The optimizer creates two non-slot variables + "optimizer/beta1_power", + "optimizer/beta2_power", + # Slot variables + "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m", + "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v", + "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", + "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", + "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", + "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", + ) + suffix = "/.ATTRIBUTES/VARIABLE_VALUE" + expected_checkpoint_names = [ + name + suffix for name in expected_checkpoint_names] + six.assertCountEqual(self, expected_checkpoint_names, + named_variables.keys()) + # Check that we've mapped to the right variable objects (not exhaustive) + self.assertEqual( + "global_step:0", + named_variables["optimizer_step" + suffix].name) + self.assertEqual( + "my_model/dense_1/kernel:0", + named_variables["model/_second/kernel" + suffix].name) + self.assertEqual( + "my_model/dense/kernel:0", + named_variables["model/_named_dense/kernel" + suffix].name) + self.assertEqual( + "beta1_power:0", + named_variables["optimizer/beta1_power" + suffix].name) + self.assertEqual( + "beta2_power:0", + named_variables["optimizer/beta2_power" + suffix].name) + # Spot check the generated protocol buffers. + self.assertEqual("optimizer", + serialized_graph.nodes[0].children[1].local_name) + optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ + 1].node_id] + self.assertEqual("beta1_power", + optimizer_node.children[0].local_name) + self.assertEqual("beta1_power", + serialized_graph.nodes[optimizer_node.children[0].node_id] + .attributes[0].full_name) + self.assertEqual( + "my_model/dense/kernel", + serialized_graph.nodes[optimizer_node.slot_variables[0] + .original_variable_node_id] + .attributes[0].full_name) + # We strip off the :0 suffix, as variable.name-based saving does. + self.assertEqual( + "my_model/dense/kernel/Adam", + serialized_graph.nodes[optimizer_node.slot_variables[0] + .slot_variable_node_id] + .attributes[0].full_name) + self.assertEqual( + "my_model/dense/kernel/Adam:0", + optimizer.get_slot( + var=named_variables["model/_named_dense/kernel" + suffix], + name="m").name) + self.assertEqual( + "model/_named_dense/kernel" + suffix, + serialized_graph.nodes[ + optimizer_node.slot_variables[0] + .original_variable_node_id].attributes[0].checkpoint_key) + self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + self.assertEqual( + "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, + serialized_graph.nodes[ + optimizer_node.slot_variables[0] + .slot_variable_node_id].attributes[0].checkpoint_key) + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model) + input_value = constant_op.constant([[3.]]) + if context.executing_eagerly(): + optimizer.minimize( + lambda: model(input_value)) + else: + train_op = optimizer.minimize(model(input_value)) + # TODO(allenl): Make initialization more pleasant when graph building. + root_checkpointable.save_counter # pylint: disable=pointless-statement + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) + m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") + self.evaluate(state_ops.assign(m_bias_slot, [1.5])) + save_path = root_checkpointable.save(file_prefix=prefix) + self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) + self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) + optimizer_variables = self.evaluate(optimizer.variables()) + self.evaluate(state_ops.assign(m_bias_slot, [-2.])) + # Immediate restoration + status = root_checkpointable.restore(save_path=save_path).assert_consumed() + status.run_restore_ops() + self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) + self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) + self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) + if not context.executing_eagerly(): + return # Restore-on-create is only supported when executing eagerly + on_create_model = MyModel() + on_create_optimizer = adam.AdamOptimizer( + 0.001, + # Preserve beta1_power and beta2_power when appying gradients so we can + # test that they've been restored correctly. + beta1=1.0, beta2=1.0) + on_create_root = checkpointable_utils.Checkpoint( + optimizer=on_create_optimizer, model=on_create_model) + # Deferred restoration + status = on_create_root.restore(save_path=save_path) + on_create_model(constant_op.constant([[3.]])) # create variables + self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) + self.assertAllEqual([42.], + self.evaluate( + on_create_model._named_dense.variables[1])) + on_create_m_bias_slot = on_create_optimizer.get_slot( + on_create_model._named_dense.variables[1], "m") + # Optimizer slot variables are created when the original variable is + # restored. + self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) + self.assertAllEqual(optimizer_variables[2:], + self.evaluate(on_create_optimizer.variables())) + dummy_var = resource_variable_ops.ResourceVariable([1.]) + on_create_optimizer.minimize(loss=dummy_var.read_value) + status.assert_consumed() + beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() + self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) + self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) + + # TODO(allenl): Debug garbage created by this test in python3. + def testDeferredRestorationUsageEager(self): + """An idiomatic eager execution example.""" + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, + optimizer_step=training_util.get_or_create_global_step()) + root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + for _ in range(num_training_steps): + # TODO(allenl): Use a Dataset and serialize/checkpoint it. + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + lambda: model(input_value), # pylint: disable=cell-var-from-loop + global_step=root.optimizer_step) + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.optimizer_step.numpy()) + + def testUsageGraph(self): + """Expected usage when graph building.""" + with context.graph_mode(): + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + with ops.Graph().as_default(): + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, + global_step=training_util.get_or_create_global_step()) + input_value = constant_op.constant([[3.]]) + train_op = optimizer.minimize( + model(input_value), + global_step=root.global_step) + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + with self.test_session(graph=ops.get_default_graph()) as session: + status = root.restore(save_path=checkpoint_path) + status.initialize_or_restore(session=session) + if checkpoint_path is None: + self.assertEqual(0, training_continuation) + with self.assertRaises(AssertionError): + status.assert_consumed() + else: + status.assert_consumed() + for _ in range(num_training_steps): + session.run(train_op) + root.save(file_prefix=checkpoint_prefix, session=session) + self.assertEqual((training_continuation + 1) * num_training_steps, + session.run(root.global_step)) + self.assertEqual(training_continuation + 1, + session.run(root.save_counter)) + + @test_util.run_in_graph_and_eager_modes() + def testAgnosticUsage(self): + """Graph/eager agnostic usage.""" + # Does create garbage when executing eagerly due to ops.Graph() creation. + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + with ops.Graph().as_default(), self.test_session( + graph=ops.get_default_graph()), test_util.device(use_gpu=True): + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, + global_step=training_util.get_or_create_global_step()) + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + status = root.restore(save_path=checkpoint_path) + input_value = constant_op.constant([[3.]]) + train_fn = functools.partial( + optimizer.minimize, + functools.partial(model, input_value), + global_step=root.global_step) + if not context.executing_eagerly(): + train_fn = functools.partial(self.evaluate, train_fn()) + status.initialize_or_restore() + for _ in range(num_training_steps): + train_fn() + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + self.evaluate(root.global_step)) + self.assertEqual(training_continuation + 1, + self.evaluate(root.save_counter)) + + def _get_checkpoint_name(self, name): + root = checkpointable.Checkpointable() + checkpointable_utils.add_variable( + root, name=name, shape=[1, 2], dtype=dtypes.float64) + named_variables, _ = checkpointable_utils._serialize_object_graph(root) + checkpoint_name, = named_variables.keys() + with ops.name_scope("root/" + checkpoint_name): + pass # Make sure we can use this as an op name if we prefix it. + return checkpoint_name + + def testAnonymousVarsInInit(self): + + class Model(training.Model): + + def __init__(self): + super(Model, self).__init__() + self.w = resource_variable_ops.ResourceVariable(0.0) + self.b = resource_variable_ops.ResourceVariable(0.0) + self.vars = [self.w, self.b] + + def call(self, x): + return x * self.w + self.b + + with context.eager_mode(): + model = Model() + optimizer = adam.AdamOptimizer(learning_rate=0.05) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + checkpoint = checkpointable_utils.Checkpoint( + model=model, optimizer=optimizer) + for _ in range(2): + checkpoint.save(checkpoint_prefix) + with backprop.GradientTape() as tape: + loss = (constant_op.constant(1.) + - model(constant_op.constant(1.))) ** 2 + grad = tape.gradient(loss, model.vars) + optimizer.apply_gradients( + [(g, v) for g, v in zip(grad, model.vars)]) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testDeferredSlotRestoration(self): + checkpoint_directory = self.get_temp_dir() + + root = checkpointable.Checkpointable() + root.var = checkpointable_utils.add_variable( + root, name="var", initializer=0.) + optimizer = adam.AdamOptimizer(0.1) + if context.executing_eagerly(): + optimizer.minimize(root.var.read_value) + else: + train_op = optimizer.minimize(root.var) + # Note that `optimizer` has not been added as a dependency of + # `root`. Create a one-off grouping so that slot variables for `root.var` + # get initialized too. + self.evaluate(checkpointable_utils.gather_initializers( + checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) + self.evaluate(train_op) + self.evaluate(state_ops.assign(root.var, 12.)) + no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( + os.path.join(checkpoint_directory, "no_slots")) + root.optimizer = optimizer + self.evaluate(state_ops.assign(root.var, 13.)) + self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), + 14.)) + slots_path = checkpointable_utils.CheckpointableSaver(root).save( + os.path.join(checkpoint_directory, "with_slots")) + new_root = checkpointable.Checkpointable() + # Load the slot-containing checkpoint (deferred), then immediately overwrite + # the non-slot variable (also deferred). + slot_status = checkpointable_utils.CheckpointableSaver( + new_root).restore(slots_path) + no_slot_status = checkpointable_utils.CheckpointableSaver( + new_root).restore(no_slots_path) + with self.assertRaises(AssertionError): + no_slot_status.assert_consumed() + new_root.var = checkpointable_utils.add_variable( + new_root, name="var", shape=[]) + no_slot_status.assert_consumed() + no_slot_status.run_restore_ops() + self.assertEqual(12., self.evaluate(new_root.var)) + new_root.optimizer = adam.AdamOptimizer(0.1) + with self.assertRaisesRegexp(AssertionError, "beta1_power"): + slot_status.assert_consumed() + self.assertEqual(12., self.evaluate(new_root.var)) + if context.executing_eagerly(): + # Slot variables are only created with restoring initializers when + # executing eagerly. + self.assertEqual(14., self.evaluate( + new_root.optimizer.get_slot(name="m", var=new_root.var))) + else: + self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), + None) + if context.executing_eagerly(): + new_root.optimizer.minimize(new_root.var.read_value) + else: + train_op = new_root.optimizer.minimize(new_root.var) + # The slot variable now exists; restore() didn't create it, but we should + # now have a restore op for it. + slot_status.run_restore_ops() + self.assertEqual(14., self.evaluate( + new_root.optimizer.get_slot(name="m", var=new_root.var))) + self.evaluate(train_op) + slot_status.assert_consumed() + + def testManySavesGraph(self): + """Saves after the first should not modify the graph.""" + with context.graph_mode(): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + obj.opt = adam.AdamOptimizer(0.1) + obj.opt.minimize(obj.var.read_value()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.CheckpointableSaver(obj) + saver.save(checkpoint_prefix) + before_ops = graph.get_operations() + saver.save(checkpoint_prefix) + self.assertEqual(before_ops, graph.get_operations()) + + def testManyRestoresGraph(self): + """Restores after the first should not modify the graph.""" + with context.graph_mode(): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + obj.opt = adam.AdamOptimizer(0.1) + obj.opt.minimize(obj.var.read_value()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.CheckpointableSaver(obj) + save_path = saver.save(checkpoint_prefix) + saver.restore(save_path) + before_ops = graph.get_operations() + saver.restore(save_path) + self.assertEqual(before_ops, graph.get_operations()) + + def testMultipleGraphsNonSlotVariables(self): + with context.graph_mode(): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + optimizer = adam.AdamOptimizer(0.001) + # Construct a model in one graph + first_graph = ops.Graph() + first_session = session_lib.Session(graph=first_graph) + with first_graph.as_default(), first_session.as_default(): + first_variable = resource_variable_ops.ResourceVariable([1.]) + first_root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, variable=first_variable) + train_op = optimizer.minimize(first_variable.read_value) + self.evaluate(checkpointable_utils.gather_initializers( + first_root_checkpointable)) + self.evaluate(train_op) + self.evaluate(first_variable.assign([1.])) + self.evaluate(optimizer.get_slot( + var=first_variable, name="m").assign([2.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(3.)) + + # Save and load in a second graph + second_graph = ops.Graph() + with second_graph.as_default(), session_lib.Session(graph=second_graph): + second_variable = resource_variable_ops.ResourceVariable([1.]) + second_root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, variable=second_variable) + train_op = optimizer.minimize(second_variable.read_value) + second_root_checkpointable.restore(None).initialize_or_restore() + self.evaluate(train_op) + self.evaluate(second_variable.assign([4.])) + self.evaluate(optimizer.get_slot( + var=second_variable, name="m").assign([5.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(6.)) + save_path = second_root_checkpointable.save(checkpoint_prefix) + self.evaluate(second_variable.assign([7.])) + self.evaluate(optimizer.get_slot( + var=second_variable, name="m").assign([8.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.assertAllEqual(6., self.evaluate(beta1_power)) + status = second_root_checkpointable.restore(save_path) + status.assert_consumed().run_restore_ops() + self.assertAllEqual([4.], self.evaluate(second_variable)) + self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( + var=second_variable, name="m"))) + beta1_power, _ = optimizer._get_beta_accumulators() + self.assertAllEqual(6., self.evaluate(beta1_power)) + + # Check that the first graph is unmolested + with first_graph.as_default(), first_session.as_default(): + self.assertAllEqual([1.], self.evaluate(first_variable)) + self.assertAllEqual([2.], self.evaluate(optimizer.get_slot( + var=first_variable, name="m"))) + beta1_power, _ = optimizer._get_beta_accumulators() + self.assertAllEqual(3., self.evaluate(beta1_power)) + + +class CheckpointCompatibilityTests(test.TestCase): + + def _initialized_model(self): + input_value = constant_op.constant([[3.]]) + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + optimizer_step = training_util.get_or_create_global_step() + root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=optimizer_step) + train_op = optimizer.minimize( + functools.partial(model, input_value), + global_step=optimizer_step) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + # A regular variable, a slot variable, and a non-slot Optimizer variable + # with known values to check when loading. + self.evaluate(model._named_dense.bias.assign([1.])) + self.evaluate(optimizer.get_slot( + var=model._named_dense.bias, name="m").assign([2.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(3.)) + return root_checkpointable + + def _set_sentinels(self, root_checkpointable): + self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.])) + self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.model._named_dense.bias, name="m") + .assign([102.])) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(103.)) + + def _check_sentinels(self, root_checkpointable): + self.assertAllEqual( + [1.], self.evaluate(root_checkpointable.model._named_dense.bias)) + self.assertAllEqual([2.], self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.model._named_dense.bias, name="m"))) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.assertAllEqual(3., self.evaluate(beta1_power)) + + def _write_name_based_checkpoint(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + root = self._initialized_model() + name_saver = core_saver.Saver() + return name_saver.save( + sess=session, save_path=checkpoint_prefix, + global_step=root.optimizer_step) + + @test_util.run_in_graph_and_eager_modes() + def testLoadFromNameBasedSaver(self): + """Save a name-based checkpoint, load it using the object-based API.""" + with test_util.device(use_gpu=True): + save_path = self._write_name_based_checkpoint() + root = self._initialized_model() + self._set_sentinels(root) + with self.assertRaises(AssertionError): + self._check_sentinels(root) + object_saver = checkpointable_utils.CheckpointableSaver(root) + status = object_saver.restore(save_path) + with self.assertRaises(AssertionError): + status.assert_consumed() + status.run_restore_ops() + self._check_sentinels(root) + self._set_sentinels(root) + status.initialize_or_restore() + self._check_sentinels(root) + + # TODO(allenl): Test for the core name-based saver loading object-based + # checkpoints once object-based checkpointing is in core. + + def testSaveGraphLoadEager(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + root = self._initialized_model() + object_saver = checkpointable_utils.CheckpointableSaver(root) + save_path = object_saver.save( + session=session, file_prefix=checkpoint_prefix) + with context.eager_mode(): + root = self._initialized_model() + self._set_sentinels(root) + root.restore(save_path).assert_consumed() + self._check_sentinels(root) + + def testSaveEagerLoadGraph(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.eager_mode(): + root = self._initialized_model() + object_saver = checkpointable_utils.CheckpointableSaver(root) + save_path = object_saver.save(file_prefix=checkpoint_prefix) + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph): + root = self._initialized_model() + self._set_sentinels(root) + root.restore(save_path).assert_consumed().run_restore_ops() + self._check_sentinels(root) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/gradient_descent.py b/tensorflow/contrib/optimizer_v2/gradient_descent.py new file mode 100644 index 0000000000..945c8de559 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/gradient_descent.py @@ -0,0 +1,69 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""GradientDescent optimizer for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import training_ops + + +class GradientDescentOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the gradient descent algorithm.""" + + def __init__(self, learning_rate, use_locking=False, name="GradientDescent"): + """Construct a new gradient descent optimizer. + + The learning rate arg below is a hyperparameter where a hyperparameter is + defined as a scalar Tensor, a regular Python value or a callable (which + will be evaluated when `apply_gradients` is called) returning a scalar + Tensor or a Python value. + + Args: + learning_rate: A float hyperparameter. The learning rate to use. + use_locking: If True use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "GradientDescent". + """ + super(GradientDescentOptimizer, self).__init__(use_locking, name) + self._set_hyper("learning_rate", learning_rate) + + def _apply_dense(self, grad, var, state): + return training_ops.apply_gradient_descent( + var, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad, + use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, handle, state): + lr = state.get_hyper("learning_rate", grad.dtype.base_dtype) + return training_ops.resource_apply_gradient_descent( + handle.handle, lr, grad, use_locking=self._use_locking) + + def _resource_apply_sparse_duplicate_indices( + self, grad, handle, indices, state): + lr = state.get_hyper("learning_rate", grad.dtype.base_dtype) + return resource_variable_ops.resource_scatter_add( + handle.handle, indices, -grad * lr) + + def _apply_sparse_duplicate_indices(self, grad, var, state): + delta = ops.IndexedSlices( + grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype), + grad.indices, grad.dense_shape) + return var.scatter_sub(delta, use_locking=self._use_locking) diff --git a/tensorflow/contrib/optimizer_v2/gradient_descent_test.py b/tensorflow/contrib/optimizer_v2/gradient_descent_test.py new file mode 100644 index 0000000000..ad9aef804f --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/gradient_descent_test.py @@ -0,0 +1,223 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for GradientDescent optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import gradient_descent +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import resources +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class GradientDescentOptimizerTest(test.TestCase): + + def testBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + optimizer = gradient_descent.GradientDescentOptimizer(3.0) + sgd_op = optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + self.assertEqual(0, len(optimizer.variables())) + + def testBasicResourceVariable(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + sgd_op = gradient_descent.GradientDescentOptimizer(3.0).apply_gradients( + zip([grads0, grads1], [var0, var1])) + # TODO(apassos) calling initialize_resources on all resources here + # doesn't work because the sessions and graph are reused across unit + # tests and this would mean trying to reinitialize variables. Figure out + # a long-term solution for this. + resources.initialize_resources([var0, var1]).run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + + def testMinimizeResourceVariable(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(var0, x) + var1 + loss = pred * pred + sgd_op = gradient_descent.GradientDescentOptimizer(1.0).minimize(loss) + # TODO(apassos) calling initialize_resources on all resources here + # doesn't work because the sessions and graph are reused across unit + # tests and this would mean trying to reinitialize variables. Figure out + # a long-term solution for this. + resources.initialize_resources([var0, var1]).run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + self.assertAllCloseAccordingToType([3.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 + np_grad = 2 * np_pred + self.assertAllCloseAccordingToType( + [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval()) + self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval()) + + def testMinimizeSparseResourceVariable(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + pred += var1 + loss = pred * pred + sgd_op = gradient_descent.GradientDescentOptimizer(1.0).minimize(loss) + # TODO(apassos) calling initialize_resources on all resources here + # doesn't work because the sessions and graph are reused across unit + # tests and this would mean trying to reinitialize variables. Figure out + # a long-term solution for this. + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + self.assertAllCloseAccordingToType([3.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 + np_grad = 2 * np_pred + self.assertAllCloseAccordingToType( + [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval()) + self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval()) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = constant_op.constant(3.0) + sgd_op = gradient_descent.GradientDescentOptimizer( + lrate).apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + + def testGradWrtRef(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + opt = gradient_descent.GradientDescentOptimizer(3.0) + values = [1.0, 3.0] + vars_ = [variables.Variable([v], dtype=dtype) for v in values] + grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) + variables.global_variables_initializer().run() + for grad, _ in grads_and_vars: + self.assertAllCloseAccordingToType([1.0], grad.eval()) + + def testWithGlobalStep(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + global_step = variables.Variable(0, trainable=False) + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + sgd_op = gradient_descent.GradientDescentOptimizer(3.0).apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params and global_step + self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + self.assertAllCloseAccordingToType(1, global_step.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant( + [0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), + constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant( + [0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + sgd_op = gradient_descent.GradientDescentOptimizer(3.0).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0], [2.0]], var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]], + var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]], + var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/momentum.py b/tensorflow/contrib/optimizer_v2/momentum.py new file mode 100644 index 0000000000..0a5aadc2d1 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/momentum.py @@ -0,0 +1,124 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Momentum for TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.training import training_ops + + +class MomentumOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the Momentum algorithm. + + Computes (if `use_nesterov = False`): + + ``` + accumulation = momentum * accumulation + gradient + variable -= learning_rate * accumulation + ``` + + Note that in the dense version of this algorithm, `accumulation` is updated + and applied regardless of a gradient's value, whereas the sparse version (when + the gradient is an `IndexedSlices`, typically because of `tf.gather` or an + embedding) only updates variable slices and corresponding `accumulation` terms + when that part of the variable was used in the forward pass. + """ + + def __init__(self, learning_rate, momentum, + use_locking=False, name="Momentum", use_nesterov=False): + """Construct a new Momentum optimizer. + + Some of the args below are hyperparameters, where a hyperparameter is + defined as a scalar Tensor, a regular Python value or a callable (which + will be evaluated when `apply_gradients` is called) returning a scalar + Tensor or a Python value. + + Args: + learning_rate: A float hyperparameter. The learning rate. + momentum: A float hyperparameter. The momentum. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Momentum". + use_nesterov: If `True` use Nesterov Momentum. + See [Sutskever et al., 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). + This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + + @compatibility(eager) + When eager execution is enabled, learning_rate and momentum can each be a + callable that takes no arguments and returns the actual value to use. This + can be useful for changing these values across different invocations of + optimizer functions. + @end_compatibility + """ + super(MomentumOptimizer, self).__init__(use_locking, name) + self._set_hyper("learning_rate", learning_rate) + self._set_hyper("momentum", momentum) + self._use_nesterov = use_nesterov + + def _create_vars(self, var_list, state): + for v in var_list: + state.zeros_slot(v, "momentum") + + def _apply_dense(self, grad, var, state): + mom = state.get_slot(var, "momentum") + return training_ops.apply_momentum( + var, + mom, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad, + state.get_hyper("momentum", var.dtype.base_dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op + + def _resource_apply_dense(self, grad, var, state): + mom = state.get_slot(var, "momentum") + return training_ops.resource_apply_momentum( + var.handle, + mom.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad, + state.get_hyper("momentum", var.dtype.base_dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov) + + def _apply_sparse(self, grad, var, state): + mom = state.get_slot(var, "momentum") + return training_ops.sparse_apply_momentum( + var, + mom, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad.values, + grad.indices, + state.get_hyper("momentum", var.dtype.base_dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op + + def _resource_apply_sparse(self, grad, var, indices, state): + mom = state.get_slot(var, "momentum") + return training_ops.resource_sparse_apply_momentum( + var.handle, + mom.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + grad, + indices, + state.get_hyper("momentum", var.dtype.base_dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov) diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py new file mode 100644 index 0000000000..f37eb48181 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/momentum_test.py @@ -0,0 +1,562 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Momentum.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.optimizer_v2 import momentum as momentum_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class MomentumOptimizerTest(test.TestCase): + + def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): + var = var + accum * lr * momentum + accum = accum * momentum + g + var = var - lr * accum + var = var - accum * lr * momentum + return var, accum + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + [1.0, 2.0], dtype=dtype, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + [3.0, 4.0], dtype=dtype, name="var1_%d" % i) + else: + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + learning_rate = lambda: 2.0 + momentum = lambda: 0.9 + if not use_callable_params: + learning_rate = learning_rate() + momentum = momentum() + mom_opt = momentum_lib.MomentumOptimizer( + learning_rate=learning_rate, momentum=momentum) + mom_update = mom_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Check we have slots + self.assertEqual(["momentum"], mom_opt.get_slot_names()) + slot0 = mom_opt.get_slot(var0, "momentum") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + slot1 = mom_opt.get_slot(var1, "momentum") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + if not context.executing_eagerly(): + self.assertFalse(slot0 in variables.trainable_variables()) + self.assertFalse(slot1 in variables.trainable_variables()) + + # Step 1: the momentum accumulators where 0. So we should see a normal + # update: v -= grad * learning_rate + if not context.executing_eagerly(): + self.evaluate(mom_update) + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), + self.evaluate(slot0)) + self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), + self.evaluate(slot1)) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) + self.assertAllCloseAccordingToType( + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) + # Step 2: the momentum accumulators contain the previous update. + if context.executing_eagerly(): + mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + else: + self.evaluate(mom_update) + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), + 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) + ]), self.evaluate(var0)) + self.assertAllCloseAccordingToType( + np.array([ + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( + (0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testVariablesAcrossGraphs(self): + optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) + with ops.Graph().as_default(): + var0 = resource_variable_ops.ResourceVariable( + [1.0, 2.0], dtype=dtypes.float32, name="var0") + var1 = resource_variable_ops.ResourceVariable( + [3.0, 4.0], dtype=dtypes.float32, name="var1") + if context.executing_eagerly(): + loss = lambda: math_ops.reduce_sum(var0 + var1) + else: + loss = math_ops.reduce_sum(var0 + var1) + optimizer.minimize(loss) + optimizer_variables = optimizer.variables() + self.assertStartsWith(optimizer_variables[0].name, "var0") + self.assertStartsWith(optimizer_variables[1].name, "var1") + self.assertEquals(2, len(optimizer_variables)) + + with ops.Graph().as_default(): + var2 = resource_variable_ops.ResourceVariable( + [1.0, 2.0], dtype=dtypes.float32, name="var2") + var3 = resource_variable_ops.ResourceVariable( + [3.0, 4.0], dtype=dtypes.float32, name="var3") + if context.executing_eagerly(): + loss = lambda: math_ops.reduce_sum(var2 + var3) + else: + loss = math_ops.reduce_sum(var2 + var3) + optimizer.minimize(loss) + optimizer_variables = optimizer.variables() + self.assertStartsWith(optimizer_variables[0].name, "var2") + self.assertStartsWith(optimizer_variables[1].name, "var3") + self.assertEquals(2, len(optimizer_variables)) + + def testNesterovMomentum(self): + for dtype in [dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + cost = 5 * var0 * var0 + 3 * var1 + global_step = variables.Variable( + array_ops.zeros([], dtypes.int64), name="global_step") + mom_op = momentum_lib.MomentumOptimizer( + learning_rate=2.0, momentum=0.9, use_nesterov=True) + opt_op = mom_op.minimize(cost, global_step, [var0, var1]) + variables.global_variables_initializer().run() + for t in range(1, 5): + opt_op.run() + var0_np, accum0_np = self._update_nesterov_momentum_numpy( + var0_np, accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, + 3, 2.0, 0.9) + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + + def testSparseNesterovMomentum(self): + for dtype in [dtypes.float32, dtypes.float64]: + with self.test_session(): + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + grads = [] + for t in range(1, 5): + grads.append(var0_np * 10) + var0_np, accum0_np = self._update_nesterov_momentum_numpy( + var0_np, accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, + 3, 2.0, 0.9) + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + loss = 5 * var0 * var0 + 3 * var1 + mom_op = momentum_lib.MomentumOptimizer( + learning_rate=2.0, momentum=0.9, use_nesterov=True) + x_feed = array_ops.placeholder(dtype) + y_feed = ops.IndexedSlices( + x_feed, constant_op.constant([0, 1]), constant_op.constant([2])) + grads_and_vars = [(y_feed, var0), (constant_op.constant( + [3.0, 3.0], dtype=dtype), var1)] + opt_update = mom_op.apply_gradients(grads_and_vars) + variables.global_variables_initializer().run() + for t in range(1, 5): + opt_update.run(feed_dict={x_feed: grads[t - 1]}) + var0_np, accum0_np = self._update_nesterov_momentum_numpy( + var0_np, accum0_np, var0_np * 10, 2.0, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, + accum1_np, + 3, 2.0, 0.9) + self.assertAllClose(var0_np, var0.eval()) + self.assertAllClose(var1_np, var1.eval()) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testMinimizeSparseResourceVariable(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + + # pylint: disable=cell-var-from-loop + def loss(): + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + return pred * pred + # pylint: enable=cell-var-from-loop + + opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) + sgd_op = opt.minimize(loss) + self.evaluate(variables.global_variables_initializer()) + # Run 1 step of sgd + self.evaluate(sgd_op) + # Validate updated params + self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0)) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testMinimizeWith2DIndiciesForEmbeddingLookup(self): + var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2])) + + def loss(): + return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]])) + + opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0) + sgd_op = opt.minimize(loss) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(sgd_op) + self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0)) + + def testTensorLearningRateAndMomentum(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + mom_opt = momentum_lib.MomentumOptimizer( + learning_rate=constant_op.constant(2.0), + momentum=constant_op.constant(0.9)) + mom_update = mom_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Check we have slots + self.assertEqual(["momentum"], mom_opt.get_slot_names()) + slot0 = mom_opt.get_slot(var0, "momentum") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + self.assertFalse(slot0 in variables.trainable_variables()) + slot1 = mom_opt.get_slot(var1, "momentum") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + self.assertFalse(slot1 in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the momentum accumulators where 0. So we should see a normal + # update: v -= grad * learning_rate + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) + self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + # Step 2: the momentum accumulators contain the previous update. + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), + 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( + (0.9 * 0.01 + 0.01) * 2.0) + ]), var1.eval()) + + def _dbParamsMom01(self): + """Return dist-belief momentum values. + + Return values been generated from the dist-belief momentum unittest, + running with a learning rate of 0.1 and a momentum of 0.1. + + These values record how a parameter vector of size 10, initialized with 0.0, + gets updated with 10 consecutive momentum steps. It uses random gradients. + + Returns: + db_grad: The gradients to apply + db_out: The parameters after the momentum update. + """ + db_grad = [[]] * 10 + db_out = [[]] * 10 + # pylint: disable=line-too-long + db_grad[0] = [ + 0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018, + 0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615 + ] + db_out[0] = [ + -9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018, + -0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618 + ] + db_grad[1] = [ + 0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378, + 0.5513742, 0.94687688, 0.16012503, 0.22159521 + ] + db_out[1] = [ + -0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884, + -0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544 + ] + db_grad[2] = [ + 0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965, + 0.31168157, 0.43203235, 0.16792089, 0.24644311 + ] + db_out[2] = [ + -0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978, + -0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189 + ] + db_grad[3] = [ + 0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098, + 0.81454384, 0.03848977, 0.89759839, 0.93665648 + ] + db_out[3] = [ + -0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105, + -0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303 + ] + db_grad[4] = [ + 0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359, + 0.69107032, 0.81897682, 0.5433259, 0.67860287 + ] + db_out[4] = [ + -0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165, + -0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544 + ] + db_grad[5] = [ + 0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563, + 0.84163809, 0.41172323, 0.83259648, 0.44941229 + ] + db_out[5] = [ + -0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094, + -0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717 + ] + db_grad[6] = [ + 0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221, + 0.73577434, 0.16014607, 0.57500273, 0.071136251 + ] + db_out[6] = [ + -0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685, + -0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997 + ] + db_grad[7] = [ + 0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646, + 0.74053431, 0.16033, 0.66625422, 0.73515922 + ] + db_out[7] = [ + -0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838, + -0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418 + ] + db_grad[8] = [ + 0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039, + 0.55561525, 0.22567581, 0.93331909, 0.29438227 + ] + db_out[8] = [ + -0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527, + -0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781 + ] + db_grad[9] = [ + 0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893, + 0.68593478, 0.50580865, 0.12602448, 0.093537711 + ] + db_out[9] = [ + -0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302, + -0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295 + ] + # pylint: enable=line-too-long + return db_grad, db_out + + def testLikeDistBeliefMom01(self): + with self.test_session(): + db_grad, db_out = self._dbParamsMom01() + num_samples = len(db_grad) + var0 = variables.Variable([0.0] * num_samples) + grads0 = constant_op.constant([0.0] * num_samples) + mom_opt = momentum_lib.MomentumOptimizer(learning_rate=0.1, momentum=0.1) + mom_update = mom_opt.apply_gradients(zip([grads0], [var0])) + variables.global_variables_initializer().run() + for i in xrange(num_samples): + mom_update.run(feed_dict={grads0: db_grad[i]}) + self.assertAllClose(np.array(db_out[i]), var0.eval()) + + def testSparse(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype)) + var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2])) + grads0 = ops.IndexedSlices( + constant_op.constant( + [[.1, .1]], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([4, 2])) + grads1 = ops.IndexedSlices( + constant_op.constant( + [[.01, .01], [.01, .01]], dtype=dtype), + constant_op.constant([2, 3]), + constant_op.constant([4, 2])) + mom_opt = momentum_lib.MomentumOptimizer( + learning_rate=2.0, momentum=0.9) + mom_update = mom_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Check we have slots + self.assertEqual(["momentum"], mom_opt.get_slot_names()) + slot0 = mom_opt.get_slot(var0, "momentum") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + slot1 = mom_opt.get_slot(var1, "momentum") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + + # Fetch params to validate initial values + self.assertAllClose([0, 0], var0.eval()[0]) + self.assertAllClose([0, 0], var0.eval()[1]) + self.assertAllClose([1, 1], var1.eval()[2]) + + # Step 1: the momentum accumulators are 0. So we should see a normal + # update: v -= grad * learning_rate + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType(np.array([0, 0]), slot0.eval()[0]) + self.assertAllCloseAccordingToType(np.array([.1, .1]), slot0.eval()[1]) + self.assertAllCloseAccordingToType( + np.array([.01, .01]), slot1.eval()[2]) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType(np.array([0, 0]), var0.eval()[0]) + self.assertAllCloseAccordingToType( + np.array([-(0.1 * 2.0), -(0.1 * 2.0)]), var0.eval()[1]) + self.assertAllCloseAccordingToType( + np.array([1.0 - (0.01 * 2.0), 1.0 - (0.01 * 2.0)]), var1.eval()[2]) + # Step 2: the momentum accumulators contain the previous update. + mom_update.run() + # Check that the momentum accumulators have been updated. + self.assertAllClose(np.array([0, 0]), slot0.eval()[0]) + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()[1]) + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + slot1.eval()[2]) + # Check that the parameters have been updated. + self.assertAllClose(np.array([0, 0]), var0.eval()[0]) + self.assertAllCloseAccordingToType( + np.array([ + -(0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), -(0.1 * 2.0) - ( + (0.9 * 0.1 + 0.1) * 2.0) + ]), var0.eval()[1]) + self.assertAllCloseAccordingToType( + np.array([ + 0.98 - ((0.9 * 0.01 + 0.01) * 2.0), 0.98 - ( + (0.9 * 0.01 + 0.01) * 2.0) + ]), var1.eval()[2]) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + mom_opt = momentum_lib.MomentumOptimizer( + learning_rate=2.0, momentum=0.9) + mom_update1 = mom_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + mom_update2 = mom_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertEqual(["momentum"], mom_opt.get_slot_names()) + slot0 = mom_opt.get_slot(var0, "momentum") + self.assertEquals(slot0.get_shape(), var0.get_shape()) + slot1 = mom_opt.get_slot(var1, "momentum") + self.assertEquals(slot1.get_shape(), var1.get_shape()) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the momentum accumulators where 0. So we should see a normal + # update: v -= grad * learning_rate + mom_update1.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) + self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + # Step 2: the second momentum accumulators contain the previous update. + mom_update2.run() + # Check that the momentum accumulators have been updated. + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + # Check that the parameters have been updated. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), + 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( + (0.9 * 0.01 + 0.01) * 2.0) + ]), var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py new file mode 100644 index 0000000000..471992fdac --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -0,0 +1,1352 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Version 2 of class Optimizer.""" +# pylint: disable=g-bad-name + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import checkpointable +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer as optimizer_v1 +from tensorflow.python.training import slot_creator +from tensorflow.python.util import nest + + +class _OptimizableVariable(object): + """Interface for abstracting over variables in the optimizers.""" + + @abc.abstractmethod + def target(self): + """Returns the optimization target for this variable.""" + raise NotImplementedError("Calling an abstract method.") + + @abc.abstractmethod + def update_op(self, optimizer, g, *args): + """Returns the update ops for updating the variable.""" + raise NotImplementedError("Calling an abstract method.") + + +class _RefVariableProcessor(_OptimizableVariable): + """Processor for Variable.""" + + def __init__(self, v): + self._v = v + + def target(self): + return self._v._ref() # pylint: disable=protected-access + + def update_op(self, optimizer, g, *args): + if isinstance(g, ops.Tensor): + update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access + if self._v.constraint is not None: + with ops.control_dependencies([update_op]): + return self._v.assign(self._v.constraint(self._v)) + else: + return update_op + else: + assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a " + "tensor nor IndexedSlices.") + if self._v.constraint is not None: + raise RuntimeError( + "Cannot use a constraint function on a sparse variable.") + # pylint: disable=protected-access + return optimizer._apply_sparse_duplicate_indices(g, self._v, *args) + + +class _DenseReadResourceVariableProcessor(_OptimizableVariable): + """Processor for dense ResourceVariables.""" + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g, *args): + # pylint: disable=protected-access + update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args) + if self._v.constraint is not None: + with ops.control_dependencies([update_op]): + return self._v.assign(self._v.constraint(self._v)) + else: + return update_op + + +class _DenseResourceVariableProcessor(_OptimizableVariable): + """Processor for dense ResourceVariables.""" + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g, *args): + # pylint: disable=protected-access + if isinstance(g, ops.IndexedSlices): + if self._v.constraint is not None: + raise RuntimeError( + "Cannot use a constraint function on a sparse variable.") + return optimizer._resource_apply_sparse_duplicate_indices( + g.values, self._v, g.indices, *args) + update_op = optimizer._resource_apply_dense(g, self._v, *args) + if self._v.constraint is not None: + with ops.control_dependencies([update_op]): + return self._v.assign(self._v.constraint(self._v)) + else: + return update_op + + +class _StreamingModelPortProcessor(_OptimizableVariable): + """Processor for streaming ModelPorts.""" + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g, *args): + return g + + +class _TensorProcessor(_OptimizableVariable): + """Processor for ordinary Tensors. + + Even though a Tensor can't really be updated, sometimes it is useful to + compute the gradients with respect to a Tensor using the optimizer. Updating + the Tensor is, of course, unsupported. + """ + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g, *args): + raise NotImplementedError("Trying to update a Tensor ", self._v) + + +def _get_processor(v): + """The processor of v.""" + if context.executing_eagerly(): + if isinstance(v, ops.Tensor): + return _TensorProcessor(v) + else: + return _DenseResourceVariableProcessor(v) + if v.op.type == "VarHandleOp": + return _DenseResourceVariableProcessor(v) + if isinstance(v, variables.Variable): + return _RefVariableProcessor(v) + if v.op.type == "SubmodelPort": + return _StreamingModelPortProcessor(v) + if isinstance(v, ops.Tensor): + return _TensorProcessor(v) + raise NotImplementedError("Trying to optimize unsupported type ", v) + + +def _var_key_v2(var): + """Key for representing a primary variable, for looking up slots.""" + # pylint: disable=protected-access + if hasattr(var, "_mirrored_container"): + mirrored_container = var._mirrored_container() + assert mirrored_container is not None + if context.executing_eagerly(): + return mirrored_container._unique_id + return mirrored_container._shared_name + if context.executing_eagerly(): + return var._unique_id + return var.op.name + + +def _resolve(value, name): + if callable(value): + value = value() + return ops.convert_to_tensor(value, name=name) + + +def _is_dynamic(value): + """Returns true if __init__ arg `value` should be re-evaluated each step.""" + if callable(value): return True + # Don't need to do anything special in graph mode, since dynamic values + # will propagate correctly automatically. + # TODO(josh11b): Add per-device caching across steps using variables for + # truly static values once we add distributed support. + if context.executing_eagerly() and isinstance( + value, resource_variable_ops.ResourceVariable): + return True + return False + + +class _OptimizerV2State(object): + """Holds per-graph and per-step optimizer state. + + Use _init_with_static_hyper() to create the state for a graph, and then + _copy_with_dynamic_hyper() to convert that to state for a particular step. + The difference between the two is that the former only has hyper + parameter values that are static and the latter also has values that + can change every step (according to _is_dynamic()). + """ + + def __init__(self, op_name): + self._op_name = op_name + + def _init_with_static_hyper(self, hyper): + """Initialize a fresh state object from hyper dict.""" + # self._hyper contains a dict from name to a dict with the Tensor values. + # This dict starts with a single item with key "None" with the hyper + # parameter value converted to a Tensor. Other items have dtype keys + # with that Tensor cast to that dtype. + self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} + for name, (dynamic, value) in hyper.items() if not dynamic} + self._slots = {} + self._non_slot_dict = {} + # Extra state to help Optimizers implement Checkpointable. Holds information + # about variables which will be restored as soon as they're created. + self._deferred_dependencies = {} # Non-slot variables + self._deferred_slot_restorations = {} # Slot variables + + def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices): + """Create a new state object for a particular step.""" + ret = _OptimizerV2State(self._op_name) + # pylint: disable=protected-access + ret._slots = self._slots + ret._non_slot_dict = self._non_slot_dict + ret._deferred_dependencies = self._deferred_dependencies + ret._deferred_slot_restorations = self._deferred_slot_restorations + ret._hyper = {name: {None: _resolve(value, name)} + for name, (dynamic, value) in hyper.items() if dynamic} + ret._hyper.update(self._hyper) + ret._non_slot_devices = non_slot_devices + ret._distribution = distribution + return ret + + def _variables(self): + """Returns a list of all variables held by self.""" + optimizer_variables = list(self._non_slot_dict.values()) + for variable_dict in self._slots.values(): + for slot_for_variable in variable_dict.values(): + optimizer_variables.append(slot_for_variable) + # Sort variables by name so that the return is deterministic. + return sorted(optimizer_variables, key=lambda v: v.name) + + def _slot_dict(self, slot_name): + """Returns a dict for caching slots created under the given name. + + Args: + slot_name: Name for the slot. + + Returns: + A dict that maps primary `Variable` objects to the slot created + for that variable, under the given slot name. + """ + named_slots = self._slots.get(slot_name, None) + if named_slots is None: + named_slots = {} + self._slots[slot_name] = named_slots + return named_slots + + def create_slot(self, var, val, slot_name, optional_op_name=None): + """Find or create a slot for a variable. + + Args: + var: A `Variable` object. + val: A `Tensor`. The initial value of the slot. + slot_name: Name for the slot. + optional_op_name: Name to use when scoping the Variable that + needs to be created for the slot. + + Returns: + A `Variable` object. + """ + named_slots = self._slot_dict(slot_name) + var_key = _var_key_v2(var) + if var_key not in named_slots: + new_slot_variable = slot_creator.create_slot( + var, val, optional_op_name or self._op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[var_key] = new_slot_variable + return named_slots[var_key] + + def create_slot_with_initializer(self, var, initializer, shape, dtype, + slot_name, optional_op_name=None): + """Find or create a slot for a variable, using an Initializer. + + Args: + var: A `Variable` object. + initializer: An `Initializer`. The initial value of the slot. + shape: Shape of the initial value of the slot. + dtype: Type of the value of the slot. + slot_name: Name for the slot. + optional_op_name: Name to use when scoping the Variable that + needs to be created for the slot. + + Returns: + A `Variable` object. + """ + named_slots = self._slot_dict(slot_name) + var_key = _var_key_v2(var) + if var_key not in named_slots: + new_slot_variable = slot_creator.create_slot_with_initializer( + var, initializer, shape, dtype, optional_op_name or self._op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[var_key] = new_slot_variable + return named_slots[var_key] + + def zeros_slot(self, var, slot_name, optional_op_name=None): + """Find or create a slot initialized with 0.0. + + Args: + var: A `Variable` object. + slot_name: Name for the slot. + optional_op_name: Name to use when scoping the Variable that + needs to be created for the slot. + + Returns: + A `Variable` object. + """ + named_slots = self._slot_dict(slot_name) + var_key = _var_key_v2(var) + if var_key not in named_slots: + new_slot_variable = slot_creator.create_zeros_slot( + var, optional_op_name or self._op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[var_key] = new_slot_variable + return named_slots[var_key] + + def _create_or_restore_slot_variable( + self, slot_variable_position, slot_name, variable, + optional_op_name=None): + """Restore a slot variable's value, possibly creating it. + + Called when a variable which has an associated slot variable is created or + restored. When executing eagerly, we create the slot variable with a + restoring initializer. + + No new variables are created when graph building. Instead, + _restore_slot_variable catches these after normal creation and adds restore + ops to the graph. This method is nonetheless important when graph building + for the case when a slot variable has already been created but `variable` + has just been added to a dependency graph (causing us to realize that the + slot variable needs to be restored). + + Args: + slot_variable_position: A `checkpointable._CheckpointPosition` object + indicating the slot variable `Checkpointable` object to be restored. + slot_name: The name of this `Optimizer`'s slot to restore into. + variable: The variable object this slot is being created for. + optional_op_name: Name to use when scoping the Variable that + needs to be created for the slot. + """ + slot_variable = self.get_slot(var=variable, name=slot_name) + if (slot_variable is None and context.executing_eagerly() and + slot_variable_position.is_simple_variable()): + initializer = checkpointable.CheckpointInitialValue( + checkpoint_position=slot_variable_position) + slot_variable = self.create_slot( + var=variable, + val=initializer, + slot_name=slot_name, + optional_op_name=optional_op_name) + # Optimizers do not have unconditional dependencies on their slot + # variables (nor do any other objects). They are only saved if the + # variables they were created for are also saved. + if slot_variable is not None: + # If we've either made this slot variable, or if we've pulled out an + # existing slot variable, we should restore it. + slot_variable_position.restore(slot_variable) + else: + # We didn't make the slot variable. Defer restoring until it gets created + # normally. We keep a list rather than the one with the highest restore + # UID in case slot variables have their own dependencies, in which case + # those could differ between restores. + variable_key = _var_key_v2(variable) + self._deferred_slot_restorations.setdefault( + slot_name, {}).setdefault(variable_key, []).append( + slot_variable_position) + + def get_slot(self, var, name): + """Return a slot named `name` created for `var` by the Optimizer. + + Some `Optimizer` subclasses use additional variables. For example + `Momentum` and `Adagrad` use variables to accumulate updates. This method + gives access to these `Variable` objects if for some reason you need them. + + Use `get_slot_names()` to get the list of slot names created by the + `Optimizer`. + + Args: + var: A variable passed to `minimize()` or `apply_gradients()`. + name: A string. + + Returns: + The `Variable` for the slot if it was created, `None` otherwise. + """ + named_slots = self._slots.get(name, None) + if not named_slots: + return None + return named_slots.get(_var_key_v2(var), None) + + def get_slot_names(self): + """Return a list of the names of slots created by the `Optimizer`. + + See `get_slot()`. + + Returns: + A list of strings. + """ + return sorted(self._slots.keys()) + + def create_non_slot(self, initial_value, name, colocate_with=None): + """Add an extra variable, not associated with a slot.""" + v = self._non_slot_dict.get(name, None) + if v is None: + if colocate_with is None: colocate_with = self._non_slot_devices + with self._distribution.colocate_vars_with(colocate_with): + # TODO(josh11b): Use get_variable() except for the legacy Adam use case. + v = variable_scope.variable(initial_value, name=name, trainable=False) + self._non_slot_dict[name] = v + deferred_dependencies_list = self._deferred_dependencies.pop(name, ()) + for checkpoint_position in sorted( + deferred_dependencies_list, + key=lambda restore: restore.checkpoint.restore_uid, + reverse=True): + checkpoint_position.restore(v) + return v + + def _restore_slot_variable(self, slot_name, variable, slot_variable): + """Restore a newly created slot variable's value.""" + variable_key = _var_key_v2(variable) + deferred_restorations = self._deferred_slot_restorations.get( + slot_name, {}).pop(variable_key, []) + # Iterate over restores, highest restore UID first to minimize the number + # of assignments. + deferred_restorations.sort(key=lambda position: position.restore_uid, + reverse=True) + for checkpoint_position in deferred_restorations: + checkpoint_position.restore(slot_variable) + + def get_non_slot(self, name): + """Returns the non-slot variable identified by `name`.""" + return self._non_slot_dict.get(name, None) + + def get_hyper(self, name, dtype=None): + """Returns the `name` hyper parameter, optionally cast to `dtype`.""" + dtype_dict = self._hyper[name] + # Do we have the value cast to dtype already cached? This should always + # succeed when dtype is None. + if dtype in dtype_dict: + return dtype_dict[dtype] + # Not cached, cast to dtype and save the result in the cache. + result = math_ops.cast(dtype_dict[None], dtype) + dtype_dict[dtype] = result + return result + + +class OptimizerV2(optimizer_v1.Optimizer): + """Updated base class for optimizers. + + This class defines the API to add Ops to train a model. You never use this + class directly, but instead instantiate one of its subclasses such as + `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. + + ### Usage + + ```python + # Create an optimizer with the desired parameters. + opt = GradientDescentOptimizer(learning_rate=0.1) + # Add Ops to the graph to minimize a cost by updating a list of variables. + # "cost" is a Tensor, and the list of variables contains tf.Variable + # objects. + opt_op = opt.minimize(cost, var_list=<list of variables>) + ``` + + In the training program you will just have to run the returned Op. + + ```python + # Execute opt_op to do one step of training: + opt_op.run() + ``` + + ### Processing gradients before applying them. + + Calling `minimize()` takes care of both computing the gradients and + applying them to the variables. If you want to process the gradients + before applying them you can instead use the optimizer in three steps: + + 1. Compute the gradients with `compute_gradients()`. + 2. Process the gradients as you wish. + 3. Apply the processed gradients with `apply_gradients()`. + + Example: + + ```python + # Create an optimizer. + opt = GradientDescentOptimizer(learning_rate=0.1) + + # Compute the gradients for a list of variables. + grads_and_vars = opt.compute_gradients(loss, <list of variables>) + + # grads_and_vars is a list of tuples (gradient, variable). Do whatever you + # need to the 'gradient' part, for example cap them, etc. + capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] + + # Ask the optimizer to apply the capped gradients. + opt.apply_gradients(capped_grads_and_vars) + ``` + + ### Gating Gradients + + Both `minimize()` and `compute_gradients()` accept a `gate_gradients` + argument that controls the degree of parallelism during the application of + the gradients. + + The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. + + <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides + the maximum parallelism in execution, at the cost of some non-reproducibility + in the results. For example the two gradients of `matmul` depend on the input + values: With `GATE_NONE` one of the gradients could be applied to one of the + inputs _before_ the other gradient is computed resulting in non-reproducible + results. + + <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before + they are used. This prevents race conditions for Ops that generate gradients + for multiple inputs where the gradients depend on the inputs. + + <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed + before any one of them is used. This provides the least parallelism but can + be useful if you want to process all gradients before applying any of them. + + ### Slots + + Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer` + allocate and manage additional variables associated with the variables to + train. These are called <i>Slots</i>. Slots have names and you can ask the + optimizer for the names of the slots that it uses. Once you have a slot name + you can ask the optimizer for the variable it created to hold the slot value. + + This can be useful if you want to log debug a training algorithm, report stats + about the slots, etc. + + ### Non-slot variables + + Some optimizer subclasses, such as `AdamOptimizer` have variables that + are not associated with the variables to train, just the step itself. + + ### Hyper parameters + + These are arguments passed to the optimizer subclass constructor + (the `__init__` method), and then passed to `self._set_hyper()`. + They can be either regular Python values (like 1.0), tensors, or + callables. If they are callable, the callable will be called during + `apply_gradients()` to get the value for the hyper parameter. + + ### State + + Internal methods apre passed a `state` argument with the correct + values to use for the slot and non-slot variables, and the hyper + parameters. + """ + + # Values for gate_gradients. + GATE_NONE = 0 + GATE_OP = 1 + GATE_GRAPH = 2 + + def __init__(self, use_locking, name): + """Create a new Optimizer. + + This must be called by the constructors of subclasses. + Note that Optimizer instances should not bind to a single graph, + and so shouldn't keep Tensors as member variables. Generally + you should be able to use the _set_hyper()/state.get_hyper() + facility instead. + + Args: + use_locking: Bool. If True apply use locks to prevent concurrent updates + to variables. + name: A non-empty string. The name to use for accumulators created + for the optimizer. + + Raises: + ValueError: If name is malformed. + RuntimeError: If _create_slots has been overridden instead of + _create_vars. + """ + # Note: We intentionally don't call parent __init__. + + # Optimizer._create_slots was replaced by _create_vars in OptimizerV2. + if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access + OptimizerV2._create_slots.__code__): + raise RuntimeError("Override _create_vars instead of _create_slots when " + "descending from OptimizerV2 (class %s)" % + self.__class__.__name__) + if not name: + raise ValueError("Must specify the optimizer name") + + self._use_locking = use_locking + self._name = name + # Map from graph_key to state for that graph. We use the graph_key + # since it works in both eager and graph mode, and gives the outer + # graph inside functions. + tower_context = distribute_lib.get_tower_context() + if tower_context is None: + # In a cross-tower context for a DistributionStrategy, which means + # only one Optimizer will be created, not one per tower. + self._per_graph_state = {} + else: + # We use get_tower_context().merge_call() to get a single dict + # shared across all model replicas when running with a + # DistributionStrategy. + self._per_graph_state = tower_context.merge_call(lambda _: {}) + + # Hyper parameters, and whether they should be re-evaluated every step. + self._hyper = {} + + def _set_hyper(self, name, value): + self._hyper[name] = (_is_dynamic(value), value) + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=GATE_OP, aggregation_method=None, + colocate_gradients_with_ops=False, name=None, + grad_loss=None, stop_gradients=None, + scale_loss_by_num_towers=None): + """Add operations to minimize `loss` by updating `var_list`. + + This method simply combines calls `compute_gradients()` and + `apply_gradients()`. If you want to process the gradient before applying + them call `compute_gradients()` and `apply_gradients()` explicitly instead + of using this function. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + stop_gradients: Optional. A Tensor or list of tensors not to differentiate + through. + scale_loss_by_num_towers: Optional boolean. If true, scale the loss + down by the number of towers. By default, auto-detects whether this + is needed. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + + Raises: + ValueError: If some of the variables are not `Variable` objects. + + @compatibility(eager) + When eager execution is enabled, `loss` should be a Python function that + takes elements of `var_list` as arguments and computes the value to be + minimized. If `var_list` is None, `loss` should take no arguments. + Minimization (and gradient computation) is done with respect to the + elements of `var_list` if not None, else with respect to any trainable + variables created during the execution of the `loss` function. + `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and + `grad_loss` are ignored when eager execution is enabled. + @end_compatibility + """ + grads_and_vars = self.compute_gradients( + loss, var_list=var_list, gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss, stop_gradients=stop_gradients, + scale_loss_by_num_towers=scale_loss_by_num_towers) + + vars_with_grad = [v for g, v in grads_and_vars if g is not None] + if not vars_with_grad: + raise ValueError( + "No gradients provided for any variable, check your graph for ops" + " that do not support gradients, between variables %s and loss %s." % + ([str(v) for _, v in grads_and_vars], loss)) + + return self.apply_gradients(grads_and_vars, global_step=global_step, + name=name) + + def compute_gradients(self, loss, var_list=None, + gate_gradients=GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None, stop_gradients=None, + scale_loss_by_num_towers=None): + """Compute gradients of `loss` for the variables in `var_list`. + + This is the first part of `minimize()`. It returns a list + of (gradient, variable) pairs where "gradient" is the gradient + for "variable". Note that "gradient" can be a `Tensor`, an + `IndexedSlices`, or `None` if there is no gradient for the + given variable. + + Args: + loss: A Tensor containing the value to minimize or a callable taking + no arguments which returns the value to minimize. When eager execution + is enabled it must be a callable. + var_list: Optional list or tuple of `tf.Variable` to update to minimize + `loss`. Defaults to the list of variables collected in the graph + under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + stop_gradients: Optional. A Tensor or list of tensors not to differentiate + through. + scale_loss_by_num_towers: Optional boolean. If true, scale the loss + down by the number of towers. By default, auto-detects whether this + is needed. + + Returns: + A list of (gradient, variable) pairs. Variable is always present, but + gradient can be `None`. + + Raises: + TypeError: If `var_list` contains anything else than `Variable` objects. + ValueError: If some arguments are invalid. + RuntimeError: If called with eager execution enabled and `loss` is + not callable. + + @compatibility(eager) + When eager execution is enabled, `gate_gradients`, `aggregation_method`, + and `colocate_gradients_with_ops` are ignored. + @end_compatibility + """ + # TODO(josh11b): Test that we handle weight decay in a reasonable way. + if callable(loss): + with backprop.GradientTape() as tape: + if var_list is not None: + tape.watch(var_list) + loss_value = loss() + + # Scale loss for number of towers (callable-loss case). In this case, + # we have to be careful to call distribute_lib.get_loss_reduction() + # *after* loss() is evaluated, so we know what loss reduction it uses. + if scale_loss_by_num_towers is None: + scale_loss_by_num_towers = ( + distribute_lib.get_loss_reduction() == "mean") + if scale_loss_by_num_towers: + num_towers = distribute_lib.get_distribution_strategy().num_towers + if num_towers > 1: + loss_value *= 1. / num_towers + + if var_list is None: + var_list = tape.watched_variables() + grads = tape.gradient(loss_value, var_list, grad_loss) + return list(zip(grads, var_list)) + if context.executing_eagerly(): + raise RuntimeError( + "`loss` passed to Optimizer.compute_gradients should " + "be a function when eager execution is enabled.") + + # Scale loss for number of towers (non-callable-loss case). + if scale_loss_by_num_towers is None: + scale_loss_by_num_towers = ( + distribute_lib.get_loss_reduction() == "mean") + if scale_loss_by_num_towers: + num_towers = distribute_lib.get_distribution_strategy().num_towers + if num_towers > 1: + loss *= 1. / num_towers + + if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE, + optimizer_v1.Optimizer.GATE_OP, + optimizer_v1.Optimizer.GATE_GRAPH]: + raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " + "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % + gate_gradients) + self._assert_valid_dtypes([loss]) + if grad_loss is not None: + self._assert_valid_dtypes([grad_loss]) + if var_list is None: + var_list = ( + variables.trainable_variables() + + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + else: + var_list = nest.flatten(var_list) + # pylint: disable=protected-access + var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) + # pylint: enable=protected-access + processors = [_get_processor(v) for v in var_list] + if not var_list: + raise ValueError("No variables to optimize.") + var_refs = [p.target() for p in processors] + grads = gradients.gradients( + loss, var_refs, grad_ys=grad_loss, + gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP), + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + stop_gradients=stop_gradients) + if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH: + grads = control_flow_ops.tuple(grads) + grads_and_vars = list(zip(grads, var_list)) + self._assert_valid_dtypes( + [v for g, v in grads_and_vars + if g is not None and v.dtype != dtypes.resource]) + return grads_and_vars + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """Apply gradients to variables. + + This is the second part of `minimize()`. It returns an `Operation` that + applies gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. + + Raises: + TypeError: If `grads_and_vars` is malformed. + ValueError: If none of the variables have gradients. + """ + # This is a default implementation of apply_gradients() that can be shared + # by most optimizers. It relies on the subclass implementing the following + # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse(). + + # Filter out variables with gradients of `None`. + grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. + if not grads_and_vars: + raise ValueError("No variables provided.") + filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None) + if not filtered: + raise ValueError("No gradients provided for any variable: %s." % + ([str(v) for _, v in grads_and_vars],)) + return distribute_lib.get_tower_context().merge_call( + self.distributed_apply, filtered, global_step=global_step, name=name) + + def _get_or_create_state(self, var_list=None): + """Either looks up or creates `_OptimizerV2State`. + + If any variables are available, they should be passed via the `var_list` + argument, and these will be used to determine the graph to create/retrieve + state for. Otherwise the returned state is for the current default graph. + + Args: + var_list: A list of variables to extract a graph from. + + Returns: + An `_OptimizerV2State` object. + """ + # Determine the graph_key from the current graph. + eager_execution = context.executing_eagerly() + if eager_execution or var_list is None: + graph = ops.get_default_graph() + else: + graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access + assert graph is not None + graph_key = graph._graph_key # pylint: disable=protected-access + + # Get the per graph state by looking up the graph_key. + if graph_key in self._per_graph_state: + per_graph_state = self._per_graph_state[graph_key] + else: + per_graph_state = _OptimizerV2State(self._name) + per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access + self._per_graph_state[graph_key] = per_graph_state + return per_graph_state + + def distributed_apply(self, distribution, grads_and_vars, global_step, name): + """`apply_gradients` for use with a `DistributionStrategy`.""" + reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + var_list = [v for _, v in grads_and_vars] + grads_and_vars = zip(reduced_grads, var_list) + + unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)] + eager_execution = context.executing_eagerly() + if eager_execution: + # Give a clear error in this case instead of "name not supported + # for Eager Tensors" when we compute non_slot_devices. + for v in unwrapped_var_list: + if isinstance(v, ops.Tensor): + raise NotImplementedError("Trying to update a Tensor ", v) + + with ops.name_scope(name, self._name) as name: + per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list) + # Include the current value of any dynamic hyper parameters in `state`. + non_slot_devices = distribution.non_slot_devices(var_list) + state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access + self._hyper, distribution, non_slot_devices) + + # Create any slot and non-slot variables we need in `state`. + with ops.init_scope(): + self._create_vars(var_list, state) + + with ops.name_scope(name): # Re-enter name_scope created above + # Give the child class a chance to do something before we start + # applying gradients. + self._prepare(state) + + def update(v, g): + """Update variable `v` using gradient `g`.""" + assert v is not None + + # Convert the grad to Tensor or IndexedSlices if necessary, and + # look up a processor for each variable's type. + try: + g = ops.convert_to_tensor_or_indexed_slices(g) + except TypeError: + raise TypeError( + "Gradient must be convertible to a Tensor" + " or IndexedSlices, or None: %s" % g) + if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): + raise TypeError( + "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) + processor = _get_processor(v) + + # We colocate all ops created in _apply_dense or _apply_sparse + # on the same device as the variable. + # TODO(apassos): figure out how to get the variable name here. + scope_name = "" if eager_execution else v.op.name + # device_policy is set because non-mirrored tensors will be read in + # `update_op`. + # TODO(josh11b): Make different state objects for each device to + # avoid needing to set the device_policy. + with ops.name_scope("update_" + scope_name), \ + context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + return processor.update_op(self, g, state) + + # Use the processors to update the variables. + update_ops = [] + for grad, var in grads_and_vars: + update_ops.extend(distribution.unwrap(distribution.update( + var, update, grad))) + + # Give the child class a chance to do something after applying + # gradients + def finish(): + # TODO(josh11b): Make different state objects for each device to + # avoid needing to set the device_policy. + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + return self._finish(state) + + update_ops = control_flow_ops.group(update_ops) + with ops.control_dependencies([update_ops]): + finish_updates = distribution.update_non_slot(non_slot_devices, finish) + if finish_updates is None: + finish_updates = update_ops + + # Update `global_step` (if any). + if global_step is None: + apply_updates = distribution.group(finish_updates, name=name) + else: + with ops.control_dependencies(distribution.unwrap(finish_updates)): + + def update_global_step(global_step): + if isinstance(global_step, resource_variable_ops.ResourceVariable): + return global_step.assign_add( + ops.convert_to_tensor(1, dtype=global_step.dtype), + read_value=False) + else: + return state_ops.assign_add(global_step, 1) + + apply_updates = distribution.group( + distribution.update(global_step, update_global_step), name=name) + + # Add the training op to the TRAIN_OP graph collection in graph mode. + if not eager_execution: + if isinstance(apply_updates, ops.Tensor): + apply_updates = apply_updates.op + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + if apply_updates not in train_op: + train_op.append(apply_updates) + + return apply_updates + + def get_slot(self, var, name): + """Return a slot named `name` created for `var` by the Optimizer. + + Some `Optimizer` subclasses use additional variables. For example + `Momentum` and `Adagrad` use variables to accumulate updates. This method + gives access to these `Variable` objects if for some reason you need them. + + Use `get_slot_names()` to get the list of slot names created by the + `Optimizer`. + + Args: + var: A variable passed to `minimize()` or `apply_gradients()`. + name: A string. + + Returns: + The `Variable` for the slot if it was created, `None` otherwise. + """ + state = self._get_state_for_var(var) + return state.get_slot(var, name) if state is not None else None + + def get_slot_names(self): + """Return a list of the names of slots created by the `Optimizer`. + + See `get_slot()`. + + Returns: + A list of strings. + """ + state = self._get_per_graph_state() + return state.get_slot_names() if state is not None else [] + + def variables(self): + """A list of variables which encode the current state of `Optimizer`. + + Includes slot variables and additional global variables created by the + optimizer in the current default graph. + + Returns: + A list of variables. + """ + state = self._get_per_graph_state() + return state._variables() if state is not None else [] # pylint: disable=protected-access + + # -------------- + # Methods to be implemented by subclasses if they want to use the + # inherited implementation of apply_gradients() or compute_gradients(). + # -------------- + def _create_vars(self, var_list, state): + """Create all slots needed by the variables and any non-slot variables. + + Args: + var_list: A list of `Variable` objects. + state: An object with these methods: + `create_slot(var, val, slot_name, optional_op_name)`, + `create_slot_with_initializer(` + `var, initializer, shape, dtype, slot_name, optional_op_name)`, + `zeros_slot(var, slot_name, optional_op_name)`, + `create_non_slot_variable(initial_value, name, colocate_with)`, + `get_hyper(name)` + """ + # No slots needed by default + pass + + def _prepare(self, state): + """Code to execute before applying gradients. + + Note that most uses of _prepare() in Optimizer have been subsumed + by explicit support for hyper parameters in OptimizerV2 + + Args: + state: An object with a `get_hyper(name)` method. + + Returns: + Return value will be ignored. + """ + pass + + def _apply_dense(self, grad, var, state): + """Add ops to apply dense gradients to `var`. + + Args: + grad: A `Tensor`. + var: A `Variable` object. + state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, + and `get_hyper(name)` methods. + + Returns: + An `Operation`. + """ + raise NotImplementedError() + + def _resource_apply_dense(self, grad, handle, state): + """Add ops to apply dense gradients to the variable `handle`. + + Args: + grad: a `Tensor` representing the gradient. + handle: a `Tensor` of dtype `resource` which points to the variable + to be updated. + state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, + and `get_hyper(name)` methods. + + Returns: + An `Operation` which updates the value of the variable. + """ + raise NotImplementedError() + + def _resource_apply_sparse_duplicate_indices( + self, grad, handle, indices, state): + """Add ops to apply sparse gradients to `handle`, with repeated indices. + + Optimizers which override this method must deal with repeated indices. See + the docstring of `_apply_sparse_duplicate_indices` for details. By default + the correct behavior, to sum non-unique indices and their associated + gradients, is enforced by first pre-processing `grad` and `indices` and + passing them on to `_resource_apply_sparse`. Optimizers which deal correctly + with duplicate indices may instead override this method to avoid the + overhead of summing. + + Args: + grad: a `Tensor` representing the gradient for the affected indices. + handle: a `Tensor` of dtype `resource` which points to the variable + to be updated. + indices: a `Tensor` of integral type representing the indices for + which the gradient is nonzero. Indices may be repeated. + state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, + and `get_hyper(name)` methods. + + Returns: + An `Operation` which updates the value of the variable. + """ + # pylint: disable=protected-access + summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices( + values=grad, indices=indices) + # pylint: enable=protected-access + return self._resource_apply_sparse( + summed_grad, handle, unique_indices, state) + + def _resource_apply_sparse(self, grad, handle, indices, state): + """Add ops to apply sparse gradients to the variable `handle`. + + Similar to `_apply_sparse`, the `indices` argument to this method has been + de-duplicated. Optimizers which deal correctly with non-unique indices may + instead override `_resource_apply_sparse_duplicate_indices` to avoid this + overhead. + + Args: + grad: a `Tensor` representing the gradient for the affected indices. + handle: a `Tensor` of dtype `resource` which points to the variable + to be updated. + indices: a `Tensor` of integral type representing the indices for + which the gradient is nonzero. Indices are unique. + state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, + and `get_hyper(name)` methods. + + Returns: + An `Operation` which updates the value of the variable. + """ + raise NotImplementedError() + + def _apply_sparse_duplicate_indices(self, grad, var, state): + """Add ops to apply sparse gradients to `var`, with repeated sparse indices. + + Optimizers which override this method must deal with IndexedSlices objects + such as the following: + + IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1]) + + The correct interpretation is: + + IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1]) + + Many optimizers deal incorrectly with repeated indices when updating based + on sparse gradients (e.g. summing squares rather than squaring the sum, or + applying momentum terms multiple times). Adding first is always the correct + behavior, so this is enforced here by reconstructing the IndexedSlices to + have only unique indices, then calling _apply_sparse. + + Optimizers which deal correctly with repeated indices may instead override + this method to avoid the overhead of summing indices. + + Args: + grad: `IndexedSlices`. + var: A `Variable` object. + state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, + and `get_hyper(name)` methods. + + Returns: + An `Operation`. + """ + # pylint: disable=protected-access + summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices( + values=grad.values, indices=grad.indices) + # pylint: enable=protected-access + gradient_no_duplicate_indices = ops.IndexedSlices( + indices=unique_indices, + values=summed_values, + dense_shape=grad.dense_shape) + return self._apply_sparse(gradient_no_duplicate_indices, var, state) + + def _apply_sparse(self, grad, var, state): + """Add ops to apply sparse gradients to `var`. + + The IndexedSlices object passed to `grad` in this function is by default + pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate + indices (see its docstring for details). Optimizers which can tolerate or + have correct special cases for duplicate sparse indices may override + `_apply_sparse_duplicate_indices` instead of this function, avoiding that + overhead. + + Args: + grad: `IndexedSlices`, with no repeated indices. + var: A `Variable` object. + state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, + and `get_hyper(name)` methods. + + Returns: + An `Operation`. + """ + raise NotImplementedError() + + def _finish(self, state): + """Do what is needed to finish the update. + + This is called inside a scope colocated with any non-slot variables. + + Args: + state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`, + and `get_hyper(name)` methods. + + Returns: + The operation to apply updates, or None if no updates. + """ + return None + + # -------------- + # Utility methods for subclasses. + # -------------- + def _get_per_graph_state(self): + # pylint: disable=protected-access + return self._per_graph_state.get(ops.get_default_graph()._graph_key, None) + + def _get_state_for_var(self, var): + # pylint: disable=protected-access + return self._per_graph_state.get(var._graph_key, None) + + # -------------- + # Overridden methods from Checkpointable. + # -------------- + + def _track_checkpointable(self, *args, **kwargs): + """Optimizers may not track dependencies. Raises an error.""" + raise NotImplementedError( + "Optimizers may not have dependencies. File a feature request if this " + "limitation bothers you.") + + @property + def _checkpoint_dependencies(self): + """From Checkpointable. Gather graph-specific non-slot variables to save.""" + current_graph_non_slot_variables = [] + state = self._get_per_graph_state() + if state is not None: + for name, variable_object in sorted( + state._non_slot_dict.items(), # pylint: disable=protected-access + # Avoid comparing variables + key=lambda item: item[0]): + current_graph_non_slot_variables.append( + checkpointable.CheckpointableReference( + name=name, ref=variable_object)) + # Note: ignores super(); Optimizers may not have any dependencies outside of + # state objects. + return current_graph_non_slot_variables + + def _lookup_dependency(self, name): + """From Checkpointable. Find a non-slot variable in the current graph.""" + state = self._get_per_graph_state() + if state is None: + return None + else: + return state.get_non_slot(name) + + @property + def _deferred_dependencies(self): + """Lets Checkpointable know where non-slot variables are created. + + If necessary, creates a new state object for the current default graph. + Checkpointable will then add entries to that state's deferred dependency + dictionary. The state object will check that dictionary when creating + non-slot variables, restoring their value if an entry is found. + + Returns: + A dictionary which holds deferred dependencies for the current default + graph. + """ + state = self._get_or_create_state() + return state._deferred_dependencies # pylint: disable=protected-access + + def _create_or_restore_slot_variable( + self, slot_variable_position, slot_name, variable): + """Checkpointable: Restore a slot variable's value, possibly creating it. + + Called when a variable which has an associated slot variable is created or + restored. + + Args: + slot_variable_position: A `checkpointable._CheckpointPosition` object + indicating the slot variable `Checkpointable` object to be restored. + slot_name: The name of this `Optimizer`'s slot to restore into. + variable: The variable object this slot is being created for. + """ + state = self._get_or_create_state(var_list=[variable]) + state._create_or_restore_slot_variable( # pylint: disable=protected-access + slot_variable_position=slot_variable_position, + slot_name=slot_name, + variable=variable, + optional_op_name=self._name) + + # -------------- + # Unsupported parent methods + # -------------- + def _slot_dict(self, slot_name): + raise NotImplementedError( + "_slot_dict() method unsupported in OptimizerV2") + + def _get_or_make_slot(self, var, val, slot_name, op_name): + raise NotImplementedError( + "_get_or_make_slot() method unsupported in OptimizerV2") + + def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, + slot_name, op_name): + raise NotImplementedError( + "_get_or_make_slot_with_initializer() method unsupported in " + "OptimizerV2") + + def _create_non_slot_variable(self, initial_value, name, colocate_with): + raise NotImplementedError( + "_create_non_slot_variable() method unsupported in OptimizerV2") + + def _get_non_slot_variable(self, name, graph=None): + raise NotImplementedError( + "_get_non_slot_variable() method unsupported in OptimizerV2") + + def _non_slot_variables(self): + raise NotImplementedError( + "_non_slot_variables() method unsupported in OptimizerV2") diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_symbols.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_symbols.py new file mode 100644 index 0000000000..24eada06cc --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_symbols.py @@ -0,0 +1,42 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Distribution-aware version of Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.optimizer_v2.adadelta import AdadeltaOptimizer +from tensorflow.contrib.optimizer_v2.adagrad import AdagradOptimizer +from tensorflow.contrib.optimizer_v2.adam import AdamOptimizer +from tensorflow.contrib.optimizer_v2.gradient_descent import GradientDescentOptimizer +from tensorflow.contrib.optimizer_v2.momentum import MomentumOptimizer +from tensorflow.contrib.optimizer_v2.optimizer_v2 import OptimizerV2 +from tensorflow.contrib.optimizer_v2.rmsprop import RMSPropOptimizer + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'AdadeltaOptimizer', + 'AdagradOptimizer', + 'AdamOptimizer', + 'GradientDescentOptimizer', + 'MomentumOptimizer', + 'OptimizerV2', + 'RMSPropOptimizer', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py new file mode 100644 index 0000000000..8599af32f6 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -0,0 +1,294 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for OptimizerV2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import gradient_descent +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class OptimizerTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testBasic(self): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + # Note that we name the variables uniquely here since the variables don't + # seem to be getting deleted at the end of the loop. + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, + name='a_%d' % i) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, + name='b_%d' % i) + def loss(): + return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + # Note that for eager execution, minimize expects a function instead of a + # Tensor. + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name='global_step_%d' % i) + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + # Run 1 step of sgd through optimizer + opt_op = sgd_op.minimize(loss, global_step, [var0, var1]) + self.evaluate(opt_op) + # Validate updated params + self.assertAllClose([-14., -13.], self.evaluate(var0)) + self.assertAllClose([-6., -5.], self.evaluate(var1)) + + def testAggregationMethod(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + cost = 5 * var0 + 3 * var1 + global_step = variables.Variable( + array_ops.zeros([], dtypes.int64), name='global_step') + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + opt_op = sgd_op.minimize( + cost, + global_step, [var0, var1], + aggregation_method=gradients_impl.AggregationMethod. + EXPERIMENTAL_ACCUMULATE_N) + + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 1 step of sgd through optimizer + opt_op.run() + # Validate updated params + self.assertAllClose([-14., -13.], var0.eval()) + self.assertAllClose([-6., -5.], var1.eval()) + + def testPrecomputedGradient(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + cost = 5 * var0 + 3 * var1 + grad_loss = constant_op.constant([42, -42], dtype=dtype) + global_step = variables.Variable( + array_ops.zeros([], dtypes.int64), name='global_step') + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + opt_op = sgd_op.minimize( + cost, global_step, [var0, var1], grad_loss=grad_loss) + + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 1 step of sgd through optimizer + opt_op.run() + # Validate updated params + self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)], + var0.eval()) + self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], + var1.eval()) + + @test_util.run_in_graph_and_eager_modes() + def testNoVariables(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + # pylint: disable=cell-var-from-loop + def loss(): + var0 = resource_variable_ops.ResourceVariable( + [1.0, 2.0], dtype=dtype, trainable=False, name='a') + var1 = resource_variable_ops.ResourceVariable( + [3.0, 4.0], dtype=dtype, trainable=False, name='b') + return 5 * var0 + var1 + # pylint: enable=cell-var-from-loop + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, 'No.*variables'): + sgd_op.minimize(loss) + + @test_util.run_in_graph_and_eager_modes() + def testNoGradients(self): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + # Note that we name the variables uniquely here since the variables don't + # seem to be getting deleted at the end of the loop. + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, + name='a%d' % i) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, + name='b%d' % i) + # pylint: disable=cell-var-from-loop + def loss(): + return 5 * var0 + # pylint: enable=cell-var-from-loop + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, 'No gradients'): + # var1 has no gradient + sgd_op.minimize(loss, var_list=[var1]) + + @test_util.run_in_graph_and_eager_modes() + def testNoGradientsForAnyVariables_Minimize(self): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + # Note that we name the variables uniquely here since the variables don't + # seem to be getting deleted at the end of the loop. + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, + name='a_%d' % i) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, + name='b_%d' % i) + def loss(): + return constant_op.constant(5.0) + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, + 'No gradients provided for any variable'): + sgd_op.minimize(loss, var_list=[var0, var1]) + + @test_util.run_in_graph_and_eager_modes() + def testNoGradientsForAnyVariables_ApplyGradients(self): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + # Note that we name the variables uniquely here since the variables don't + # seem to be getting deleted at the end of the loop. + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, + name='a_%d' % i) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, + name='b_%d' % i) + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + with self.assertRaisesRegexp(ValueError, + 'No gradients provided for any variable'): + sgd_op.apply_gradients([(None, var0), (None, var1)]) + + @test_util.run_in_graph_and_eager_modes() + def testGradientsAsVariables(self): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + # Note that we name the variables uniquely here since the variables don't + # seem to be getting deleted at the end of the loop. + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, + name='a%d' % i) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, + name='b%d' % i) + def loss(): + return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1]) + # Convert gradients to tf.Variables + converted_grads = [ + resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype), + name='c_%d_%d' % (i, j)) + for j, gv in enumerate(grads_and_vars) + ] + convert_ops = [ + state_ops.assign(converted_grads[j], gv[0]) + for j, gv in enumerate(grads_and_vars) + ] + + self.evaluate(variables.global_variables_initializer()) + # Run convert_ops to achieve the gradietns converting + self.evaluate(convert_ops) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 1 step of sgd through optimizer + converted_grads_and_vars = list(zip(converted_grads, [var0, var1])) + opt_op = sgd_op.apply_gradients(converted_grads_and_vars) + self.evaluate(opt_op) + + # Validate updated params + self.assertAllClose([-14., -13.], self.evaluate(var0)) + self.assertAllClose([-6., -5.], self.evaluate(var1)) + + @test_util.run_in_graph_and_eager_modes() + def testComputeGradientsWithTensors(self): + x = ops.convert_to_tensor(1.0) + def f(): + return x * x + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + grads_and_vars = sgd_op.compute_gradients(f, [x]) + self.assertEqual(1, len(grads_and_vars)) + grad, x_as_var = grads_and_vars[0] + self.assertIs(x, x_as_var) + self.assertEqual(2.0, self.evaluate(grad)) + + with self.assertRaises(NotImplementedError): + sgd_op.apply_gradients(grads_and_vars) + + def testTrainOp(self): + with self.test_session(): + var0 = variables.Variable([1.0, 2.0]) + var1 = variables.Variable([3.0, 4.0]) + cost = 5 * var0 + 3 * var1 + global_step = variables.Variable( + array_ops.zeros([], dtypes.int64), name='global_step') + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) + self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP)) + + def testConstraint(self): + constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.) + constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.) + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], + constraint=constraint_01) + var1 = variables.Variable([3.0, 4.0], + constraint=constraint_0) + cost = 5 * var0 + 3 * var1 + global_step = variables.Variable( + array_ops.zeros([], dtypes.int64), name='global_step') + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) + + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Run 1 step of sgd through optimizer + opt_op.run() + # Validate updated params + self.assertAllClose([-0.1, -0.1], var0.eval()) + self.assertAllClose([0., 0.], var1.eval()) + + def testStopGradients(self): + with self.test_session(): + var0 = variables.Variable([1.0, 2.0], name='var0') + var1 = variables.Variable([3.0, 4.0], name='var1') + var0_id = array_ops.identity(var0) + cost = 5 * var0_id + 3 * var1 + sgd_op = gradient_descent.GradientDescentOptimizer(3.0) + grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1], + stop_gradients=[var0_id]) + grad_dict = {var.op.name: grad for grad, var in grads_and_vars} + self.assertIsNone(grad_dict['var0']) + self.assertIsNotNone(grad_dict['var1']) + + def testDoNotOverrideCreateSlots(self): + class ShouldNotOverrideCreateSlots(optimizer_v2.OptimizerV2): + + def _create_slots(self, var_list): + """In OptimizerV2 _create_slots was renamed _create_vars.""" + return var_list + + with self.assertRaises(RuntimeError): + ShouldNotOverrideCreateSlots(True, 'name') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py new file mode 100644 index 0000000000..164ff0ea06 --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/rmsprop.py @@ -0,0 +1,233 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RMSprop optimizer for Tensorflow. + +rmsprop algorithm [tieleman2012rmsprop] + +A detailed description of rmsprop. + +- maintain a moving (discounted) average of the square of gradients +- divide gradient by the root of this average + +mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 +mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon) +delta = - mom + +This implementation of RMSProp uses plain momentum, not Nesterov momentum. + +The centered version additionally maintains a moving (discounted) average of the +gradients, and uses that average to estimate the variance: + +mean_grad = decay * mean_square{t-1} + (1-decay) * gradient +mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 +mom = momentum * mom{t-1} + learning_rate * g_t / + sqrt(mean_square - mean_grad**2 + epsilon) +delta = - mom +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops + +from tensorflow.python.training import training_ops + + +class RMSPropOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the RMSProp algorithm. + + See the + [paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). + """ + + def __init__(self, + learning_rate, + decay=0.9, + momentum=0.0, + epsilon=1e-10, + use_locking=False, + centered=False, + name="RMSProp"): + """Construct a new RMSProp optimizer. + + Note that in the dense implementation of this algorithm, variables and their + corresponding accumulators (momentum, gradient moving average, square + gradient moving average) will be updated even if the gradient is zero + (i.e. accumulators will decay, momentum will be applied). The sparse + implementation (used when the gradient is an `IndexedSlices` object, + typically because of `tf.gather` or an embedding lookup in the forward pass) + will not update variable slices or their accumulators unless those slices + were used in the forward pass (nor is there an "eventual" correction to + account for these omitted updates). This leads to more efficient updates for + large embedding lookup tables (where most of the slices are not accessed in + a particular graph execution), but differs from the published algorithm. + + Some of the args below are hyperparameters, where a hyperparameter is + defined as a scalar Tensor, a regular Python value or a callable (which + will be evaluated when `apply_gradients` is called) returning a scalar + Tensor or a Python value. + + Args: + learning_rate: A float hyperparameter. The learning rate. + decay: A float hyperparameter. Discounting factor for the history/coming + gradient. + momentum: A float hyperparameter. + epsilon: A float hyperparameter. Small value to avoid zero denominator. + use_locking: If True use locks for update operation. + centered: If True, gradients are normalized by the estimated variance of + the gradient; if False, by the uncentered second moment. Setting this to + True may help with training, but is slightly more expensive in terms of + computation and memory. Defaults to False. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "RMSProp". + """ + super(RMSPropOptimizer, self).__init__(use_locking, name) + self._set_hyper("learning_rate", learning_rate) + self._set_hyper("decay", decay) + self._set_hyper("momentum", momentum) + self._set_hyper("epsilon", epsilon) + + self._centered = centered + + def _create_vars(self, var_list, state): + for v in var_list: + if v.get_shape().is_fully_defined(): + init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype) + else: + init_rms = array_ops.ones_like(v) + state.create_slot_with_initializer(v, init_rms, v.get_shape(), + v.dtype.base_dtype, "rms") + if self._centered: + state.zeros_slot(v, "mg") + state.zeros_slot(v, "momentum") + + def _apply_dense(self, grad, var, state): + rms = state.get_slot(var, "rms") + mom = state.get_slot(var, "momentum") + if self._centered: + mg = state.get_slot(var, "mg") + return training_ops.apply_centered_rms_prop( + var, + mg, + rms, + mom, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + use_locking=self._use_locking).op + else: + return training_ops.apply_rms_prop( + var, + rms, + mom, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var, state): + rms = state.get_slot(var, "rms") + mom = state.get_slot(var, "momentum") + if self._centered: + mg = state.get_slot(var, "mg") + return training_ops.resource_apply_centered_rms_prop( + var.handle, + mg.handle, + rms.handle, + mom.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + use_locking=self._use_locking) + else: + return training_ops.resource_apply_rms_prop( + var.handle, + rms.handle, + mom.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + use_locking=self._use_locking) + + def _apply_sparse(self, grad, var, state): + rms = state.get_slot(var, "rms") + mom = state.get_slot(var, "momentum") + if self._centered: + mg = state.get_slot(var, "mg") + return training_ops.sparse_apply_centered_rms_prop( + var, + mg, + rms, + mom, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad.values, + grad.indices, + use_locking=self._use_locking) + else: + return training_ops.sparse_apply_rms_prop( + var, + rms, + mom, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad.values, + grad.indices, + use_locking=self._use_locking) + + def _resource_apply_sparse(self, grad, var, indices, state): + rms = state.get_slot(var, "rms") + mom = state.get_slot(var, "momentum") + if self._centered: + mg = self.get_slot(var, "mg") + return training_ops.resource_sparse_apply_centered_rms_prop( + var.handle, + mg.handle, + rms.handle, + mom.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + indices, + use_locking=self._use_locking) + else: + return training_ops.resource_sparse_apply_rms_prop( + var.handle, + rms.handle, + mom.handle, + state.get_hyper("learning_rate", var.dtype.base_dtype), + state.get_hyper("decay", var.dtype.base_dtype), + state.get_hyper("momentum", var.dtype.base_dtype), + state.get_hyper("epsilon", var.dtype.base_dtype), + grad, + indices, + use_locking=self._use_locking) diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py new file mode 100644 index 0000000000..ed68f6afbf --- /dev/null +++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py @@ -0,0 +1,449 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rmsprop optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import itertools +import math + +import numpy as np + +from tensorflow.contrib.optimizer_v2 import rmsprop +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +_DATA_TYPES = [dtypes.half, dtypes.float32] + +_TEST_PARAM_VALUES = [ + # learning_rate, decay, momentum, epsilon, centered, use_resource + [0.5, 0.9, 0.0, 1e-3, True, False], + [0.5, 0.9, 0.0, 1e-3, False, False], + [0.5, 0.9, 0.0, 1e-3, True, True], + [0.5, 0.9, 0.0, 1e-3, False, True], + [0.1, 0.9, 0.0, 1e-3, True, False], + [0.5, 0.95, 0.0, 1e-3, False, False], + [0.5, 0.95, 0.0, 1e-5, True, False], + [0.5, 0.95, 0.9, 1e-5, True, False], +] + +_TESTPARAMS = [ + [data_type] + values + for data_type, values in itertools.product(_DATA_TYPES, _TEST_PARAM_VALUES) +] + + +class RMSPropOptimizerTest(test.TestCase): + + def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum, + epsilon, centered): + rms_t = rms * decay + (1 - decay) * g * g + denom_t = rms_t + epsilon + if centered: + mg_t = mg * decay + (1 - decay) * g + denom_t -= mg_t * mg_t + else: + mg_t = mg + mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype) + var_t = var - mom_t + return var_t, mg_t, rms_t, mom_t + + def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom, + lr, decay, momentum, epsilon, centered): + mg_t = copy.deepcopy(mg) + rms_t = copy.deepcopy(rms) + mom_t = copy.deepcopy(mom) + var_t = copy.deepcopy(var) + for i in range(len(gindexs)): + gindex = gindexs[i] + gvalue = gvalues[i] + rms_t[gindex] = rms[gindex] * decay + (1 - decay) * gvalue * gvalue + denom_t = rms_t[gindex] + epsilon + if centered: + mg_t[gindex] = mg_t[gindex] * decay + (1 - decay) * gvalue + denom_t -= mg_t[gindex] * mg_t[gindex] + mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t) + var_t[gindex] = var[gindex] - mom_t[gindex] + return var_t, mg_t, rms_t, mom_t + + def testDense(self): + # TODO(yori): Use ParameterizedTest when available + for (dtype, learning_rate, decay, momentum, + epsilon, centered, use_resource) in _TESTPARAMS: + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(1, 5): + update.run() + + var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( + var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, + decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( + var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, + decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testMinimizeSparseResourceVariable(self): + for dtype in [dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=0.0, + centered=False).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[0., 1.]], var0.eval(), atol=0.01) + + def testMinimizeSparseResourceVariableCentered(self): + for dtype in [dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=1.0, + centered=True).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[-111, -138]], var0.eval(), atol=0.01) + + def testSparse(self): + # TODO(yori): Use ParameterizedTest when available + for (dtype, learning_rate, decay, + momentum, epsilon, centered, _) in _TESTPARAMS: + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([1])) + grads1_np_indices = np.array([1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([1])) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(1, 5): + update.run() + + var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( + var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, + learning_rate, decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( + var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, + learning_rate, decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testWithoutMomentum(self): + for dtype in [dtypes.half, dtypes.float32]: + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the rms accumulators where 1. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) + ]), var1.eval()) + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) + ]), var1.eval()) + + def testWithMomentum(self): + for dtype in [dtypes.half, dtypes.float32]: + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: rms = 1, mom = 0. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the momentum accumulators + self.assertAllCloseAccordingToType( + np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) + + # Check that the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + ]), var1.eval()) + + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) + ]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) + ]), mom1.eval()) + + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) + ]), var0.eval()) + + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) + ]), var1.eval()) + + +if __name__ == "__main__": + test.main() |