aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-31 15:07:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 15:10:30 -0700
commit7d4c0998147fded8a2291dc097186634f017fedc (patch)
tree3768af2fc1795b563c076c8b97b654a5cc987c43 /tensorflow/contrib/opt
parentf29635ccb73337d20b9f3e79b5d7917ae2bb56cc (diff)
This is an initial submission of the Shampoo Optimizer to tensorflow contrib.
Paper link: https://arxiv.org/pdf/1802.09568.pdf PiperOrigin-RevId: 206834923
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/BUILD19
-rw-r--r--tensorflow/contrib/opt/__init__.py3
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py463
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py669
4 files changed, 1153 insertions, 1 deletions
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index bbdf962d04..280d4a5492 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -27,6 +27,7 @@ py_library(
"python/training/nadam_optimizer.py",
"python/training/powersign.py",
"python/training/reg_adagrad_optimizer.py",
+ "python/training/shampoo.py",
"python/training/sign_decay.py",
"python/training/variable_clipping_optimizer.py",
"python/training/weight_decay_optimizers.py",
@@ -344,3 +345,21 @@ py_test(
"//third_party/py/numpy",
],
)
+
+py_test(
+ name = "shampoo_test",
+ srcs = ["python/training/shampoo_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 3e63e99030..9471fb0181 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -30,10 +30,10 @@ from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
+from tensorflow.contrib.opt.python.training.shampoo import *
from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
-from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -62,6 +62,7 @@ _allowed_symbols = [
'ModelAverageOptimizer',
'ModelAverageCustomGetter',
'GGTOptimizer',
+ 'ShampooOptimizer',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
new file mode 100644
index 0000000000..7afa0998f4
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -0,0 +1,463 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""The Shampoo Optimizer.
+
+Variant of Adagrad using one preconditioner matrix per variable dimension.
+For details, see https://arxiv.org/abs/1802.09568
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import optimizer
+
+
+def GetParam(var, timestep):
+ if callable(var):
+ return var(timestep)
+ else:
+ return var
+
+
+class ShampooOptimizer(optimizer.Optimizer):
+ """The Shampoo Optimizer
+
+ Variant of Adagrad using one preconditioner matrix per variable dimension.
+ For details, see https://arxiv.org/abs/1802.09568
+
+ gbar is time-weighted accumulated gradient:
+ gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
+
+ mat_gbar is time-weighted accumulated gradient square:
+ mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
+ + mat_gbar_weight[t] * gg_j[t]
+ where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation)
+
+ Update rule:
+ w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t]
+ Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the
+ j'th dimension of gbar[t] with the first dimension of
+ mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter,
+ and n = rank of the variable.
+ Prod_j represents doing this contraction for all j in 0..n-1.
+
+ Typically learning_rate is constant, but could be time dependent by passing
+ a lambda function that depends on step.
+ """
+
+ def __init__(self, global_step=0,
+ max_matrix_size=500,
+ gbar_decay=0.0,
+ gbar_weight=1.0,
+ mat_gbar_decay=1.0,
+ mat_gbar_weight=1.0,
+ learning_rate=1.0,
+ svd_interval=1,
+ precond_update_interval=1,
+ epsilon=0.1,
+ alpha=0.5,
+ use_iterative_root=False,
+ use_locking=False,
+ name="Shampoo"):
+ """Default values of the various hyper-parameters.
+
+ gbar_decay, gbar_weight etc. can be a float or a time varying parameter.
+ For time-varying parameters use e.g. "lambda T: T / (T + 1.0)"
+ where the expression in the lambda is a tensorflow expression
+
+ Args:
+ global_step: tensorflow variable indicating the step.
+ max_matrix_size: We do not perform SVD for matrices larger than this.
+ gbar_decay:
+ gbar_weight: Used to update gbar:
+ gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
+ mat_gbar_decay:
+ mat_gbar_weight: Used to update mat_gbar:
+ mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
+ + mat_gbar_weight[t] * gg_j[t]
+ learning_rate: Similar to SGD
+ svd_interval: We should do SVD after this many steps. Default = 1, i.e.
+ every step. Usually 20 leads to no loss of accuracy, and
+ 50 or 100 is also OK. May also want more often early,
+ and less often later - set in caller as for example:
+ "svd_interval = lambda(T): tf.cond(
+ T < 2000, lambda: 20.0, lambda: 1000.0)"
+ precond_update_interval: We should update the preconditioners after
+ this many steps. Default = 1. Usually less than
+ svd_interval.
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ alpha: total power of the preconditioners.
+ use_iterative_root: should the optimizer use SVD (faster) or the
+ iterative root method (for TPU) for finding the
+ roots of PSD matrices.
+ use_locking:
+ name: name of optimizer.
+ """
+
+ super(ShampooOptimizer, self).__init__(use_locking, name)
+
+ self._global_step = math_ops.to_float(global_step)
+ self._max_matrix_size = max_matrix_size
+ self._gbar_decay = gbar_decay
+ self._gbar_weight = gbar_weight
+ self._mat_gbar_decay = mat_gbar_decay
+ self._mat_gbar_weight = mat_gbar_weight
+ self._learning_rate = learning_rate
+ self._svd_interval = svd_interval
+ self._precond_update_interval = precond_update_interval
+ self._epsilon = epsilon
+ self._alpha = alpha
+ self._use_iterative_root = use_iterative_root
+ self._name = name
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ with ops.colocate_with(v):
+ _ = self._zeros_slot(v, "gbar", self._name)
+ shape = np.array(v.get_shape())
+ for i, d in enumerate(shape):
+ d_tensor = ops.convert_to_tensor(d)
+ if d < self._max_matrix_size:
+ mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor))
+ if self._svd_interval > 1:
+ _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor),
+ "H_" + str(i), self._name)
+ else:
+ mat_g_init = array_ops.zeros([d_tensor])
+
+ _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i),
+ self._name)
+
+ def _apply_dense(self, grad, var):
+ return self._apply_gradient(grad, var)
+
+ def _apply_sparse(self, grad, var):
+ if var.get_shape()[0] < self._max_matrix_size or self._gbar_decay != 0.0:
+ # The dimension is small enough, we can make the variable dense and
+ # do a dense update
+ dense_grad = array_ops.scatter_nd(
+ array_ops.expand_dims(grad.indices, axis=1),
+ grad.values, array_ops.shape(var, out_type=grad.indices.dtype))
+ return self._apply_gradient(dense_grad, var)
+ return self._apply_gradient(grad.values, var, grad.indices)
+
+ def _weighted_average(self, var, weight, weight_t, rest):
+ """Computes exponential weighted average: var = weight_t * var + rest.
+
+ Important to ensure that var does not occur in rest, otherwise
+ we can get race conditions in a distributed setting.
+
+ Args:
+ var: variable to be updated
+ weight: parameter to be checked. If it is a constant, we can optimize.
+ weight_t: current value of parameter, used for weighting
+ rest: the remaining tensor to be added
+
+ Returns:
+ updated variable.
+ """
+ if weight == 0.0:
+ return rest # no need to update var, we will never use it.
+ if weight == 1.0: # common case
+ return state_ops.assign_add(var, rest)
+ # The op below can cause race conditions in a distributed setting,
+ # since computing weight_t * var + rest can take some time, during
+ # which var may be set by another worker. To prevent this, it should
+ # be implemented as a C++ op.
+ return var.assign_add((weight_t - 1) * var + rest)
+
+ def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay,
+ mat_gbar_weight, i):
+ """Updates the cumulative outer products of the gradients.
+
+ Args:
+ mat_g: the matrix to be updated
+ grad: the gradient of the variable
+ axes: a list of k-1 integers 0 to k-1, except i
+ mat_gbar_decay: constant for weighted average:
+ mat_g = mat_g * decay + grad * weight
+ mat_gbar_weight: constant for weighted average
+ i: index of dimension to be updated.
+
+ Returns:
+ updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight
+
+ In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd
+ thus grad_outer is a matrix d_i x d_i, where d_i is the size of the
+ i'th dimension of g.
+ Alternate view: If mat_i(grad) is the flattening of grad to a
+ d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then
+ grad_outer = mat_i(grad) mat_i(grad).transpose
+ """
+ grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes),
+ name="grad_outer_" + str(i))
+ return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay,
+ mat_gbar_weight * grad_outer)
+
+ def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name):
+ """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix.
+
+ Args:
+ var: the variable we are updating.
+ mat_g: the symmetric PSD matrix whose power it to be computed
+ mat_g_size: size of mat_g
+ alpha: a real number
+ mat_h_slot_name: name of slot to store the power, if needed.
+
+ Returns:
+ mat_h = mat_g^alpha
+
+ Stores mat_h in the appropriate slot, if it exists.
+ Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig.
+ """
+ if mat_g_size == 1:
+ mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
+ else:
+ damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size))
+ diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True)
+ mat_h = math_ops.matmul(
+ mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha),
+ array_ops.transpose(mat_u))
+ if mat_h_slot_name is not None:
+ return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
+ return mat_h
+
+ def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name,
+ iter_count=100, epsilon=1e-6):
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+
+ We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
+
+ A Schur-Newton Method for the Matrix p-th Root and its Inverse
+ by Chun-Hua Guo and Nicholas J. Higham
+ SIAM Journal on Matrix Analysis and Applications,
+ 2006, Vol. 28, No. 3 : pp. 788-804
+ https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
+
+ Args:
+ var: the variable we are updating.
+ mat_g: the symmetric PSD matrix whose power it to be computed
+ mat_g_size: size of mat_g.
+ alpha: exponent, must be -1/p for p a positive integer.
+ mat_h_slot_name: name of slot to store the power, if needed.
+ iter_count: Maximum number of iterations.
+ epsilon: accuracy indicator, useful for early termination.
+
+ Returns:
+ mat_g^alpha
+ """
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
+
+ def MatPower(mat_m, p):
+ """Computes mat_m^p, for p a positive integer.
+
+ Power p is known at graph compile time, so no need for loop and cond.
+ Args:
+ mat_m: a square matrix
+ p: a positive integer
+
+ Returns:
+ mat_m^p
+ """
+ assert p == int(p) and p > 0
+ power = None
+ while p > 0:
+ if p % 2 == 1:
+ power = math_ops.matmul(mat_m, power) if power is not None else mat_m
+ p //= 2
+ mat_m = math_ops.matmul(mat_m, mat_m)
+ return power
+
+ def IterCondition(i, mat_m, _):
+ return math_ops.logical_and(
+ i < iter_count,
+ math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
+
+ def IterBody(i, mat_m, mat_x):
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
+ return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m),
+ math_ops.matmul(mat_x, mat_m_i))
+
+ if mat_g_size == 1:
+ mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
+ else:
+ damped_mat_g = mat_g + self._epsilon * identity
+ z = (1 - 1/alpha) / (2 * linalg_ops.norm(damped_mat_g, ord=2))
+ # The best value for z is
+ # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
+ # (c_max^{1-alpha} - c_min^{1-alpha})
+ # where c_max and c_min are the largest and smallest singular values of
+ # damped_mat_g.
+ # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
+ # Can replace above line by the one below, but it is less accurate,
+ # hence needs more iterations to converge.
+ # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
+ # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
+ # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
+ # extra iterations.
+ _, _, mat_h = control_flow_ops.while_loop(
+ IterCondition, IterBody,
+ [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
+ if mat_h_slot_name is not None:
+ return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
+ return mat_h
+
+ def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None):
+ """Just a switch between the iterative power vs svd."""
+ if self._use_iterative_root:
+ return self._compute_power_iter(var, mat_g, mat_g_size, alpha,
+ mat_h_slot_name)
+ else:
+ return self._compute_power_svd(var, mat_g, mat_g_size, alpha,
+ mat_h_slot_name)
+
+ def _apply_gradient(self, grad, var, indices=None):
+ """The main function to update a variable.
+
+ Args:
+ grad: A Tensor containing gradient to apply.
+ var: A Tensor containing the variable to update.
+ indices: An array of integers, for sparse update.
+
+ Returns:
+ Updated variable var = var - learning_rate * preconditioner * grad
+
+ If the gradient is dense, var and grad have the same shape.
+ If the update is sparse, then the first dimension of the gradient and var
+ may differ, others are all the same. In this case the indices array
+ provides the set of indices of the variable which are to be updated with
+ each row of the gradient.
+ """
+ global_step = self._global_step + 1
+
+ # Update accumulated weighted average of gradients
+ gbar = self.get_slot(var, "gbar")
+ gbar_decay_t = GetParam(self._gbar_decay, global_step)
+ gbar_weight_t = GetParam(self._gbar_weight, global_step)
+ if indices is not None:
+ # Note - the sparse update is not easily implemented, since the
+ # algorithm needs all indices of gbar to be updated
+ # if mat_gbar_decay != 1 or mat_gbar_decay != 0.
+ # One way to make mat_gbar_decay = 1 is by rescaling.
+ # If we want the update:
+ # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t
+ # define:
+ # r_{t+1} = a_{t+1} * r_t
+ # h_t = G_t / r_t
+ # Then:
+ # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t
+ # So we get the mat_gbar_decay = 1 as desired.
+ # We can implement this in a future version as needed.
+ # However we still need gbar_decay = 0, otherwise all indices
+ # of the variable will need to be updated.
+ if self._gbar_decay != 0.0:
+ tf_logging.warning("Not applying momentum for variable: %s" % var.name)
+ gbar_updated = grad
+ else:
+ gbar_updated = self._weighted_average(gbar, self._gbar_decay,
+ gbar_decay_t,
+ gbar_weight_t * grad)
+
+ # Update the preconditioners and compute the preconditioned gradient
+ shape = var.get_shape()
+ mat_g_list = []
+ for i in range(len(shape)):
+ mat_g_list.append(self.get_slot(var, "Gbar_" + str(i)))
+ mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step)
+ mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step)
+
+ preconditioned_grad = gbar_updated
+ v_rank = len(mat_g_list)
+ neg_alpha = - GetParam(self._alpha, global_step) / v_rank
+ svd_interval = GetParam(self._svd_interval, global_step)
+ precond_update_interval = GetParam(self._precond_update_interval,
+ global_step)
+ for i, mat_g in enumerate(mat_g_list):
+ # axes is the list of indices to reduce - everything but the current i.
+ axes = list(range(i)) + list(range(i+1, v_rank))
+ if shape[i] < self._max_matrix_size:
+ # If the tensor size is sufficiently small perform full Shampoo update
+ # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
+ # is not strictly correct. However we will use it for now, and
+ # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg)
+
+ # pylint: disable=g-long-lambda,cell-var-from-loop
+ mat_g_updated = control_flow_ops.cond(
+ math_ops.mod(global_step, precond_update_interval) < 1,
+ lambda: self._update_mat_g(
+ mat_g, grad, axes, mat_gbar_decay_t,
+ mat_gbar_weight_t * precond_update_interval, i),
+ lambda: mat_g)
+
+ if self._svd_interval == 1:
+ mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
+ else:
+ mat_h = control_flow_ops.cond(
+ math_ops.mod(global_step, svd_interval) < 1,
+ lambda: self._compute_power(var, mat_g_updated, shape[i],
+ neg_alpha, "H_" + str(i)),
+ lambda: self.get_slot(var, "H_" + str(i)))
+
+ # mat_h is a square matrix of size d_i x d_i
+ # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor
+ # After contraction with a d_i x d_i tensor
+ # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor
+ # (the first dimension is contracted out, and the second dimension of
+ # mat_h is appended). After going through all the indices, it becomes
+ # a d_0 x ... x d_n tensor again.
+ preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h,
+ axes=([0], [0]),
+ name="precond_" + str(i))
+ else:
+ # Tensor size is too large -- perform diagonal Shampoo update
+ grad_outer = math_ops.reduce_sum(grad * grad, axis=axes)
+ if i == 0 and indices is not None:
+ assert self._mat_gbar_decay == 1.0
+ mat_g_updated = state_ops.scatter_add(mat_g, indices,
+ mat_gbar_weight_t * grad_outer)
+ mat_h = math_ops.pow(
+ array_ops.gather(mat_g_updated, indices) + self._epsilon,
+ neg_alpha)
+ else:
+ mat_g_updated = self._weighted_average(mat_g,
+ self._mat_gbar_decay,
+ mat_gbar_decay_t,
+ mat_gbar_weight_t * grad_outer)
+ mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha)
+
+ # Need to do the transpose to ensure that the tensor becomes
+ # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
+ preconditioned_grad = array_ops.transpose(
+ preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h
+
+ # Update the variable based on the Shampoo update
+ learning_rate_t = GetParam(self._learning_rate, global_step)
+ if indices is not None:
+ var_updated = state_ops.scatter_sub(var, indices,
+ learning_rate_t * preconditioned_grad)
+ else:
+ var_updated = state_ops.assign_sub(var,
+ learning_rate_t * preconditioned_grad)
+ return var_updated
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
new file mode 100644
index 0000000000..3148d02296
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -0,0 +1,669 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional tests for AdaMoo optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import shampoo
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+TOLERANCE = 1e-3
+
+
+def np_power(mat_g, alpha):
+ """Computes mat_g^alpha for a square symmetric matrix mat_g."""
+
+ mat_u, diag_d, mat_v = np.linalg.svd(mat_g)
+ diag_d = np.power(diag_d, alpha)
+ return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v)
+
+
+class ShampooTest(test.TestCase):
+
+ def testBasicVector(self):
+ """Similar to the full Adagrad update."""
+
+ size = 20
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size)
+ grad_np_2 = np.random.rand(size)
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_g^{-0.5} * grad
+ # lr = 1
+ mat_g = np.outer(grad_np, grad_np)
+ mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ new_val_np = init_var_np - np.dot(mat_h, grad_np)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g += np.outer(grad_np_2, grad_np_2)
+ mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ new_val_np -= np.dot(mat_h, grad_np_2)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testBasicMatrix(self):
+ """Check update when gradient is a matrix."""
+ size = [10, 5]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1])
+ grad_np_2 = np.random.rand(size[0], size[1])
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25}
+ # lr = 1
+ mat_g1 = np.dot(grad_np, grad_np.transpose())
+ mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np)
+ mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += np.dot(grad_np_2, grad_np_2.transpose())
+ mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
+ mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def _testBasicTensor(self, use_iterative_root):
+ """Check update when gradient is a tensor."""
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1], size[2])
+ grad_np_2 = np.random.rand(size[0], size[1], size[2])
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
+ mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
+ mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
+ mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
+ mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
+ mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+
+ precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np = init_var_np - precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
+ mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
+ mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
+ mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
+ mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
+ mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+
+ precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testBasicTensor(self):
+ for use_iterative_root in [True, False]:
+ self._testBasicTensor(use_iterative_root)
+
+ def testLargeVector(self):
+ """This is just the diagonal Adagrad update."""
+
+ size = 2000
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size)
+ grad_np_2 = np.random.rand(size)
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * gg^{-0.5} * grad
+ # lr = 1
+ mat_g = grad_np * grad_np + 0.1
+ new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g += grad_np_2 * grad_np_2
+ new_val_np -= np.power(mat_g, -0.5) * grad_np_2
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val)
+
+ def testLargeMatrix(self):
+ """Gradient is a matrix, one of whose dimensions is large.
+
+ We do diagonal updates for large dimensions.
+ """
+
+ size = [2000, 3]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1])
+ grad_np_2 = np.random.rand(size[0], size[1])
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_left * grad * mat_right
+ # where the mat_left * grad is just element-wise product,
+ # with broadcasting
+ # lr = 1
+
+ mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
+ mat_left = np.power(mat_g1 + 0.1, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np)
+ mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
+ mat_left = np.power(mat_g1 + 0.1, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
+ mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testSparseUpdateLarge(self):
+ """Check update when gradient is of type IndexSlices.
+
+ We do diagonal updates for the first dimension, unless it is very small.
+ """
+
+ size = [2000, 3]
+ sample_size_1 = 100
+ init_var_np = np.zeros(size)
+ grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size_1,
+ replace=False))
+ grad_np = np.random.rand(sample_size_1, size[1])
+
+ sample_size_2 = 7
+ grad_indices_2 = np.sort(np.random.choice(np.arange(size[0]), sample_size_2,
+ replace=False))
+ grad_np_2 = np.random.rand(sample_size_2, size[1])
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = ops.IndexedSlices(
+ constant_op.constant(grad_np, dtype=dtypes.float32),
+ constant_op.constant(grad_indices),
+ constant_op.constant(size))
+ grad_2 = ops.IndexedSlices(
+ constant_op.constant(grad_np_2, dtype=dtypes.float32),
+ constant_op.constant(grad_indices_2),
+ constant_op.constant(size))
+
+ opt = shampoo.ShampooOptimizer(global_step)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * mat_left * grad * mat_right
+ # where the mat_left * grad is just element-wise product,
+ # with broadcasting
+ # lr = 1
+ # In this case the update lr * mat_left * grad * mat_right is
+ # of size 10 x 2.
+ # So the correct indices of var need to be updated.
+
+ mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
+ mat_g1_acc = np.zeros((size[0], 1))
+ mat_g1_acc[grad_indices] += mat_g1
+ mat_left = np.power(mat_g1 + 0.1, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np)
+ mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ new_val_np = init_var_np
+ new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
+ mat_g1_acc[grad_indices_2] += mat_g1
+ mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
+ mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right)
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def _testSparseUpdateSmall(self, use_iterative_root):
+ """Gradient is of type IndexSlices, but the first dimension is small.
+
+ We create dense gradient and do the full update with SVD etc.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ """
+
+ size = [100, 3, 5]
+ sample_size = 10
+ init_var_np = np.zeros(size)
+ grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size,
+ replace=False))
+ grad_np = np.random.rand(sample_size, size[1], size[2])
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = ops.IndexedSlices(
+ constant_op.constant(grad_np, dtype=dtypes.float32),
+ constant_op.constant(grad_indices),
+ constant_op.constant(size))
+
+ opt = shampoo.ShampooOptimizer(global_step,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.125} grad
+ # lr = 1
+ grad_dense = np.zeros_like(init_var_np)
+ grad_dense[grad_indices] = grad_np
+
+ mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2]))
+ mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
+ mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2]))
+ mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
+ mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1]))
+ mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+
+ precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np = init_var_np - precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testSparseUpdateSmall(self):
+ for use_iterative_root in [True, False]:
+ self._testSparseUpdateSmall(use_iterative_root)
+
+ def _testBasicTensorWithMomentum(self, use_iterative_root):
+ """Check update with momentum when gradient is a tensor.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ """
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size)
+ grad_np = np.random.rand(size[0], size[1], size[2])
+ grad_np_2 = np.random.rand(size[0], size[1], size[2])
+ gbar_decay = 0.9
+ gbar_weight = 0.1
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = constant_op.constant(grad_np, dtype=dtypes.float32)
+ grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
+
+ opt = shampoo.ShampooOptimizer(global_step, gbar_decay=gbar_decay,
+ gbar_weight=gbar_weight,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ update_2 = opt.apply_gradients(zip([grad_2], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ # Run a step of Shampoo
+ update.run()
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
+ mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
+ mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
+ mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
+ mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
+ mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+
+ gbar_np = gbar_weight * grad_np
+ precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np = init_var_np - precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ # Run another step of Shampoo
+ update_2.run()
+ new_val = sess.run(var)
+
+ mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
+ mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
+ mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
+ mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
+ mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
+ mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+
+ gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2
+ precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testBasicTensorWithMomentum(self):
+ for use_iterative_root in [True, False]:
+ self._testBasicTensorWithMomentum(use_iterative_root)
+
+ def _testDelayedSVD(self, use_iterative_root):
+ """Performing the SVD every nth step.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ """
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size).astype(np.float32)
+ iterations = 20
+ svd_interval = 5
+ grad_np = np.random.rand(
+ iterations, size[0], size[1], size[2]).astype(np.float32)
+ mat_g1_a = np.eye(size[0])
+ mat_g1 = np.zeros_like(mat_g1_a)
+ mat_g2_a = np.eye(size[1])
+ mat_g2 = np.zeros_like(mat_g2_a)
+ mat_g3_a = np.eye(size[2])
+ mat_g3 = np.zeros_like(mat_g3_a)
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = array_ops.placeholder(dtypes.float32, shape=size)
+
+ opt = shampoo.ShampooOptimizer(global_step, svd_interval=svd_interval,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+ new_val_np = init_var_np
+
+ # Run n steps of Shampoo
+ for i in range(iterations):
+ _ = sess.run(update, feed_dict={grad: grad_np[i]})
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
+ mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
+ mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
+ if (i + 1) % svd_interval == 0:
+ mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
+ mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
+ mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+
+ precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testDelayedSVD(self):
+ for use_iterative_root in [True, False]:
+ self._testDelayedSVD(use_iterative_root)
+
+ def _testDelayedPrecondUpdate(self, use_iterative_root):
+ """Update the squared sum every nth step, drop the other steps.
+
+ Args:
+ use_iterative_root: use iterative power method or SVD to find nth roots.
+ """
+ size = [10, 5, 7]
+ init_var_np = np.zeros(size).astype(np.float32)
+ iterations = 100
+ grad_np = np.random.rand(
+ iterations, size[0], size[1], size[2]).astype(np.float32)
+ svd_interval = 20
+ precond_update_interval = 5
+ mat_g1_a = np.eye(size[0])
+ mat_g1 = np.zeros_like(mat_g1_a)
+ mat_g2_a = np.eye(size[1])
+ mat_g2 = np.zeros_like(mat_g2_a)
+ mat_g3_a = np.eye(size[2])
+ mat_g3 = np.zeros_like(mat_g3_a)
+
+ with self.test_session() as sess:
+ global_step = variables.Variable(0, dtype=dtypes.int64)
+ var = variables.Variable(init_var_np, dtype=dtypes.float32)
+ grad = array_ops.placeholder(dtypes.float32, shape=size)
+
+ opt = shampoo.ShampooOptimizer(
+ global_step, svd_interval=svd_interval,
+ precond_update_interval=precond_update_interval,
+ use_iterative_root=use_iterative_root)
+ update = opt.apply_gradients(zip([grad], [var]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ init_val = sess.run(var)
+ self.assertAllCloseAccordingToType(init_var_np, init_val)
+ new_val_np = init_var_np
+
+ # Run n steps of Shampoo
+ for i in range(iterations):
+ _ = sess.run(update, feed_dict={grad: grad_np[i]})
+ new_val = sess.run(var)
+
+ # let up compute this in numpy
+ # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
+ # lr = 1
+ if (i + 1) % precond_update_interval == 0:
+ mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
+ * precond_update_interval)
+ mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
+ * precond_update_interval)
+ mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
+ * precond_update_interval)
+
+ if (i + 1) % svd_interval == 0:
+ mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
+ mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
+ mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+
+ precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
+ precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0]))
+ new_val_np -= precond_grad
+
+ self.assertAllCloseAccordingToType(new_val_np, new_val,
+ atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testDelayedPrecondUpdate(self):
+ for use_iterative_root in [True, False]:
+ self._testDelayedPrecondUpdate(use_iterative_root)
+
+
+if __name__ == '__main__':
+ test.main()