diff options
Diffstat (limited to 'tensorflow/contrib/opt/python/training/shampoo.py')
-rw-r--r-- | tensorflow/contrib/opt/python/training/shampoo.py | 98 |
1 files changed, 22 insertions, 76 deletions
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py index 294627f42a..f161521b97 100644 --- a/tensorflow/contrib/opt/python/training/shampoo.py +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -23,6 +23,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from tensorflow.contrib.opt.python.training import matrix_functions from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -76,7 +77,7 @@ class ShampooOptimizer(optimizer.Optimizer): learning_rate=1.0, svd_interval=1, precond_update_interval=1, - epsilon=0.1, + epsilon=1e-4, alpha=0.5, use_iterative_root=False, use_locking=False, @@ -255,81 +256,18 @@ class ShampooOptimizer(optimizer.Optimizer): def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, iter_count=100, epsilon=1e-6): - """Computes mat_g^alpha, where alpha = -1/p, p a positive integer. + """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.""" + + mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size, + iter_count, self._epsilon) + mat_h = matrix_functions.matrix_inverse_pth_root( + mat_g_sqrt, + mat_g_size, + 2 * alpha, + iter_count, + epsilon, + ridge_epsilon=0.0) - We use an iterative Schur-Newton method from equation 3.2 on page 9 of: - - A Schur-Newton Method for the Matrix p-th Root and its Inverse - by Chun-Hua Guo and Nicholas J. Higham - SIAM Journal on Matrix Analysis and Applications, - 2006, Vol. 28, No. 3 : pp. 788-804 - https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf - - Args: - var: the variable we are updating. - mat_g: the symmetric PSD matrix whose power it to be computed - mat_g_size: size of mat_g. - alpha: exponent, must be -1/p for p a positive integer. - mat_h_slot_name: name of slot to store the power, if needed. - iter_count: Maximum number of iterations. - epsilon: accuracy indicator, useful for early termination. - - Returns: - mat_g^alpha - """ - - identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) - - def MatPower(mat_m, p): - """Computes mat_m^p, for p a positive integer. - - Power p is known at graph compile time, so no need for loop and cond. - Args: - mat_m: a square matrix - p: a positive integer - - Returns: - mat_m^p - """ - assert p == int(p) and p > 0 - power = None - while p > 0: - if p % 2 == 1: - power = math_ops.matmul(mat_m, power) if power is not None else mat_m - p //= 2 - mat_m = math_ops.matmul(mat_m, mat_m) - return power - - def IterCondition(i, mat_m, _): - return math_ops.logical_and( - i < iter_count, - math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon) - - def IterBody(i, mat_m, mat_x): - mat_m_i = (1 - alpha) * identity + alpha * mat_m - return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m), - math_ops.matmul(mat_x, mat_m_i)) - - if mat_g_size == 1: - mat_h = math_ops.pow(mat_g + self._epsilon, alpha) - else: - damped_mat_g = mat_g + self._epsilon * identity - z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g)) - # The best value for z is - # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / - # (c_max^{1-alpha} - c_min^{1-alpha}) - # where c_max and c_min are the largest and smallest singular values of - # damped_mat_g. - # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) - # Can replace above line by the one below, but it is less accurate, - # hence needs more iterations to converge. - # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g) - # If we want the method to always converge, use z = 1 / norm(damped_mat_g) - # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many - # extra iterations. - _, _, mat_h = control_flow_ops.while_loop( - IterCondition, IterBody, - [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)]) if mat_h_slot_name is not None: return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) return mat_h @@ -422,6 +360,8 @@ class ShampooOptimizer(optimizer.Optimizer): mat_gbar_weight_t * precond_update_interval, i), lambda: mat_g) + mat_g_updated = mat_g_updated / float(shape[i].value) + if self._svd_interval == 1: mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) else: @@ -443,7 +383,13 @@ class ShampooOptimizer(optimizer.Optimizer): name="precond_" + str(i)) else: # Tensor size is too large -- perform diagonal Shampoo update - grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) + # Only normalize non-vector cases. + if axes: + normalizer = 1.0 if indices is not None else float(shape[i].value) + grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer + else: + grad_outer = grad * grad + if i == 0 and indices is not None: assert self._mat_gbar_decay == 1.0 mat_g_updated = state_ops.scatter_add(mat_g, indices, |