diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-31 15:07:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 15:10:30 -0700 |
commit | 7d4c0998147fded8a2291dc097186634f017fedc (patch) | |
tree | 3768af2fc1795b563c076c8b97b654a5cc987c43 /tensorflow/contrib/opt | |
parent | f29635ccb73337d20b9f3e79b5d7917ae2bb56cc (diff) |
This is an initial submission of the Shampoo Optimizer to tensorflow contrib.
Paper link: https://arxiv.org/pdf/1802.09568.pdf
PiperOrigin-RevId: 206834923
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/BUILD | 19 | ||||
-rw-r--r-- | tensorflow/contrib/opt/__init__.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/shampoo.py | 463 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/shampoo_test.py | 669 |
4 files changed, 1153 insertions, 1 deletions
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index bbdf962d04..280d4a5492 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -27,6 +27,7 @@ py_library( "python/training/nadam_optimizer.py", "python/training/powersign.py", "python/training/reg_adagrad_optimizer.py", + "python/training/shampoo.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", "python/training/weight_decay_optimizers.py", @@ -344,3 +345,21 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "shampoo_test", + srcs = ["python/training/shampoo_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 3e63e99030..9471fb0181 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -30,10 +30,10 @@ from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.shampoo import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -62,6 +62,7 @@ _allowed_symbols = [ 'ModelAverageOptimizer', 'ModelAverageCustomGetter', 'GGTOptimizer', + 'ShampooOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py new file mode 100644 index 0000000000..7afa0998f4 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -0,0 +1,463 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""The Shampoo Optimizer. + +Variant of Adagrad using one preconditioner matrix per variable dimension. +For details, see https://arxiv.org/abs/1802.09568 +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import optimizer + + +def GetParam(var, timestep): + if callable(var): + return var(timestep) + else: + return var + + +class ShampooOptimizer(optimizer.Optimizer): + """The Shampoo Optimizer + + Variant of Adagrad using one preconditioner matrix per variable dimension. + For details, see https://arxiv.org/abs/1802.09568 + + gbar is time-weighted accumulated gradient: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + + mat_gbar is time-weighted accumulated gradient square: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation) + + Update rule: + w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t] + Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the + j'th dimension of gbar[t] with the first dimension of + mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter, + and n = rank of the variable. + Prod_j represents doing this contraction for all j in 0..n-1. + + Typically learning_rate is constant, but could be time dependent by passing + a lambda function that depends on step. + """ + + def __init__(self, global_step=0, + max_matrix_size=500, + gbar_decay=0.0, + gbar_weight=1.0, + mat_gbar_decay=1.0, + mat_gbar_weight=1.0, + learning_rate=1.0, + svd_interval=1, + precond_update_interval=1, + epsilon=0.1, + alpha=0.5, + use_iterative_root=False, + use_locking=False, + name="Shampoo"): + """Default values of the various hyper-parameters. + + gbar_decay, gbar_weight etc. can be a float or a time varying parameter. + For time-varying parameters use e.g. "lambda T: T / (T + 1.0)" + where the expression in the lambda is a tensorflow expression + + Args: + global_step: tensorflow variable indicating the step. + max_matrix_size: We do not perform SVD for matrices larger than this. + gbar_decay: + gbar_weight: Used to update gbar: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + mat_gbar_decay: + mat_gbar_weight: Used to update mat_gbar: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + learning_rate: Similar to SGD + svd_interval: We should do SVD after this many steps. Default = 1, i.e. + every step. Usually 20 leads to no loss of accuracy, and + 50 or 100 is also OK. May also want more often early, + and less often later - set in caller as for example: + "svd_interval = lambda(T): tf.cond( + T < 2000, lambda: 20.0, lambda: 1000.0)" + precond_update_interval: We should update the preconditioners after + this many steps. Default = 1. Usually less than + svd_interval. + epsilon: epsilon * I_n is added to each mat_gbar_j for stability + alpha: total power of the preconditioners. + use_iterative_root: should the optimizer use SVD (faster) or the + iterative root method (for TPU) for finding the + roots of PSD matrices. + use_locking: + name: name of optimizer. + """ + + super(ShampooOptimizer, self).__init__(use_locking, name) + + self._global_step = math_ops.to_float(global_step) + self._max_matrix_size = max_matrix_size + self._gbar_decay = gbar_decay + self._gbar_weight = gbar_weight + self._mat_gbar_decay = mat_gbar_decay + self._mat_gbar_weight = mat_gbar_weight + self._learning_rate = learning_rate + self._svd_interval = svd_interval + self._precond_update_interval = precond_update_interval + self._epsilon = epsilon + self._alpha = alpha + self._use_iterative_root = use_iterative_root + self._name = name + + def _create_slots(self, var_list): + for v in var_list: + with ops.colocate_with(v): + _ = self._zeros_slot(v, "gbar", self._name) + shape = np.array(v.get_shape()) + for i, d in enumerate(shape): + d_tensor = ops.convert_to_tensor(d) + if d < self._max_matrix_size: + mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) + if self._svd_interval > 1: + _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), + "H_" + str(i), self._name) + else: + mat_g_init = array_ops.zeros([d_tensor]) + + _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i), + self._name) + + def _apply_dense(self, grad, var): + return self._apply_gradient(grad, var) + + def _apply_sparse(self, grad, var): + if var.get_shape()[0] < self._max_matrix_size or self._gbar_decay != 0.0: + # The dimension is small enough, we can make the variable dense and + # do a dense update + dense_grad = array_ops.scatter_nd( + array_ops.expand_dims(grad.indices, axis=1), + grad.values, array_ops.shape(var, out_type=grad.indices.dtype)) + return self._apply_gradient(dense_grad, var) + return self._apply_gradient(grad.values, var, grad.indices) + + def _weighted_average(self, var, weight, weight_t, rest): + """Computes exponential weighted average: var = weight_t * var + rest. + + Important to ensure that var does not occur in rest, otherwise + we can get race conditions in a distributed setting. + + Args: + var: variable to be updated + weight: parameter to be checked. If it is a constant, we can optimize. + weight_t: current value of parameter, used for weighting + rest: the remaining tensor to be added + + Returns: + updated variable. + """ + if weight == 0.0: + return rest # no need to update var, we will never use it. + if weight == 1.0: # common case + return state_ops.assign_add(var, rest) + # The op below can cause race conditions in a distributed setting, + # since computing weight_t * var + rest can take some time, during + # which var may be set by another worker. To prevent this, it should + # be implemented as a C++ op. + return var.assign_add((weight_t - 1) * var + rest) + + def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay, + mat_gbar_weight, i): + """Updates the cumulative outer products of the gradients. + + Args: + mat_g: the matrix to be updated + grad: the gradient of the variable + axes: a list of k-1 integers 0 to k-1, except i + mat_gbar_decay: constant for weighted average: + mat_g = mat_g * decay + grad * weight + mat_gbar_weight: constant for weighted average + i: index of dimension to be updated. + + Returns: + updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight + + In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd + thus grad_outer is a matrix d_i x d_i, where d_i is the size of the + i'th dimension of g. + Alternate view: If mat_i(grad) is the flattening of grad to a + d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then + grad_outer = mat_i(grad) mat_i(grad).transpose + """ + grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes), + name="grad_outer_" + str(i)) + return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay, + mat_gbar_weight * grad_outer) + + def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name): + """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix. + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g + alpha: a real number + mat_h_slot_name: name of slot to store the power, if needed. + + Returns: + mat_h = mat_g^alpha + + Stores mat_h in the appropriate slot, if it exists. + Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig. + """ + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size)) + diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) + mat_h = math_ops.matmul( + mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), + array_ops.transpose(mat_u)) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, + iter_count=100, epsilon=1e-6): + """Computes mat_g^alpha, where alpha = -1/p, p a positive integer. + + We use an iterative Schur-Newton method from equation 3.2 on page 9 of: + + A Schur-Newton Method for the Matrix p-th Root and its Inverse + by Chun-Hua Guo and Nicholas J. Higham + SIAM Journal on Matrix Analysis and Applications, + 2006, Vol. 28, No. 3 : pp. 788-804 + https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g. + alpha: exponent, must be -1/p for p a positive integer. + mat_h_slot_name: name of slot to store the power, if needed. + iter_count: Maximum number of iterations. + epsilon: accuracy indicator, useful for early termination. + + Returns: + mat_g^alpha + """ + + identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) + + def MatPower(mat_m, p): + """Computes mat_m^p, for p a positive integer. + + Power p is known at graph compile time, so no need for loop and cond. + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + assert p == int(p) and p > 0 + power = None + while p > 0: + if p % 2 == 1: + power = math_ops.matmul(mat_m, power) if power is not None else mat_m + p //= 2 + mat_m = math_ops.matmul(mat_m, mat_m) + return power + + def IterCondition(i, mat_m, _): + return math_ops.logical_and( + i < iter_count, + math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon) + + def IterBody(i, mat_m, mat_x): + mat_m_i = (1 - alpha) * identity + alpha * mat_m + return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m), + math_ops.matmul(mat_x, mat_m_i)) + + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damped_mat_g = mat_g + self._epsilon * identity + z = (1 - 1/alpha) / (2 * linalg_ops.norm(damped_mat_g, ord=2)) + # The best value for z is + # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / + # (c_max^{1-alpha} - c_min^{1-alpha}) + # where c_max and c_min are the largest and smallest singular values of + # damped_mat_g. + # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) + # Can replace above line by the one below, but it is less accurate, + # hence needs more iterations to converge. + # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g) + # If we want the method to always converge, use z = 1 / norm(damped_mat_g) + # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many + # extra iterations. + _, _, mat_h = control_flow_ops.while_loop( + IterCondition, IterBody, + [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)]) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None): + """Just a switch between the iterative power vs svd.""" + if self._use_iterative_root: + return self._compute_power_iter(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + else: + return self._compute_power_svd(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + + def _apply_gradient(self, grad, var, indices=None): + """The main function to update a variable. + + Args: + grad: A Tensor containing gradient to apply. + var: A Tensor containing the variable to update. + indices: An array of integers, for sparse update. + + Returns: + Updated variable var = var - learning_rate * preconditioner * grad + + If the gradient is dense, var and grad have the same shape. + If the update is sparse, then the first dimension of the gradient and var + may differ, others are all the same. In this case the indices array + provides the set of indices of the variable which are to be updated with + each row of the gradient. + """ + global_step = self._global_step + 1 + + # Update accumulated weighted average of gradients + gbar = self.get_slot(var, "gbar") + gbar_decay_t = GetParam(self._gbar_decay, global_step) + gbar_weight_t = GetParam(self._gbar_weight, global_step) + if indices is not None: + # Note - the sparse update is not easily implemented, since the + # algorithm needs all indices of gbar to be updated + # if mat_gbar_decay != 1 or mat_gbar_decay != 0. + # One way to make mat_gbar_decay = 1 is by rescaling. + # If we want the update: + # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t + # define: + # r_{t+1} = a_{t+1} * r_t + # h_t = G_t / r_t + # Then: + # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t + # So we get the mat_gbar_decay = 1 as desired. + # We can implement this in a future version as needed. + # However we still need gbar_decay = 0, otherwise all indices + # of the variable will need to be updated. + if self._gbar_decay != 0.0: + tf_logging.warning("Not applying momentum for variable: %s" % var.name) + gbar_updated = grad + else: + gbar_updated = self._weighted_average(gbar, self._gbar_decay, + gbar_decay_t, + gbar_weight_t * grad) + + # Update the preconditioners and compute the preconditioned gradient + shape = var.get_shape() + mat_g_list = [] + for i in range(len(shape)): + mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) + mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) + mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) + + preconditioned_grad = gbar_updated + v_rank = len(mat_g_list) + neg_alpha = - GetParam(self._alpha, global_step) / v_rank + svd_interval = GetParam(self._svd_interval, global_step) + precond_update_interval = GetParam(self._precond_update_interval, + global_step) + for i, mat_g in enumerate(mat_g_list): + # axes is the list of indices to reduce - everything but the current i. + axes = list(range(i)) + list(range(i+1, v_rank)) + if shape[i] < self._max_matrix_size: + # If the tensor size is sufficiently small perform full Shampoo update + # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this + # is not strictly correct. However we will use it for now, and + # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) + + # pylint: disable=g-long-lambda,cell-var-from-loop + mat_g_updated = control_flow_ops.cond( + math_ops.mod(global_step, precond_update_interval) < 1, + lambda: self._update_mat_g( + mat_g, grad, axes, mat_gbar_decay_t, + mat_gbar_weight_t * precond_update_interval, i), + lambda: mat_g) + + if self._svd_interval == 1: + mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) + else: + mat_h = control_flow_ops.cond( + math_ops.mod(global_step, svd_interval) < 1, + lambda: self._compute_power(var, mat_g_updated, shape[i], + neg_alpha, "H_" + str(i)), + lambda: self.get_slot(var, "H_" + str(i))) + + # mat_h is a square matrix of size d_i x d_i + # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor + # After contraction with a d_i x d_i tensor + # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor + # (the first dimension is contracted out, and the second dimension of + # mat_h is appended). After going through all the indices, it becomes + # a d_0 x ... x d_n tensor again. + preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, + axes=([0], [0]), + name="precond_" + str(i)) + else: + # Tensor size is too large -- perform diagonal Shampoo update + grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) + if i == 0 and indices is not None: + assert self._mat_gbar_decay == 1.0 + mat_g_updated = state_ops.scatter_add(mat_g, indices, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow( + array_ops.gather(mat_g_updated, indices) + self._epsilon, + neg_alpha) + else: + mat_g_updated = self._weighted_average(mat_g, + self._mat_gbar_decay, + mat_gbar_decay_t, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha) + + # Need to do the transpose to ensure that the tensor becomes + # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. + preconditioned_grad = array_ops.transpose( + preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h + + # Update the variable based on the Shampoo update + learning_rate_t = GetParam(self._learning_rate, global_step) + if indices is not None: + var_updated = state_ops.scatter_sub(var, indices, + learning_rate_t * preconditioned_grad) + else: + var_updated = state_ops.assign_sub(var, + learning_rate_t * preconditioned_grad) + return var_updated diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py new file mode 100644 index 0000000000..3148d02296 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -0,0 +1,669 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional tests for AdaMoo optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import shampoo +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +TOLERANCE = 1e-3 + + +def np_power(mat_g, alpha): + """Computes mat_g^alpha for a square symmetric matrix mat_g.""" + + mat_u, diag_d, mat_v = np.linalg.svd(mat_g) + diag_d = np.power(diag_d, alpha) + return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v) + + +class ShampooTest(test.TestCase): + + def testBasicVector(self): + """Similar to the full Adagrad update.""" + + size = 20 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g^{-0.5} * grad + # lr = 1 + mat_g = np.outer(grad_np, grad_np) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np = init_var_np - np.dot(mat_h, grad_np) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += np.outer(grad_np_2, grad_np_2) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np -= np.dot(mat_h, grad_np_2) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def testBasicMatrix(self): + """Check update when gradient is a matrix.""" + size = [10, 5] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25} + # lr = 1 + mat_g1 = np.dot(grad_np, grad_np.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testBasicTensor(self, use_iterative_root): + """Check update when gradient is a tensor.""" + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def testBasicTensor(self): + for use_iterative_root in [True, False]: + self._testBasicTensor(use_iterative_root) + + def testLargeVector(self): + """This is just the diagonal Adagrad update.""" + + size = 2000 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * gg^{-0.5} * grad + # lr = 1 + mat_g = grad_np * grad_np + 0.1 + new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += grad_np_2 * grad_np_2 + new_val_np -= np.power(mat_g, -0.5) * grad_np_2 + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + def testLargeMatrix(self): + """Gradient is a matrix, one of whose dimensions is large. + + We do diagonal updates for large dimensions. + """ + + size = [2000, 3] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def testSparseUpdateLarge(self): + """Check update when gradient is of type IndexSlices. + + We do diagonal updates for the first dimension, unless it is very small. + """ + + size = [2000, 3] + sample_size_1 = 100 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size_1, + replace=False)) + grad_np = np.random.rand(sample_size_1, size[1]) + + sample_size_2 = 7 + grad_indices_2 = np.sort(np.random.choice(np.arange(size[0]), sample_size_2, + replace=False)) + grad_np_2 = np.random.rand(sample_size_2, size[1]) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + grad_2 = ops.IndexedSlices( + constant_op.constant(grad_np_2, dtype=dtypes.float32), + constant_op.constant(grad_indices_2), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + # In this case the update lr * mat_left * grad * mat_right is + # of size 10 x 2. + # So the correct indices of var need to be updated. + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_g1_acc = np.zeros((size[0], 1)) + mat_g1_acc[grad_indices] += mat_g1 + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np + new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_g1_acc[grad_indices_2] += mat_g1 + mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testSparseUpdateSmall(self, use_iterative_root): + """Gradient is of type IndexSlices, but the first dimension is small. + + We create dense gradient and do the full update with SVD etc. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + """ + + size = [100, 3, 5] + sample_size = 10 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size, + replace=False)) + grad_np = np.random.rand(sample_size, size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.125} grad + # lr = 1 + grad_dense = np.zeros_like(init_var_np) + grad_dense[grad_indices] = grad_np + + mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def testSparseUpdateSmall(self): + for use_iterative_root in [True, False]: + self._testSparseUpdateSmall(use_iterative_root) + + def _testBasicTensorWithMomentum(self, use_iterative_root): + """Check update with momentum when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + gbar_decay = 0.9 + gbar_weight = 0.1 + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, gbar_decay=gbar_decay, + gbar_weight=gbar_weight, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np = gbar_weight * grad_np + precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2 + precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def testBasicTensorWithMomentum(self): + for use_iterative_root in [True, False]: + self._testBasicTensorWithMomentum(use_iterative_root) + + def _testDelayedSVD(self, use_iterative_root): + """Performing the SVD every nth step. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 20 + svd_interval = 5 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer(global_step, svd_interval=svd_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def testDelayedSVD(self): + for use_iterative_root in [True, False]: + self._testDelayedSVD(use_iterative_root) + + def _testDelayedPrecondUpdate(self, use_iterative_root): + """Update the squared sum every nth step, drop the other steps. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 100 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + svd_interval = 20 + precond_update_interval = 5 + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable(0, dtype=dtypes.int64) + var = variables.Variable(init_var_np, dtype=dtypes.float32) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer( + global_step, svd_interval=svd_interval, + precond_update_interval=precond_update_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + if (i + 1) % precond_update_interval == 0: + mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + * precond_update_interval) + mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + * precond_update_interval) + mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + * precond_update_interval) + + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def testDelayedPrecondUpdate(self): + for use_iterative_root in [True, False]: + self._testDelayedPrecondUpdate(use_iterative_root) + + +if __name__ == '__main__': + test.main() |