diff options
author | Xin Jin <jinxin900924@gmail.com> | 2018-09-04 20:44:43 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-04 20:44:43 +0800 |
commit | ce035c2493c060b38e53ca7a63c66b26e265b210 (patch) | |
tree | 85af4fc680847ffbe06037b429a013ade32c4ce4 /tensorflow/contrib/opt | |
parent | 16c42f0d4826b12a5359281997ee3f8e27fd5a87 (diff) | |
parent | 1c3d02eb3594e9d92cd26562e797142ee34505b2 (diff) |
Merge branch 'master' into ma_easgd
Diffstat (limited to 'tensorflow/contrib/opt')
26 files changed, 3215 insertions, 102 deletions
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 13aa1d7e7a..93e589907e 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -19,24 +19,32 @@ py_library( "python/training/drop_stale_gradient_optimizer.py", "python/training/elastic_average_optimizer.py", "python/training/external_optimizer.py", + "python/training/ggt.py", + "python/training/lars_optimizer.py", "python/training/lazy_adam_optimizer.py", + "python/training/matrix_functions.py", "python/training/model_average_optimizer.py", "python/training/moving_average_optimizer.py", "python/training/multitask_optimizer_wrapper.py", "python/training/nadam_optimizer.py", "python/training/powersign.py", "python/training/reg_adagrad_optimizer.py", + "python/training/shampoo.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", + "python/training/weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", @@ -194,6 +202,25 @@ py_test( ], ) +py_test( + name = "weight_decay_optimizers_test", + srcs = ["python/training/weight_decay_optimizers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], @@ -302,3 +329,71 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "ggt_test", + srcs = ["python/training/ggt_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "shampoo_test", + size = "large", + srcs = ["python/training/shampoo_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "lars_optimizer_test", + srcs = ["python/training/lars_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_test( + name = "matrix_functions_test", + srcs = ["python/training/matrix_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 4c13c8e247..ad7d7cfa6e 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,15 +22,20 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * +from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * +from tensorflow.contrib.opt.python.training.lars_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.reg_adagrad_optimizer import * +from tensorflow.contrib.opt.python.training.shampoo import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * -from tensorflow.contrib.opt.python.training.model_average_optimizer import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -43,9 +48,14 @@ _allowed_symbols = [ 'DelayCompensatedGradientDescentOptimizer', 'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', + 'LARSOptimizer', 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', + 'MomentumWOptimizer', + 'AdamWOptimizer', + 'DecoupledWeightDecayExtension', + 'extend_with_decoupled_weight_decay', 'ScipyOptimizerInterface', 'VariableClippingOptimizer', 'MultitaskOptimizerWrapper', @@ -53,7 +63,10 @@ _allowed_symbols = [ 'ElasticAverageOptimizer', 'ElasticAverageCustomGetter', 'ModelAverageOptimizer', - 'ModelAverageCustomGetter' + 'ModelAverageCustomGetter', + 'GGTOptimizer', + 'ShampooOptimizer', + 'RegAdagradOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index 21bf3f5313..61d8b94eca 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -74,7 +74,7 @@ class AdaMaxOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype) m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots() @@ -142,7 +142,7 @@ class AdaMaxOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable( @@ -172,7 +172,7 @@ class AdaMaxOptimizerTest(test.TestCase): def doTestBasic(self, use_resource=False): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - with self.test_session(graph=ops.Graph()): + with self.session(graph=ops.Graph()): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -224,14 +224,16 @@ class AdaMaxOptimizerTest(test.TestCase): var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), + rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), + rtol=1e-2) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) def testBasic(self): - with self.test_session(): + with self.cached_session(): self.doTestBasic(use_resource=False) @test_util.run_in_graph_and_eager_modes(reset_test=True) @@ -240,7 +242,7 @@ class AdaMaxOptimizerTest(test.TestCase): def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -276,7 +278,7 @@ class AdaMaxOptimizerTest(test.TestCase): def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py index 08d45ed73f..628a735e72 100644 --- a/tensorflow/contrib/opt/python/training/addsign_test.py +++ b/tensorflow/contrib/opt/python/training/addsign_test.py @@ -214,7 +214,7 @@ class AddSignTest(test.TestCase): # Run 7 steps of AddSign # first 4 steps with positive gradient # last 3 steps with negative gradient (sign(gm) should be -1) - for t in range(1, 4): + for t in range(1, 8): if t < 5: update.run() else: @@ -222,7 +222,7 @@ class AddSignTest(test.TestCase): var0_np, m0 = addsign_update_numpy( var0_np, - grads0_np, + grads0_np if t < 5 else -grads0_np, m0, learning_rate, alpha=alpha, @@ -232,7 +232,7 @@ class AddSignTest(test.TestCase): ) var1_np, m1 = addsign_update_numpy( var1_np, - grads1_np, + grads1_np if t < 5 else -grads1_np, m1, learning_rate, alpha=alpha, diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 209c4611f3..6c203e5519 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -17,22 +17,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops - -from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import optimizer +from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import data_flow_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import constant_op LOCAL_VARIABLE_NAME = 'local_center_variable' GLOBAL_VARIABLE_NAME = 'global_center_variable' +GLOBAL_STEP = 'global_step' class ElasticAverageCustomGetter(object): @@ -52,16 +53,32 @@ class ElasticAverageCustomGetter(object): with tf.device( tf.train.replica_device_setter( worker_device=worker_device, - ps_device="/job:ps/cpu:0", + ps_device="/job:ps", cluster=cluster)), tf.variable_scope('',custom_getter=ea_custom_getter): - hid_w = tf.get_variable( - initializer=tf.truncated_normal( - [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], - stddev=1.0 / IMAGE_PIXELS), - name="hid_w") - hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), - name="hid_b") + ... + create your model here + ... + with tf.device(worker_device): + opt = tf.train.MomentumOptimizer(...) + optimizer = ElasticAverageOptimizer( + opt, + num_worker=2, + moving_rate=0.01, # or use default value + communication_period=20, + ea_custom_getter=ea_custom_getter) + ... + train_op = optimizer.apply_gradients( + grads_vars, + global_step=global_step) + ... + hooks = [optimizer.make_session_run_hook(is_chief, task_index)] + ... + with tf.train.MonitoredTrainingSession(master=server.target, + is_chief=is_chief, + checkpoint_dir=("...), + save_checkpoint_secs=600, + hooks=hooks) as mon_sess: """ def __init__(self, worker_device): @@ -83,21 +100,32 @@ class ElasticAverageCustomGetter(object): collections=[ops.GraphKeys.LOCAL_VARIABLES], *args, **kwargs) - global_center_variable = variable_scope.variable( + if kwargs['reuse'] == True: + return local_var + global_center_variable = getter( name='%s/%s' % (GLOBAL_VARIABLE_NAME, name), - initial_value=local_var.initialized_value(), trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + *args, + **kwargs) with ops.device(self._worker_device): - local_center_variable = variable_scope.variable( + local_center_variable = getter( name='%s/%s' % (LOCAL_VARIABLE_NAME, name), - initial_value=local_var.initialized_value(), trainable=False, - collections=[ops.GraphKeys.LOCAL_VARIABLES]) - - self._local_map[local_var] = local_center_variable - self._global_map[local_var] = global_center_variable + collections=[ops.GraphKeys.LOCAL_VARIABLES], + *args, + **kwargs) + if kwargs['partitioner'] is None: + self._local_map[local_var] = local_center_variable + self._global_map[local_var] = global_center_variable + else: + v_list = list(local_var) + for i in range(len(v_list)): + self._local_map[v_list[i]] \ + = list(local_center_variable)[i] + self._global_map[v_list[i]] \ + = list(global_center_variable)[i] return local_var else: kwargs['trainable'] = trainable @@ -132,6 +160,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): moving_rate=None, rho=None, use_locking=True, + synchronous=False, name='ElasticAverageOptimizer'): """Construct a new gradient descent optimizer. @@ -143,9 +172,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer): communication_period: An int point value to controls the frequency of the communication between every worker and the ps. moving_rate: A floating point value to control the elastic difference. - rho: the amount of exploration we allow ine the model. The default + rho: the amount of exploration we allow in the model. The default value is moving_rate/learning_rate + rho=0.0 is suggested in async mode. use_locking: If True use locks for update operations. + synchronous: Add_sync_queues_and_barrier or not. + True: all workers will wait for each other before start training + False: worker can start training when its initilization is done, + no need to wait for everyone is ready. + in case one worker is restarted, it can join and continue + training without being blocked. name: Optional name prefix for the operations created when applying gradients. Defaults to "ElasticAverageOptimizer". """ @@ -155,6 +191,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): self._period = communication_period self._local_map = ea_custom_getter._local_map self._global_map = ea_custom_getter._global_map + self._synchronous = synchronous if moving_rate is None: self._moving_rate = self.BETA / communication_period / num_worker @@ -248,11 +285,29 @@ class ElasticAverageOptimizer(optimizer.Optimizer): TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. """ + global_old = set(n.op.name for n in variables.global_variables()) apply_updates = self._opt.apply_gradients(grads_and_vars) + global_new = set(n.op.name for n in variables.global_variables()) with ops.control_dependencies([apply_updates]): local_update = state_ops.assign_add( self._local_step, 1, name='local_step_update').op + # this is for place the variables created by optimizer to local collection + # e.g., AdamOptimizer will create beta as global variables + def _adjust_optimizer_variable_collection(opt_vars): + g = ops.get_default_graph() + idx = 0 + for _ in range(len(g._collections[ops.GraphKeys.GLOBAL_VARIABLES])): + var = g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx] + name = var.op.name + if name in opt_vars: + ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, var) + del g.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)[idx] + else: + idx += 1 + + _adjust_optimizer_variable_collection(global_new - global_old) + # update global variables. def _Update_global_variables(): local_vars = [v for g, v in grads_and_vars if g is not None] @@ -297,7 +352,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): variables equal to the global center variables before the training begins""" def _Add_sync_queues_and_barrier(enqueue_after_list): - """Adds ops to enqueu on all worker queues""" + """Adds ops to enqueue on all worker queues""" sync_queues = [ data_flow_ops.FIFOQueue( self._num_worker, [dtypes.bool], @@ -331,6 +386,9 @@ class ElasticAverageOptimizer(optimizer.Optimizer): init_ops.append(state_ops.assign(lc_var, gc_var)) init_op = control_flow_ops.group(*(init_ops)) + if self._synchronous == False: + return init_op + sync_queue_op = _Add_sync_queues_and_barrier([init_op]) return sync_queue_op @@ -338,6 +396,51 @@ class ElasticAverageOptimizer(optimizer.Optimizer): """Creates a hook to handle ElasticAverageOptimizerHook ops such as initialization.""" return _ElasticAverageOptimizerHook(self, is_chief, task_index) + def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): + """Create a saver copy global_center_variable to trainable variables + Please call this function after all your variables created with + ElasticAverageCustomGetter. For evaluations or inference, use this saver + during training. It will save the global_center_variable of the trained + parameters under the original parameter names. + Args: + var_list: List of variables to save, as per `Saver()`. + If set to None, save all the trainable_variables that have + been created before this call. + name: The name of the saver. + **kwargs: Keyword arguments of `Saver()`. + Returns: + A `tf.train.Saver` object. + Raises: + RuntimeError: global_center_variable is empty, please make sure + this is called after model created and + ElasticAverageCustomGetter is used when declaring you model + """ + if not self._global_map: + raise RuntimeError('global_center_variable is empty, please make sure ' + 'this is called after model created and ' + 'ElasticAverageCustomGetter is used when declaring ' + 'you model') + + if var_list is None: + var_list = variables.trainable_variables() + if not isinstance(var_list, dict): + var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + + swapped_var_list = {} + for key, var in var_list.items(): + tensor = var + + if not isinstance(var, list): + for tvar in variables.trainable_variables(): + if tvar.op.name == var.op.name: + tensor = self._global_map.get(tvar, var) + break + else: #partitioned variable + tensor = [self._global_map.get(lvar, lvar) for lvar in var] + + swapped_var_list[key] = tensor + + return saver.Saver(swapped_var_list, name=name, **kwargs) class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): @@ -358,3 +461,7 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): if self._is_chief: self._global_init_op = variables.global_variables_initializer() self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index) + + def after_create_session(self, session, coord): + """Run initialization ops""" + session.run(self._variable_init_op)
\ No newline at end of file diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py index 9d57dc08f6..5bf6a08de1 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py @@ -17,17 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import portpicker +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import device_setter from tensorflow.python.training import gradient_descent +from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.training import training from tensorflow.python.training import training_util -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import device_setter from tensorflow.contrib.opt.python.training.elastic_average_optimizer import \ ElasticAverageOptimizer, ElasticAverageCustomGetter, GLOBAL_VARIABLE_NAME @@ -59,42 +64,72 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"): # Creates the workers and return their sessions, graphs, train_ops. # Chief worker will update at last -def _get_workers(num_workers, period, workers, moving_rate): +def _get_workers(num_workers, period, workers, moving_rate, num_ps=1): sessions = [] graphs = [] train_ops = [] + savers = [] for worker_id in range(num_workers): graph = ops.Graph() is_chief = (worker_id == 0) with graph.as_default(): worker_device = "/job:worker/task:%d/cpu:0" % (worker_id) - ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device) + ea_custom = ElasticAverageCustomGetter(worker_device=worker_device) with variable_scope.variable_scope( - "", custom_getter=ea_coustom), ops.device( + "", custom_getter=ea_custom), ops.device( device_setter.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/task:0/cpu:0", ps_tasks=1)): - global_step = variables.Variable(0, name="global_step", trainable=False) + global_step = training_util.get_or_create_global_step() var_0 = variable_scope.get_variable(initializer=0.0, name="v0") var_1 = variable_scope.get_variable(initializer=1.0, name="v1") - - with ops.device("/job:worker/task:" + str(worker_id)): - grads_0 = constant_op.constant(-1.0) - grads_1 = constant_op.constant(-1.0) - - sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) - opt = ElasticAverageOptimizer( - opt=sgd_opt, - num_worker=num_workers, - moving_rate=moving_rate, - communication_period=period, - ea_custom_getter=ea_coustom) + if num_ps > 1: + with variable_scope.variable_scope( + "", + partitioner=partitioned_variables.fixed_size_partitioner( + num_ps, axis=0), + custom_getter=ea_custom), ops.device( + device_setter.replica_device_setter( + worker_device=worker_device, + ps_device="/job:ps/task:0/cpu:0", + ps_tasks=num_ps)): + + partition_var = variable_scope.get_variable( + 'partition_var', + shape=[2, 4], + initializer=init_ops.ones_initializer) + part_0 = list(partition_var)[0] + part_1 = list(partition_var)[1] + + with ops.device("/job:worker/task:" + str(worker_id)): + grads_0 = constant_op.constant(-1.0) + grads_1 = constant_op.constant(-1.0) + grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]]) + grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]]) + + sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) + opt = ElasticAverageOptimizer( + opt=sgd_opt, + num_worker=num_workers, + moving_rate=moving_rate, + communication_period=period, + ea_custom_getter=ea_custom) + if num_ps == 1: + train_op = [ + opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]), + global_step) + ] + else: train_op = [ - opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]), + opt.apply_gradients(([grads_0, var_0], + [grads_1, var_1], + [grads_part_0, part_0], + [grads_part_1, part_1]), global_step) ] easgd_hook = opt.make_session_run_hook(is_chief, worker_id) + saver = opt.swapping_saver() # Creates MonitoredSession sess = training.MonitoredTrainingSession( workers[worker_id].target, hooks=[easgd_hook]) @@ -102,8 +137,9 @@ def _get_workers(num_workers, period, workers, moving_rate): sessions.append(sess) graphs.append(graph) train_ops.append(train_op) + savers.append(saver) - return sessions, graphs, train_ops + return sessions, graphs, train_ops, savers class ElasticAverageOptimizerTest(test.TestCase): @@ -118,7 +154,7 @@ class ElasticAverageOptimizerTest(test.TestCase): cluster, workers, _ = create_local_cluster( num_workers=num_workers, num_ps=num_ps) - sessions, graphs, train_ops = _get_workers( + sessions, graphs, train_ops, savers = _get_workers( num_workers, communication_period, workers, 1.0) var_0 = graphs[0].get_tensor_by_name("v0:0") @@ -158,6 +194,21 @@ class ElasticAverageOptimizerTest(test.TestCase): self.assertAllEqual(2.0, sessions[0].run(var_0_g)) self.assertAllEqual(3.0, sessions[0].run(var_1_g)) self.assertAllEqual(1, sessions[0].run(global_step)) + sessions[0].run(train_ops[0]) + + # save, data will be global value + outfile = os.path.join(test.get_temp_dir(), "model") + savers[0].save(sessions[0]._sess._sess._sess._sess, + save_path=outfile) + ops.reset_default_graph() # restore on a new graph + with session.Session() as sess: + v0 = variable_scope.get_variable(initializer=0.0, name="v0") + v1 = variable_scope.get_variable(initializer=1.0, name="v1") + sess.run(variables.local_variables_initializer()) + saver_opt = saver.Saver(var_list=[v1, v0]) + saver_opt.restore(sess, outfile) + self.assertAllEqual(2.0, sess.run(v0)) + self.assertAllEqual(3.0, sess.run(v1)) def test2Worker1Period(self): num_workers = 2 @@ -166,8 +217,8 @@ class ElasticAverageOptimizerTest(test.TestCase): cluster, workers, _ = create_local_cluster( num_workers=num_workers, num_ps=num_ps) - sessions, graphs, train_ops = _get_workers( - num_workers, communication_period, workers, 0.5) + sessions, graphs, train_ops, savers = _get_workers( + num_workers, communication_period, workers, 0.5, num_ps=2) var_0 = graphs[0].get_tensor_by_name("v0:0") var_1 = graphs[0].get_tensor_by_name("v1:0") @@ -177,6 +228,9 @@ class ElasticAverageOptimizerTest(test.TestCase): var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0") var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0") + part_0_g = graphs[0].get_tensor_by_name( + GLOBAL_VARIABLE_NAME + "/partition_var/part_0:0") + # Verify the initialized value. self.assertAllEqual(0.0, sessions[0].run(var_0)) self.assertAllEqual(1.0, sessions[0].run(var_1)) @@ -194,22 +248,45 @@ class ElasticAverageOptimizerTest(test.TestCase): self.assertAllEqual(1.75, sessions[0].run(var_1_g)) self.assertAllEqual(0.75, sessions[1].run(var_0_1)) self.assertAllEqual(1.75, sessions[1].run(var_1_1)) + # part_0 of global_center copy + part_0_g = sessions[0].run(part_0_g) + + outfile = os.path.join(test.get_temp_dir(), "model") + savers[0].save(sessions[0]._sess._sess._sess._sess, + save_path=outfile) + + # verify restore of partitioned_variables + ops.reset_default_graph() # restore on a new graph + g = ops.get_default_graph() + with session.Session() as sess, g.as_default(): + with variable_scope.variable_scope( + "", + partitioner=partitioned_variables.fixed_size_partitioner( + num_ps, axis=0)): + partition_var = variable_scope.get_variable( + 'partition_var', + shape=[2, 4], + initializer=init_ops.ones_initializer) + s = saver.Saver(var_list=[partition_var]) + s.restore(sess, outfile) + part_0 = g.get_tensor_by_name('partition_var/part_0:0') + self.assertAllEqual(part_0_g, sess.run(part_0)) def testPS2TasksWithClusterSpecClass(self): cluster_spec = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] }) - ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0") + ea_custom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0") from tensorflow.python.training import device_setter with ops.device( device_setter.replica_device_setter(cluster=cluster_spec, worker_device="/job:worker/task:0", ps_device="/job:ps")), \ - variable_scope.variable_scope("", custom_getter=ea_coustom): + variable_scope.variable_scope("", custom_getter=ea_custom): v = variable_scope.get_variable(initializer=[1, 2], name="v") w = variable_scope.get_variable(initializer=[2, 1], name="w") - v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w] + v_g, w_g = ea_custom._global_map[v], ea_custom._global_map[w] self.assertDeviceEqual("/job:worker/task:0", v.device) self.assertDeviceEqual("job:ps/task:0", v_g.device) self.assertDeviceEqual("/job:worker/task:0", w.device) diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py index 953586ee70..9997103016 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py @@ -85,7 +85,7 @@ class ExternalOptimizerInterfaceTest(TestCase): optimizer = MockOptimizerInterface(loss) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) @@ -107,7 +107,7 @@ class ExternalOptimizerInterfaceTest(TestCase): optimizer = MockOptimizerInterface(loss) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) initial_vector_val = sess.run(vector) @@ -164,7 +164,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( self._objective(x), method=method, options=options) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) @@ -176,7 +176,7 @@ class ScipyOptimizerInterfaceTest(TestCase): x = variables.Variable(array_ops.zeros(dimension)) optimizer = external_optimizer.ScipyOptimizerInterface(self._objective(x)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) @@ -242,7 +242,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, equalities=equalities, inequalities=inequalities, method='SLSQP') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) self.assertAllClose(np.ones(2), sess.run(vector)) @@ -260,7 +260,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, var_to_bounds=var_to_bounds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) self.assertAllClose(np.ones(2), sess.run(vector)) @@ -277,7 +277,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, var_to_bounds=var_to_bounds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) self.assertAllClose([0., 2.], sess.run(vector)) @@ -293,7 +293,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, method='SLSQP') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) method = optimizer.optimizer_kwargs.get('method') @@ -312,7 +312,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface(loss, method='SLSQP') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) initial_vector_val = sess.run(vector) diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py new file mode 100644 index 0000000000..cae952d8f5 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -0,0 +1,312 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GGT for Tensorflow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops + + +class GGTOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the GGT algorithm. + + GGT has an advantage over sgd and adam on large models with poor conditioning, + for example language models and CNNs, + see [[ABCHSZZ 2018]](https://arxiv.org/pdf/1806.02958.pdf). + """ + + def __init__(self, + learning_rate=0.001, + beta1=0.9, + use_locking=False, + name="GGT", + window=10, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Construct a new GGT optimizer. + + Initialization: + + ``` + t <- 0 (Initialize timestep) + grad_buffer <- 0 (Initialize buffer for keeping past gradients) + flat_grad <- 0 (Initialize flattened gradient that contains gradients of all + variables) + m_0 <- 0 (Initialize 1st moment vector) + ``` + + Suppose all variables and their gradients are concatenated into vectors + `flat_vars` and `flat_grad`. The update rule for `flat_vars` + uses an optimization described at the beginning of section 2 of the paper: + + ``` + t <- t + 1 + + m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad + grad_buffer[(t-1) % window, :] <- m_t + + M <- grad_buffer^T / sqrt(min(t, window)) + U, sigma, _ <- SVD(M^TM + I * svd_eps) + + sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3) + sigma_sqrt_min <- min(sqrt(sigma)) + + if sigma_sqrt_min > eps: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min + else: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + flat_vars <- flat_vars - learning_rate * new_step + ``` + + GGT provides the power of full-matrix adaptive regularization at a cost not + much larger than SGD. As a result it is suited for large models where the + gradient covariance matrix has a poor condition number that slows down first + order methods. + GGT uses the preconditioner from full-matrix AdaGrad, with gradient history + attenuated exponentially as in Adam, and truncated to a window parameter. + It has provable guarantees even for non-convex optimization that is never + significantly worse than SGD and in some cases better. + + Args: + learning_rate: A float hyperparameter. The learning rate. + beta1: A float hyperparameter. The exponential decay rate for the 1st + moment estimates. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "GGT". + window: An integer hyperparameter. The number of first moments to keep in + computing the adaptive preconditioner. + eps: A float hyperparameter. Used to truncate small eigenvalues of the + gradient covariance matrix. + svd_eps: A float hyperparameter. Used to stabilize SVD. + sigma_eps: A float hyperparameter. Used to regularize matrix inversion. + """ + super(GGTOptimizer, self).__init__(use_locking, name) + self._set_hyper("lr", learning_rate) + self._set_hyper("beta1", beta1) + self._set_hyper("window", window) + self._set_hyper("eps", eps) + self._set_hyper("svd_eps", svd_eps) + self._set_hyper("sigma_eps", sigma_eps) + + self.index_dict = {} + self.shape_dict = {} + + def _create_vars(self, var_list, state): + # Construct ordered dictionary for variable dimensions, sorted by name. + shape_dict = {} + for v in var_list: + shape_dict[v.name] = np.prod(v.get_shape()).value + self.shape_dict = collections.OrderedDict( + sorted(shape_dict.items(), key=lambda t: t[0])) + + # Assign each variable its location in flat_grad. The locations are based on + # the order of sorted names. + idx = 0 + for v_name, v_dim in self.shape_dict.items(): + self.index_dict[v_name] = idx + idx += v_dim + + state.create_non_slot( + initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype), + name="global_step") + + # Buffer for keeping past gradients. + window = state.get_hyper("window") + grad_buffer_init = array_ops.zeros( + [window, idx], dtype=var_list[0].dtype.base_dtype) + state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer") + + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="moment1") + + # Flattened gradient that contains gradients for all variables in the model. + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="flat_grad") + + def _get_global_step(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("global_step") + + def _get_moment1(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("moment1") + + def _get_grad_buffer(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("grad_buffer") + + def _get_flat_grad(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("flat_grad") + + def _apply_sparse(self, grad, var): + raise NotImplementedError("Sparse gradient updates are not supported.") + + def _prepare(self, state): + self._variables = [] + + def _apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _resource_apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _finish(self, state): + var_dtype = self._variables[0].dtype.base_dtype + # Update global step. + global_step = self._get_global_step(state) + update_global_step = state_ops.assign_add(global_step, 1.) + + # Update the first moment estimate. + beta1 = state.get_hyper("beta1", dtype=var_dtype) + moment1 = self._get_moment1(state) + flat_grad = self._get_flat_grad(state) + # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t + update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad) + + # Update the gradient buffer. + window = state.get_hyper("window") + grad_buffer = self._get_grad_buffer(state) + next_grad_index = math_ops.floormod( + math_ops.to_int32(update_global_step - 1.), window) + # grad_buffer[(t-1) % window] := moment1_t + update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, + update_moment1) + + # Compute the update step. + eps = state.get_hyper("eps", dtype=var_dtype) + svd_eps = state.get_hyper("svd_eps", dtype=var_dtype) + sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype) + lr = state.get_hyper("lr", dtype=var_dtype) + denom = math_ops.sqrt( + math_ops.minimum( + ops.convert_to_tensor(update_global_step), + ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype)))) + moment1_2d = array_ops.expand_dims(update_moment1, -1) + + # m = grad_buffer^T / sqrt(min(t, window)) + # m has shape [model dimension, window], where model dimension is the sum + # of the dimensions of the flattened variables. + m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom)) + + # sigma, u, _ = SVD(m^Tm + I * svd_eps) + mm = math_ops.matmul(m, m, transpose_a=True) + damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps + sigma, u, _ = linalg_ops.svd(mm + damping) + sigma_sqrt = math_ops.sqrt(sigma) + sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt) + + # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3 + # We add sigma_eps to alleviate numerical instability. + # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T. + sigma_sqrt_inv = math_ops.divide( + math_ops.cast(1.0, dtype=var_dtype), + math_ops.pow(sigma_sqrt + sigma_eps, 3)) + + # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the + # inversion of a model dimension by model dimension matrix is needed. To + # speed up this computation we calculate the following instead: + # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1. + new_step = array_ops.expand_dims( + array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1) + head = math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag(sigma_sqrt_inv), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + + # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for + # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using + # Woodbury's identity. + # For full derivation please see paper at + # https://arxiv.org/pdf/1806.02958.pdf + tail = moment1_2d - math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag( + math_ops.divide(math_ops.cast(1.0, dtype=var_dtype), + sigma)), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + scaled_tail = math_ops.divide(tail, sigma_sqrt_min) + + update_new_step = control_flow_ops.cond( + sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail), + lambda: math_ops.add(new_step, head)) + + # Update each variable. + update_step = [] + for var in self._variables: + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + var_update_correct_shape = array_ops.reshape( + update_new_step[start_index:end_index], var.get_shape()) + var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape) + update_step.append(var_updated) + + return control_flow_ops.group(update_step) diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py new file mode 100644 index 0000000000..1775edabb3 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt_test.py @@ -0,0 +1,183 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GGTOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def ggt_update_numpy(param, + g_t, + lr, + grad_buffer, + m, + window, + t, + beta1=0.9, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Tests the correctness of one step of GGT.""" + m_t = m * beta1 + (1 - beta1) * g_t + grad_buffer[((t - 1) % window), :] = m_t + m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window))) + mm = np.dot(np.transpose(m_matrix), m_matrix) + damping = np.eye(window) * svd_eps + u, sigma, _ = np.linalg.svd(mm + damping) + + sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3) + new_step = np.linalg.multi_dot([ + m_matrix, u, + np.diag(sigma_sqrt_inv), + np.transpose(u), + np.transpose(m_matrix), m_t + ]) + + sigma_sqrt_min = np.sqrt(sigma).min() + + if sigma_sqrt_min > eps: + new_step += (m_t - np.linalg.multi_dot([ + m_matrix, u, + np.diag(1.0 / sigma), + np.transpose(u), + np.transpose(m_matrix), m_t + ])) * (1.0 / sigma_sqrt_min) + + param_t = param - lr * new_step + return param_t, m_t, grad_buffer + + +class GGTOptimizerTest(test.TestCase): + + def doTestBasic(self, use_resource=False): + # SVD does not support float16 + for i, dtype in enumerate([dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0 = 0.0 + window = 3 + grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype) + lr = 0.001 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np, name="var0") + var1 = variables.Variable(var1_np, name="var1") + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = GGTOptimizer(learning_rate=lr, window=window) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + self.assertTrue(m_t is not None) + self.assertTrue(grad_buffer_t is not None) + self.assertTrue(g_t is not None) + self.assertIn(m_t, opt_variables) + self.assertIn(grad_buffer_t, opt_variables) + self.assertIn(g_t, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + + # Run 3 steps of GGT + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if t == 1: + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.], + [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t)) + elif t == 2: + self.assertAllCloseAccordingToType( + np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], + [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]), + self.evaluate(grad_buffer_t)) + else: + self.assertAllCloseAccordingToType( + np.array([0.0271, 0.0271, 0.00271, 0.00271]), + self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, + 0.001], [0.019, 0.019, 0.0019, 0.0019], + [0.0271, 0.0271, 0.00271, 0.00271]]), + self.evaluate(grad_buffer_t)) + + self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01], + self.evaluate(g_t)) + + var_np = np.append(var0_np, var1_np) + grads_np = np.append(grads0_np, grads1_np) + var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr, + grad_buffer, m0, window, t) + + var0_np = var_np[:2] + var1_np = var_np[2:] + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py new file mode 100644 index 0000000000..a8dafd9a4c --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py @@ -0,0 +1,164 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Layer-wise Adaptive Rate Scaling optimizer for large-batch training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class LARSOptimizer(optimizer.Optimizer): + """Layer-wise Adaptive Rate Scaling for large batch training. + + Introduced by "Large Batch Training of Convolutional Networks" by Y. You, + I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888) + + Implements the LARS learning rate scheme presented in the paper above. This + optimizer is useful when scaling the batch size to up to 32K without + significant performance degradation. It is recommended to use the optimizer + in conjunction with: + - Gradual learning rate warm-up + - Linear learning rate scaling + - Poly rule learning rate decay + + Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors + use the default momentum optimizer. + """ + + def __init__( + self, + learning_rate, + momentum=0.9, + weight_decay=0.0001, + # The LARS coefficient is a hyperparameter + eeta=0.001, + epsilon=0.0, + name="LARSOptimizer", + # Enable skipping variables from LARS scaling. + # TODO(sameerkm): Enable a direct mechanism to pass a + # subset of variables to the optimizer. + skip_list=None, + use_nesterov=False): + """Construct a new LARS Optimizer. + + Args: + learning_rate: A `Tensor` or floating point value. The base learning rate. + momentum: A floating point value. Momentum hyperparameter. + weight_decay: A floating point value. Weight decay hyperparameter. + eeta: LARS coefficient as used in the paper. Dfault set to LARS + coefficient from the paper. (eeta / weight_decay) determines the highest + scaling factor in LARS. + epsilon: Optional epsilon parameter to be set in models that have very + small gradients. Default set to 0.0. + name: Optional name prefix for variables and ops created by LARSOptimizer. + skip_list: List of strings to enable skipping variables from LARS scaling. + If any of the strings in skip_list is a subset of var.name, variable + 'var' is skipped from LARS scaling. For a typical classification model + with batch normalization, the skip_list is ['batch_normalization', + 'bias'] + use_nesterov: when set to True, nesterov momentum will be enabled + + Raises: + ValueError: If a hyperparameter is set to a non-sensical value. + """ + if momentum < 0.0: + raise ValueError("momentum should be positive: %s" % momentum) + if weight_decay < 0.0: + raise ValueError("weight_decay should be positive: %s" % weight_decay) + super(LARSOptimizer, self).__init__(use_locking=False, name=name) + + self._learning_rate = learning_rate + self._momentum = momentum + self._weight_decay = weight_decay + self._eeta = eeta + self._epsilon = epsilon + self._name = name + self._skip_list = skip_list + self._use_nesterov = use_nesterov + + def _create_slots(self, var_list): + for v in var_list: + self._zeros_slot(v, "momentum", self._name) + + def compute_lr(self, grad, var): + scaled_lr = self._learning_rate + if self._skip_list is None or not any(v in var.name + for v in self._skip_list): + w_norm = linalg_ops.norm(var, ord=2) + g_norm = linalg_ops.norm(grad, ord=2) + trust_ratio = array_ops.where( + math_ops.greater(w_norm, 0), + array_ops.where( + math_ops.greater(g_norm, 0), + (self._eeta * w_norm / + (g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0), + 1.0) + scaled_lr = self._learning_rate * trust_ratio + return scaled_lr + + def _apply_dense(self, grad, var): + scaled_lr = self.compute_lr(grad, var) + mom = self.get_slot(var, "momentum") + return training_ops.apply_momentum( + var, + mom, + scaled_lr, + grad, + self._momentum, + use_locking=False, + use_nesterov=self._use_nesterov) + + def _resource_apply_dense(self, grad, var): + scaled_lr = self.compute_lr(grad, var) + mom = self.get_slot(var, "momentum") + return training_ops.resource_apply_momentum( + var.handle, + mom.handle, + scaled_lr, + grad, + self._momentum, + use_locking=False, + use_nesterov=self._use_nesterov) + + # Fallback to momentum optimizer for sparse tensors + def _apply_sparse(self, grad, var): + mom = self.get_slot(var, "momentum") + return training_ops.sparse_apply_momentum( + var, + mom, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + grad.values, + grad.indices, + math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op + + def _resource_apply_sparse(self, grad, var, indices): + mom = self.get_slot(var, "momentum") + return training_ops.resource_sparse_apply_momentum( + var.handle, + mom.handle, + math_ops.cast(self._learning_rate_tensor, grad.dtype), + grad, + indices, + math_ops.cast(self._momentum_tensor, grad.dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov) diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py new file mode 100644 index 0000000000..b76db763da --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py @@ -0,0 +1,127 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0. Licensed to the Apache +# Software Foundation. You may not use this file except in compliance with the +# License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Layer-wise Adaptive Rate Scaling optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import lars_optimizer as lo +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class LARSOptimizerTest(test.TestCase): + + def testLARSGradientOneStep(self): + for _ in range(10): + for dtype in [dtypes.float32, dtypes.float64]: + with self.cached_session() as sess: + shape = [3, 3] + var_np = np.ones(shape) + grad_np = np.ones(shape) + lr_np = 0.1 + m_np = 0.9 + wd_np = 0.1 + ep_np = 1e-5 + eeta = 0.1 + vel_np = np.zeros(shape) + + var = variables.Variable(var_np, dtype=dtype) + grad = variables.Variable(grad_np, dtype=dtype) + opt = lo.LARSOptimizer( + learning_rate=lr_np, + momentum=m_np, + weight_decay=wd_np, + eeta=eeta, + epsilon=ep_np) + + step = opt.apply_gradients([(grad, var)]) + variables.global_variables_initializer().run() + + pre_var = sess.run(var) + pre_vel = sess.run(opt.get_slot(var, 'momentum')) + self.assertAllClose(var_np, pre_var) + self.assertAllClose(vel_np, pre_vel) + + step.run() + post_var = sess.run(var) + post_vel = sess.run(opt.get_slot(var, 'momentum')) + + w_norm = np.linalg.norm(var_np.flatten(), ord=2) + g_norm = np.linalg.norm(grad_np.flatten(), ord=2) + trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np) + scaled_lr = lr_np * trust_ratio + + vel_np = m_np * vel_np + grad_np + var_np -= scaled_lr * vel_np + + self.assertAllClose(var_np, post_var) + self.assertAllClose(vel_np, post_vel) + + def testLARSGradientMultiStep(self): + for _ in range(10): + for dtype in [dtypes.float32, dtypes.float64]: + with self.cached_session() as sess: + shape = [3, 3] + var_np = np.ones(shape) + grad_np = np.ones(shape) + lr_np = 0.1 + m_np = 0.9 + wd_np = 0.1 + ep_np = 1e-5 + eeta = 0.1 + vel_np = np.zeros(shape) + + var = variables.Variable(var_np, dtype=dtype) + grad = variables.Variable(grad_np, dtype=dtype) + opt = lo.LARSOptimizer( + learning_rate=lr_np, + momentum=m_np, + eeta=eeta, + weight_decay=wd_np, + epsilon=ep_np) + + step = opt.apply_gradients([(grad, var)]) + variables.global_variables_initializer().run() + + pre_var = sess.run(var) + pre_vel = sess.run(opt.get_slot(var, 'momentum')) + self.assertAllClose(var_np, pre_var) + self.assertAllClose(vel_np, pre_vel) + + for _ in range(10): + step.run() + + post_var = sess.run(var) + post_vel = sess.run(opt.get_slot(var, 'momentum')) + + w_norm = np.linalg.norm(var_np.flatten(), ord=2) + g_norm = np.linalg.norm(grad_np.flatten(), ord=2) + trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np) + scaled_lr = lr_np * trust_ratio + + vel_np = m_np * vel_np + grad_np + var_np -= scaled_lr * vel_np + + self.assertAllClose(var_np, post_var) + self.assertAllClose(vel_np, post_vel) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py index a16857db7d..dc4c462ce4 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py @@ -53,7 +53,7 @@ class AdamOptimizerTest(test.TestCase): def testSparse(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -109,7 +109,7 @@ class AdamOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable( diff --git a/tensorflow/contrib/opt/python/training/matrix_functions.py b/tensorflow/contrib/opt/python/training/matrix_functions.py new file mode 100644 index 0000000000..baab577638 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/matrix_functions.py @@ -0,0 +1,155 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Matrix functions contains iterative methods for M^p.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + + +def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4): + """Iterative method to get matrix square root. + + Stable iterations for the matrix square root, Nicholas J. Higham + + Page 231, Eq 2.6b + http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf + + Args: + mat_a: the symmetric PSD matrix whose matrix square root be computed + mat_a_size: size of mat_a. + iter_count: Maximum number of iterations. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + + Returns: + mat_a^0.5 + """ + + def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z, + unused_old_mat_z, err, old_err): + # This method require that we check for divergence every step. + return math_ops.logical_and(i < iter_count, err < old_err) + + def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err, + unused_old_err): + current_iterate = 0.5 * (3.0 * identity - math_ops.matmul(mat_z, mat_y)) + current_mat_y = math_ops.matmul(mat_y, current_iterate) + current_mat_z = math_ops.matmul(current_iterate, mat_z) + # Compute the error in approximation. + mat_sqrt_a = current_mat_y * math_ops.sqrt(norm) + mat_a_approx = math_ops.matmul(mat_sqrt_a, mat_sqrt_a) + residual = mat_a - mat_a_approx + current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm + return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err + + identity = linalg_ops.eye(math_ops.to_int32(mat_a_size)) + mat_a = mat_a + ridge_epsilon * identity + norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a)) + mat_init_y = mat_a / norm + mat_init_z = identity + init_err = norm + + _, _, prev_mat_y, _, _, _, _ = control_flow_ops.while_loop( + _iter_condition, _iter_body, [ + 0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err, + init_err + 1.0 + ]) + return prev_mat_y * math_ops.sqrt(norm) + + +def matrix_inverse_pth_root(mat_g, + mat_g_size, + alpha, + iter_count=100, + epsilon=1e-6, + ridge_epsilon=1e-6): + """Computes mat_g^alpha, where alpha = -1/p, p a positive integer. + + We use an iterative Schur-Newton method from equation 3.2 on page 9 of: + + A Schur-Newton Method for the Matrix p-th Root and its Inverse + by Chun-Hua Guo and Nicholas J. Higham + SIAM Journal on Matrix Analysis and Applications, + 2006, Vol. 28, No. 3 : pp. 788-804 + https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf + + Args: + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g. + alpha: exponent, must be -1/p for p a positive integer. + iter_count: Maximum number of iterations. + epsilon: accuracy indicator, useful for early termination. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + + Returns: + mat_g^alpha + """ + + identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) + + def mat_power(mat_m, p): + """Computes mat_m^p, for p a positive integer. + + Power p is known at graph compile time, so no need for loop and cond. + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + assert p == int(p) and p > 0 + power = None + while p > 0: + if p % 2 == 1: + power = math_ops.matmul(mat_m, power) if power is not None else mat_m + p //= 2 + mat_m = math_ops.matmul(mat_m, mat_m) + return power + + def _iter_condition(i, mat_m, _): + return math_ops.logical_and( + i < iter_count, + math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon) + + def _iter_body(i, mat_m, mat_x): + mat_m_i = (1 - alpha) * identity + alpha * mat_m + return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha), mat_m), + math_ops.matmul(mat_x, mat_m_i)) + + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha) + else: + damped_mat_g = mat_g + ridge_epsilon * identity + z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g)) + # The best value for z is + # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / + # (c_max^{1-alpha} - c_min^{1-alpha}) + # where c_max and c_min are the largest and smallest singular values of + # damped_mat_g. + # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) + # Can replace above line by the one below, but it is less accurate, + # hence needs more iterations to converge. + # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g) + # If we want the method to always converge, use z = 1 / norm(damped_mat_g) + # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many + # extra iterations. + _, _, mat_h = control_flow_ops.while_loop( + _iter_condition, _iter_body, + [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)]) + return mat_h diff --git a/tensorflow/contrib/opt/python/training/matrix_functions_test.py b/tensorflow/contrib/opt/python/training/matrix_functions_test.py new file mode 100644 index 0000000000..518fa38233 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/matrix_functions_test.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for Matrix functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import matrix_functions +from tensorflow.python.platform import test + +TOLERANCE = 1e-3 + + +def np_power(mat_g, alpha): + """Computes mat_g^alpha for a square symmetric matrix mat_g.""" + + mat_u, diag_d, mat_v = np.linalg.svd(mat_g) + diag_d = np.power(diag_d, alpha) + return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v) + + +class MatrixFunctionTests(test.TestCase): + + def testMatrixSquareRootFunction(self): + """Tests for matrix square roots.""" + + size = 20 + mat_a = np.random.rand(size, size) + mat = np.dot(mat_a, mat_a.T) + expected_mat = np_power(mat, 0.5) + mat_root = matrix_functions.matrix_square_root(mat, size) + self.assertAllCloseAccordingToType( + expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE) + + def testMatrixInversePthRootFunction(self): + """Tests for matrix inverse pth roots.""" + + size = 20 + mat_a = np.random.rand(size, size) + mat = np.dot(mat_a, mat_a.T) + expected_mat = np_power(mat, -0.125) + mat_root = matrix_functions.matrix_inverse_pth_root(mat, size, -0.125) + self.assertAllCloseAccordingToType( + expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py index ac04ad9911..f22e724528 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py @@ -46,7 +46,7 @@ class MovingAverageOptimizerTest(test.TestCase): def _helpTestRun(self, use_resource=False): for sequential_update in [True, False]: for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(graph=ops.Graph()) as sess: + with self.session(graph=ops.Graph()) as sess: orig_val0 = [1.0, 2.0] orig_val1 = [3.0, 4.0] var0 = variable_scope.get_variable( @@ -165,7 +165,7 @@ class MovingAverageOptimizerTest(test.TestCase): self.assertLess(avg_val1[i], orig_val1[i]) def testFailWhenSaverCreatedBeforeInitialized(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable([1.0], name='var', dtype=dtypes.float32) opt = moving_average_optimizer.MovingAverageOptimizer( gradient_descent.GradientDescentOptimizer(learning_rate=2.0)) @@ -187,7 +187,7 @@ class MovingAverageOptimizerTest(test.TestCase): self.apply_gradients_called = True return super(WrapperOptimizer, self).apply_gradients(*args, **kwargs) - with self.test_session() as sess: + with self.cached_session() as sess: var = variables.Variable([1.2], name='var', dtype=dtypes.float32) loss = var ** 2 wrapper_opt = WrapperOptimizer(learning_rate=2.0) diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py index 618d8eb18d..904aa9ab13 100644 --- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py +++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py @@ -34,7 +34,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase): """ def testWrapper(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32) var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32) @@ -92,7 +92,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase): self.evaluate(slot1)) def testGradientClipping(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32) var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) var2 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py index 825c08a09a..85e05ce71c 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -53,7 +53,7 @@ class NadamOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -106,7 +106,7 @@ class NadamOptimizerTest(test.TestCase): def doTestBasic(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) diff --git a/tensorflow/contrib/opt/python/training/powersign.py b/tensorflow/contrib/opt/python/training/powersign.py index 828f3c51c9..b4aa19264d 100644 --- a/tensorflow/contrib/opt/python/training/powersign.py +++ b/tensorflow/contrib/opt/python/training/powersign.py @@ -65,7 +65,7 @@ class PowerSignOptimizer(optimizer.Optimizer): Example usage for PowerSign-cd (PowerSign with cosine sign decay) ``` decay_steps = 1000 - linear_decay_fn = sign_decays.get_linear_decay_fn(decay_steps) + linear_decay_fn = sign_decays.get_cosine_decay_fn(decay_steps) opt = PowerSignOptimizer(learning_rate=0.1, sign_decay_fn=linear_decay_fn) ``` diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py index 5214082dd6..0bcf5d230a 100644 --- a/tensorflow/contrib/opt/python/training/powersign_test.py +++ b/tensorflow/contrib/opt/python/training/powersign_test.py @@ -216,7 +216,7 @@ class PowerSignTest(test.TestCase): self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([3.0, 4.0], var1.eval()) - # Run 3 steps of powersign + # Run 7 steps of powersign # first 4 steps with positive gradient # last 3 steps with negative gradient (sign(gm) should be -1) for t in range(1, 8): diff --git a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py index ea56e1646a..c09e2ac76d 100644 --- a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py @@ -36,7 +36,7 @@ class RegAdagradOptimizerTest(test.TestCase): def doTestBasic(self, use_locking=False, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): if use_resource: var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) @@ -73,7 +73,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable( [[1.0, 2.0], [3.0, 4.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) @@ -92,7 +92,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -116,7 +116,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseBasic(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) grads0 = ops.IndexedSlices( @@ -144,7 +144,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable([[1.0], [2.0]], dtype=dtype) @@ -170,7 +170,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseRepeatedIndicesResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var_repeated = resource_variable_ops.ResourceVariable( [1.0, 2.0], dtype=dtype) loss_repeated = math_ops.reduce_sum( @@ -194,7 +194,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseStability(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): shape = [1, 6] var0 = variables.Variable( [[ @@ -230,7 +230,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -263,7 +263,7 @@ class RegAdagradOptimizerTest(test.TestCase): np.array([2.715679168701172, 3.715679168701172]), var1.eval()) def testDynamicShapeVariable_Ok(self): - with self.test_session(): + with self.cached_session(): v = variable_scope.get_variable( "v", initializer=constant_op.constant(1.), validate_shape=False) self.assertFalse(v.shape.is_fully_defined()) @@ -274,7 +274,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSkipUpdatingSlots(self): iav = 0.130005 # A value that works with float16 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -306,7 +306,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseSkipUpdatingSlots(self): iav = 0.130005 # A value that works with float16 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) grads0 = ops.IndexedSlices( diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py new file mode 100644 index 0000000000..f161521b97 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -0,0 +1,420 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""The Shampoo Optimizer. + +Variant of Adagrad using one preconditioner matrix per variable dimension. +For details, see https://arxiv.org/abs/1802.09568 +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.opt.python.training import matrix_functions +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import optimizer + + +def GetParam(var, timestep): + if callable(var): + return var(timestep) + else: + return var + + +class ShampooOptimizer(optimizer.Optimizer): + """The Shampoo Optimizer + + Variant of Adagrad using one preconditioner matrix per variable dimension. + For details, see https://arxiv.org/abs/1802.09568 + + gbar is time-weighted accumulated gradient: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + + mat_gbar is time-weighted accumulated gradient square: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation) + + Update rule: + w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t] + Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the + j'th dimension of gbar[t] with the first dimension of + mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter, + and n = rank of the variable. + Prod_j represents doing this contraction for all j in 0..n-1. + + Typically learning_rate is constant, but could be time dependent by passing + a lambda function that depends on step. + """ + + def __init__(self, + global_step=0, + max_matrix_size=768, + gbar_decay=0.0, + gbar_weight=1.0, + mat_gbar_decay=1.0, + mat_gbar_weight=1.0, + learning_rate=1.0, + svd_interval=1, + precond_update_interval=1, + epsilon=1e-4, + alpha=0.5, + use_iterative_root=False, + use_locking=False, + name="Shampoo"): + """Default values of the various hyper-parameters. + + gbar_decay, gbar_weight etc. can be a float or a time varying parameter. + For time-varying parameters use e.g. "lambda T: T / (T + 1.0)" + where the expression in the lambda is a tensorflow expression + + Args: + global_step: tensorflow variable indicating the step. + max_matrix_size: We do not perform SVD for matrices larger than this. + gbar_decay: + gbar_weight: Used to update gbar: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + mat_gbar_decay: + mat_gbar_weight: Used to update mat_gbar: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + learning_rate: Similar to SGD + svd_interval: We should do SVD after this many steps. Default = 1, i.e. + every step. Usually 20 leads to no loss of accuracy, and + 50 or 100 is also OK. May also want more often early, + and less often later - set in caller as for example: + "svd_interval = lambda(T): tf.cond( + T < 2000, lambda: 20.0, lambda: 1000.0)" + precond_update_interval: We should update the preconditioners after + this many steps. Default = 1. Usually less than + svd_interval. + epsilon: epsilon * I_n is added to each mat_gbar_j for stability + alpha: total power of the preconditioners. + use_iterative_root: should the optimizer use SVD (faster) or the + iterative root method (for TPU) for finding the + roots of PSD matrices. + use_locking: + name: name of optimizer. + """ + + super(ShampooOptimizer, self).__init__(use_locking, name) + + self._global_step = math_ops.to_float(global_step) + self._max_matrix_size = max_matrix_size + self._gbar_decay = gbar_decay + self._gbar_weight = gbar_weight + self._mat_gbar_decay = mat_gbar_decay + self._mat_gbar_weight = mat_gbar_weight + self._learning_rate = learning_rate + self._svd_interval = svd_interval + self._precond_update_interval = precond_update_interval + self._epsilon = epsilon + self._alpha = alpha + self._use_iterative_root = use_iterative_root + self._name = name + + def _create_slots(self, var_list): + for v in var_list: + with ops.colocate_with(v): + _ = self._zeros_slot(v, "gbar", self._name) + shape = np.array(v.get_shape()) + for i, d in enumerate(shape): + d_tensor = ops.convert_to_tensor(d) + if d <= self._max_matrix_size: + mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) + if self._svd_interval > 1: + _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), + "H_" + str(i), self._name) + else: + mat_g_init = array_ops.zeros([d_tensor]) + + _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i), + self._name) + + def _resource_apply_dense(self, grad, var): + return self._apply_dense(grad, var) + + def _apply_dense(self, grad, var): + return self._apply_gradient(grad, var) + + def _resource_apply_sparse(self, grad_values, var, grad_indices): + return self._apply_sparse_shared(grad_values, grad_indices, var) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared(grad.values, grad.indices, var) + + def _apply_sparse_shared(self, grad_values, grad_indices, var): + if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0: + # The dimension is small enough, we can make the variable dense and + # do a dense update + dense_grad = array_ops.scatter_nd( + array_ops.expand_dims(grad_indices, axis=1), grad_values, + array_ops.shape(var, out_type=grad_indices.dtype)) + return self._apply_gradient(dense_grad, var) + return self._apply_gradient(grad_values, var, grad_indices) + + def _weighted_average(self, var, weight, weight_t, rest): + """Computes exponential weighted average: var = weight_t * var + rest. + + Important to ensure that var does not occur in rest, otherwise + we can get race conditions in a distributed setting. + + Args: + var: variable to be updated + weight: parameter to be checked. If it is a constant, we can optimize. + weight_t: current value of parameter, used for weighting + rest: the remaining tensor to be added + + Returns: + updated variable. + """ + if weight == 0.0: + return rest # no need to update var, we will never use it. + if weight == 1.0: # common case + return state_ops.assign_add(var, rest) + # The op below can cause race conditions in a distributed setting, + # since computing weight_t * var + rest can take some time, during + # which var may be set by another worker. To prevent this, it should + # be implemented as a C++ op. + return var.assign_add((weight_t - 1) * var + rest) + + def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay, + mat_gbar_weight, i): + """Updates the cumulative outer products of the gradients. + + Args: + mat_g: the matrix to be updated + grad: the gradient of the variable + axes: a list of k-1 integers 0 to k-1, except i + mat_gbar_decay: constant for weighted average: + mat_g = mat_g * decay + grad * weight + mat_gbar_weight: constant for weighted average + i: index of dimension to be updated. + + Returns: + updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight + + In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd + thus grad_outer is a matrix d_i x d_i, where d_i is the size of the + i'th dimension of g. + Alternate view: If mat_i(grad) is the flattening of grad to a + d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then + grad_outer = mat_i(grad) mat_i(grad).transpose + """ + grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes), + name="grad_outer_" + str(i)) + return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay, + mat_gbar_weight * grad_outer) + + def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name): + """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix. + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g + alpha: a real number + mat_h_slot_name: name of slot to store the power, if needed. + + Returns: + mat_h = mat_g^alpha + + Stores mat_h in the appropriate slot, if it exists. + Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig. + """ + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size)) + diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) + mat_h = math_ops.matmul( + mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), + array_ops.transpose(mat_u)) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, + iter_count=100, epsilon=1e-6): + """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.""" + + mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size, + iter_count, self._epsilon) + mat_h = matrix_functions.matrix_inverse_pth_root( + mat_g_sqrt, + mat_g_size, + 2 * alpha, + iter_count, + epsilon, + ridge_epsilon=0.0) + + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None): + """Just a switch between the iterative power vs svd.""" + with ops.name_scope("matrix_iterative_power"): + if self._use_iterative_root: + return self._compute_power_iter(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + else: + return self._compute_power_svd(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + + def _apply_gradient(self, grad, var, indices=None): + """The main function to update a variable. + + Args: + grad: A Tensor containing gradient to apply. + var: A Tensor containing the variable to update. + indices: An array of integers, for sparse update. + + Returns: + Updated variable var = var - learning_rate * preconditioner * grad + + If the gradient is dense, var and grad have the same shape. + If the update is sparse, then the first dimension of the gradient and var + may differ, others are all the same. In this case the indices array + provides the set of indices of the variable which are to be updated with + each row of the gradient. + """ + global_step = self._global_step + 1 + + # Update accumulated weighted average of gradients + gbar = self.get_slot(var, "gbar") + gbar_decay_t = GetParam(self._gbar_decay, global_step) + gbar_weight_t = GetParam(self._gbar_weight, global_step) + if indices is not None: + # Note - the sparse update is not easily implemented, since the + # algorithm needs all indices of gbar to be updated + # if mat_gbar_decay != 1 or mat_gbar_decay != 0. + # One way to make mat_gbar_decay = 1 is by rescaling. + # If we want the update: + # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t + # define: + # r_{t+1} = a_{t+1} * r_t + # h_t = G_t / r_t + # Then: + # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t + # So we get the mat_gbar_decay = 1 as desired. + # We can implement this in a future version as needed. + # However we still need gbar_decay = 0, otherwise all indices + # of the variable will need to be updated. + if self._gbar_decay != 0.0: + tf_logging.warning("Not applying momentum for variable: %s" % var.name) + gbar_updated = grad + else: + gbar_updated = self._weighted_average(gbar, self._gbar_decay, + gbar_decay_t, + gbar_weight_t * grad) + + # Update the preconditioners and compute the preconditioned gradient + shape = var.get_shape() + mat_g_list = [] + for i in range(len(shape)): + mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) + mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) + mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) + + preconditioned_grad = gbar_updated + v_rank = len(mat_g_list) + neg_alpha = - GetParam(self._alpha, global_step) / v_rank + svd_interval = GetParam(self._svd_interval, global_step) + precond_update_interval = GetParam(self._precond_update_interval, + global_step) + for i, mat_g in enumerate(mat_g_list): + # axes is the list of indices to reduce - everything but the current i. + axes = list(range(i)) + list(range(i+1, v_rank)) + if shape[i] <= self._max_matrix_size: + # If the tensor size is sufficiently small perform full Shampoo update + # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this + # is not strictly correct. However we will use it for now, and + # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) + + # pylint: disable=g-long-lambda,cell-var-from-loop + mat_g_updated = control_flow_ops.cond( + math_ops.mod(global_step, precond_update_interval) < 1, + lambda: self._update_mat_g( + mat_g, grad, axes, mat_gbar_decay_t, + mat_gbar_weight_t * precond_update_interval, i), + lambda: mat_g) + + mat_g_updated = mat_g_updated / float(shape[i].value) + + if self._svd_interval == 1: + mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) + else: + mat_h = control_flow_ops.cond( + math_ops.mod(global_step, svd_interval) < 1, + lambda: self._compute_power(var, mat_g_updated, shape[i], + neg_alpha, "H_" + str(i)), + lambda: self.get_slot(var, "H_" + str(i))) + + # mat_h is a square matrix of size d_i x d_i + # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor + # After contraction with a d_i x d_i tensor + # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor + # (the first dimension is contracted out, and the second dimension of + # mat_h is appended). After going through all the indices, it becomes + # a d_0 x ... x d_n tensor again. + preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, + axes=([0], [0]), + name="precond_" + str(i)) + else: + # Tensor size is too large -- perform diagonal Shampoo update + # Only normalize non-vector cases. + if axes: + normalizer = 1.0 if indices is not None else float(shape[i].value) + grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer + else: + grad_outer = grad * grad + + if i == 0 and indices is not None: + assert self._mat_gbar_decay == 1.0 + mat_g_updated = state_ops.scatter_add(mat_g, indices, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow( + array_ops.gather(mat_g_updated, indices) + self._epsilon, + neg_alpha) + else: + mat_g_updated = self._weighted_average(mat_g, + self._mat_gbar_decay, + mat_gbar_decay_t, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha) + + # Need to do the transpose to ensure that the tensor becomes + # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. + preconditioned_grad = array_ops.transpose( + preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h + + # Update the variable based on the Shampoo update + learning_rate_t = GetParam(self._learning_rate, global_step) + if indices is not None: + var_updated = state_ops.scatter_add( + var, indices, -learning_rate_t * preconditioned_grad) + else: + var_updated = state_ops.assign_sub(var, + learning_rate_t * preconditioned_grad) + return var_updated diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py new file mode 100644 index 0000000000..05bcf2cfa3 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -0,0 +1,772 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional tests for AdaMoo optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import shampoo +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +TOLERANCE = 1e-3 +RIDGE_EPSILON = 1e-4 + + +def np_power(mat_g, alpha): + """Computes mat_g^alpha for a square symmetric matrix mat_g.""" + + mat_u, diag_d, mat_v = np.linalg.svd(mat_g) + diag_d = np.power(diag_d, alpha) + return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v) + + +class ShampooTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicVector(self, use_resource_var): + """Similar to the full Adagrad update.""" + + size = 20 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g^{-0.5} * grad + # lr = 1 + mat_g = np.outer(grad_np, grad_np) / grad_np.shape[0] + mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5) + new_val_np = init_var_np - np.dot(mat_h, grad_np) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += np.outer(grad_np_2, grad_np_2) / grad_np.shape[0] + mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5) + new_val_np -= np.dot(mat_h, grad_np_2) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicMatrix(self, use_resource_var): + """Check update when gradient is a matrix.""" + size = [10, 5] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25} + # lr = 1 + mat_g1 = np.dot(grad_np, grad_np.transpose()) / grad_np.shape[0] + mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) / grad_np_2.shape[0] + mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) + new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testBasicTensor(self, use_iterative_root, use_resource_var): + """Check update when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = ( + np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) / + grad_np.shape[0]) + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 = ( + np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) / + grad_np.shape[1]) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 = ( + np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) / + grad_np.shape[2]) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) + + precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += ( + np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / + grad_np_2.shape[0]) + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 += ( + np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / + grad_np_2.shape[1]) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 += ( + np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / + grad_np_2.shape[2]) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) + + precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensor(self, use_iterative_root, use_resource_var): + self._testBasicTensor(use_iterative_root, use_resource_var) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeVector(self, use_resource_var): + """This is just the diagonal Adagrad update.""" + + size = 2000 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * gg^{-0.5} * grad + # lr = 1 + mat_g = (grad_np * grad_np) + new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np + + self.assertAllCloseAccordingToType( + new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE) + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += (grad_np_2 * grad_np_2) + new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2 + + self.assertAllCloseAccordingToType( + new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE) + + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeMatrix(self, use_resource_var): + """Gradient is a matrix, one of whose dimensions is large. + + We do diagonal updates for large dimensions. + + Args: + use_resource_var: use resource var as variables. + """ + + size = [2000, 3] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + + mat_g1 = np.sum( + grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0] + mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.sum( + grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0] + mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) + new_val_np -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False)) + def testSparseUpdateLarge(self, use_resource_var): + """Check update when gradient is of type IndexSlices. + + We do diagonal updates for the first dimension, unless it is very small. + + Args: + use_resource_var: use resource var as variables. + """ + size = [2000, 3] + sample_size_1 = 100 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size_1, + replace=False)) + grad_np = np.random.rand(sample_size_1, size[1]) + + sample_size_2 = 7 + grad_indices_2 = np.sort(np.random.choice(np.arange(size[0]), sample_size_2, + replace=False)) + grad_np_2 = np.random.rand(sample_size_2, size[1]) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + grad_2 = ops.IndexedSlices( + constant_op.constant(grad_np_2, dtype=dtypes.float32), + constant_op.constant(grad_indices_2), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + # In this case the update lr * mat_left * grad * mat_right is + # of size 10 x 2. + # So the correct indices of var need to be updated. + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_g1_acc = np.zeros((size[0], 1)) + mat_g1_acc[grad_indices] += mat_g1 + mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) + new_val_np = init_var_np + new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_g1_acc[grad_indices_2] += mat_g1 + mat_left = np.power(mat_g1_acc[grad_indices_2] + RIDGE_EPSILON, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) + new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + """Gradient is of type IndexSlices, but the first dimension is small. + + We create dense gradient and do the full update with SVD etc. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + + size = [100, 3, 5] + sample_size = 10 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size, + replace=False)) + grad_np = np.random.rand(sample_size, size[1], size[2]) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.125} grad + # lr = 1 + grad_dense = np.zeros_like(init_var_np) + grad_dense[grad_indices] = grad_np + + mat_g1 = np.tensordot( + grad_dense, grad_dense, axes=([1, 2], [1, 2])) / grad_dense.shape[0] + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 = np.tensordot( + grad_dense, grad_dense, axes=([0, 2], [0, 2])) / grad_dense.shape[1] + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 = np.tensordot( + grad_dense, grad_dense, axes=([0, 1], [0, 1])) / grad_dense.shape[2] + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) + + precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + self._testSparseUpdateSmall(use_iterative_root, use_resource_var) + + def _testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + """Check update with momentum when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + gbar_decay = 0.9 + gbar_weight = 0.1 + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, gbar_decay=gbar_decay, + gbar_weight=gbar_weight, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot( + grad_np, grad_np, axes=([1, 2], [1, 2])) / grad_np.shape[0] + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 = np.tensordot( + grad_np, grad_np, axes=([0, 2], [0, 2])) / grad_np.shape[1] + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 = np.tensordot( + grad_np, grad_np, axes=([0, 1], [0, 1])) / grad_np.shape[2] + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) + + gbar_np = gbar_weight * grad_np + precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot( + grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / grad_np_2.shape[0] + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 += np.tensordot( + grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / grad_np_2.shape[1] + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 += np.tensordot( + grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / grad_np_2.shape[2] + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) + + gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2 + precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + self._testBasicTensorWithMomentum(use_iterative_root, use_resource_var) + + def _testDelayedSVD(self, use_iterative_root, use_resource_var): + """Performing the SVD every nth step. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 20 + svd_interval = 5 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer(global_step, svd_interval=svd_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 += np.tensordot( + grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / grad_np[i].shape[0] + mat_g2 += np.tensordot( + grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / grad_np[i].shape[1] + mat_g3 += np.tensordot( + grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / grad_np[i].shape[2] + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), + -0.5 / 3.0) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), + -0.5 / 3.0) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), + -0.5 / 3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedSVD(self, use_iterative_root, use_resource_var): + self._testDelayedSVD(use_iterative_root, use_resource_var) + + def _testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + """Update the squared sum every nth step, drop the other steps. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 100 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + svd_interval = 20 + precond_update_interval = 5 + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.cached_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer( + global_step, svd_interval=svd_interval, + precond_update_interval=precond_update_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + if (i + 1) % precond_update_interval == 0: + mat_g1 += ( + np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / + grad_np[i].shape[0] * precond_update_interval) + mat_g2 += ( + np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / + grad_np[i].shape[1] * precond_update_interval) + mat_g3 += ( + np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / + grad_np[i].shape[2] * precond_update_interval) + + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), + -0.5 / 3.0) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), + -0.5 / 3.0) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), + -0.5 / 3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + self._testDelayedPrecondUpdate(use_iterative_root, use_resource_var) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/opt/python/training/sign_decay_test.py b/tensorflow/contrib/opt/python/training/sign_decay_test.py index c31cb924ea..3a84789afd 100644 --- a/tensorflow/contrib/opt/python/training/sign_decay_test.py +++ b/tensorflow/contrib/opt/python/training/sign_decay_test.py @@ -66,7 +66,7 @@ class SignDecaysTest(test.TestCase): linear_decay_fn = sign_decay.get_linear_decay_fn(num_training_steps) for step in range(0, 1000, 100): - with self.test_session(): + with self.cached_session(): tf_decayed = linear_decay_fn(step).eval() py_decayed = py_linear_decay_fn(num_training_steps)(step) self.assertAlmostEqual(tf_decayed, py_decayed, places=4) @@ -78,7 +78,7 @@ class SignDecaysTest(test.TestCase): num_training_steps, num_periods=5, zero_after=2) for step in range(0, 1000, 100): - with self.test_session(): + with self.cached_session(): tf_decayed = cosine_decay_fn(step).eval() py_decayed = py_cosine_decay_fn(num_training_steps)(step) self.assertAlmostEqual(tf_decayed, py_decayed, places=4) @@ -95,7 +95,7 @@ class SignDecaysTest(test.TestCase): num_training_steps, num_periods=5, zero_after=2) for step in range(0, 1000, 100): - with self.test_session(): + with self.cached_session(): tf_decayed = restart_decay_fn(step).eval() py_decayed = py_restart_decay_fn(num_training_steps)(step) self.assertAlmostEqual(tf_decayed, py_decayed, places=4) diff --git a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py index fdda86b0b5..ff0ea8d766 100644 --- a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py @@ -158,7 +158,7 @@ class VariableClippingOptimizerTest(test.TestCase): def testDenseLocal(self): for dtype in [dtypes.float32, dtypes.float64, dtypes.half]: - with self.test_session(): + with self.cached_session(): var0, var1, update_op = self._setupDense(False, dtype) self._assertDenseCorrect(var0, var1, update_op) @@ -171,7 +171,7 @@ class VariableClippingOptimizerTest(test.TestCase): def testSparseLocal(self): for dtype in [dtypes.float64, dtypes.float32, dtypes.half]: - with self.test_session(): + with self.cached_session(): var0, var1, update_op = self._setupSparse(False, dtype) self._assertSparseCorrect(var0, var1, update_op) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py new file mode 100644 index 0000000000..200b0d2008 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -0,0 +1,435 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.opt.python.training import shampoo +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import adam +from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.training import optimizer +from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.ops import array_ops + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two examples + used in the above paper (SGDW and AdamW), but in general this can extend + any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + """ + + def __init__(self, weight_decay, **kwargs): + """Construct the extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by which + a variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ + self._decay_var_list = None # is set in minimize or apply_gradients + self._weight_decay = weight_decay + # The tensors are initialized in call to _prepare + self._weight_decay_tensor = None + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, colocate_gradients_with_ops=False, + name=None, grad_loss=None, decay_var_list=None): + """Add operations to minimize `loss` by updating `var_list` with decay. + + This function is the same as Optimizer.minimize except that it allows to + specify the variables that should be decayed using decay_var_list. + If decay_var_list is None, all variables in var_list are decayed. + + For more information see the documentation of Optimizer.minimize. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of decay variables. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, global_step=global_step, var_list=var_list, + gate_gradients=gate_gradients, aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, + grad_loss=grad_loss) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None, + decay_var_list=None): + """Apply gradients to variables and decay the variables. + + This function is the same as Optimizer.apply_gradients except that it + allows to specify the variables that should be decayed using + decay_var_list. If decay_var_list is None, all variables in var_list + are decayed. + + For more information see the documentation of Optimizer.apply_gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of decay variables. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, global_step=global_step, name=name) + + def _prepare(self): + weight_decay = self._weight_decay + if callable(weight_decay): + weight_decay = weight_decay() + self._weight_decay_tensor = ops.convert_to_tensor( + weight_decay, name="weight_decay") + # Call the optimizers _prepare function. + super(DecoupledWeightDecayExtension, self)._prepare() + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub(self._weight_decay * var, self._use_locking) + return control_flow_ops.no_op() + + def _decay_weights_sparse_op(self, var, indices, scatter_add): + if not self._decay_var_list or var in self._decay_var_list: + update = -self._weight_decay * array_ops.gather(var, indices) + return scatter_add(var, indices, update, self._use_locking) + return control_flow_ops.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + def _apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var) + + def _resource_apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( + grad, var) + + def _apply_sparse(self, grad, var): + scatter_add = state_ops.scatter_add + decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._apply_sparse( + grad, var) + + def _resource_scatter_add(self, x, i, v, _=None): + # last argument allows for one overflow argument, to have the same function + # signature as state_ops.scatter_add + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + scatter_add = self._resource_scatter_add + decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( + grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being decoupled from + the optimization steps w.r.t. to the loss function, as described by + Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). + For SGD variants, this simplifies hyperparameter search since + it decouples the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + # pylint: disable=useless-super-delegation + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + # pylint: enable=useless-super-delegation + + return OptimizerWithDecoupledWeightDecay + + +@tf_export("contrib.opt.MomentumWOptimizer") +class MomentumWOptimizer(DecoupledWeightDecayExtension, + momentum_opt.MomentumOptimizer): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the Momentum Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.MomentumOptimizer, + weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate, momentum, + use_locking=False, name="MomentumW", use_nesterov=False): + """Construct a new MomentumW optimizer. + + For further information see the documentation of the Momentum Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A `Tensor` or a floating point value. The learning rate. + momentum: A `Tensor` or a floating point value. The momentum. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Momentum". + use_nesterov: If `True` use Nesterov Momentum. + See [Sutskever et al., 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). + This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + + @compatibility(eager) + When eager execution is enabled, learning_rate, weight_decay and momentum + can each be a callable that takes no arguments and returns the actual value + to use. This can be useful for changing these values across different + invocations of optimizer functions. + @end_compatibility + """ + super(MomentumWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, momentum=momentum, + use_locking=use_locking, name=name, use_nesterov=use_nesterov) + + +@tf_export("contrib.opt.AdamWOptimizer") +class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `train.AdamOptimizer` and additionally decays + the variable. Note that this is different from adding L2 regularization on + the variables to the loss: it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="AdamW"): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(AdamWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=name) + + +@tf_export("contrib.opt.ShampooWOptimizer") +class ShampooWOptimizer(DecoupledWeightDecayExtension, + shampoo.ShampooOptimizer): + """Optimizer that implements the Shampoo algorithm with weight decay. + + For further information see the documentation of the Shampoo Optimizer. + """ + + def __init__(self, + weight_decay, + global_step, + max_matrix_size=768, + gbar_decay=0.0, + gbar_weight=1.0, + mat_gbar_decay=1.0, + mat_gbar_weight=1.0, + learning_rate=1.0, + svd_interval=1, + precond_update_interval=1, + epsilon=1e-4, + alpha=0.5, + use_iterative_root=False, + use_locking=False, + name="ShampooW"): + """Construct a new ShampooW optimizer. + + For further information see the documentation of the Shampoo Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + global_step: tensorflow variable indicating the step. + max_matrix_size: We do not perform SVD for matrices larger than this. + gbar_decay: + gbar_weight: Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] + + gbar_weight[t] * g[t] + mat_gbar_decay: + mat_gbar_weight: Used to update mat_gbar: mat_gbar_j[t] = + mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t] + learning_rate: Similar to SGD + svd_interval: We should do SVD after this many steps. Default = 1, i.e. + every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is + also OK. May also want more often early, + and less often later - set in caller as for example: + "svd_interval = lambda(T): tf.cond( + T < 2000, lambda: 20.0, lambda: 1000.0)" + precond_update_interval: We should update the preconditioners after this + many steps. Default = 1. Usually less than svd_interval. + epsilon: epsilon * I_n is added to each mat_gbar_j for stability + alpha: total power of the preconditioners. + use_iterative_root: should the optimizer use SVD (faster) or the iterative + root method (for TPU) for finding the roots of PSD matrices. + use_locking: If `True` use locks for update operations. + name: name of optimizer. + """ + super(ShampooWOptimizer, self).__init__( + weight_decay, + global_step=global_step, + max_matrix_size=max_matrix_size, + gbar_decay=gbar_decay, + gbar_weight=gbar_weight, + mat_gbar_decay=mat_gbar_weight, + learning_rate=learning_rate, + svd_interval=svd_interval, + precond_update_interval=precond_update_interval, + epsilon=epsilon, + alpha=alpha, + use_iterative_root=use_iterative_root, + use_locking=use_locking, + name=name) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py new file mode 100644 index 0000000000..9c91078301 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -0,0 +1,188 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import weight_decay_optimizers +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam + +WEIGHT_DECAY = 0.01 + + +def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9, + beta2=0.999, epsilon=1e-8): + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - + (param * WEIGHT_DECAY)) + return param_t, m_t, v_t + + +def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_): + # v, t are not needed for momentum optimizer + m = momentum * m + g_t + param_t = param - lr * m - param * WEIGHT_DECAY + return param_t, m, None + + +class WeightDecayOptimizerTest(test.TestCase): + + def doTest(self, optimizer, update_fn, optimizer_name, slot_name, + use_resource=False, do_sparse=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices(constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), + constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices(constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), + constant_op.constant([2])) + else: + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = optimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of the optimizer + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0) + var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/%s:0" % (i, optimizer_name), + opt.get_slot(var=var0, name=slot_name).name) + + +class AdamWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY) + + def testSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True) + + +class MomentumWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9) + + def testSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True) + + +class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adam.AdamOptimizer) + return adamw(WEIGHT_DECAY) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=True) + + +if __name__ == "__main__": + test.main() |