aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 22:30:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 22:33:42 -0700
commit57d31aa599c83014397a22bbb8f1a27a33b0ade3 (patch)
tree0735d1c9cd3d6fc3437c48414cafd6b217535977 /tensorflow/contrib/opt
parent176e6993c5e11631389e05f82b3d71a3a367e392 (diff)
Remove dependency on epsilon for diagonal shampoo.
PiperOrigin-RevId: 215857772
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py16
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py8
2 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
index f161521b97..e542f46892 100644
--- a/tensorflow/contrib/opt/python/training/shampoo.py
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -108,7 +108,8 @@ class ShampooOptimizer(optimizer.Optimizer):
precond_update_interval: We should update the preconditioners after
this many steps. Default = 1. Usually less than
svd_interval.
- epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability for
+ non-diagonal version of shampoo.
alpha: total power of the preconditioners.
use_iterative_root: should the optimizer use SVD (faster) or the
iterative root method (for TPU) for finding the
@@ -394,15 +395,20 @@ class ShampooOptimizer(optimizer.Optimizer):
assert self._mat_gbar_decay == 1.0
mat_g_updated = state_ops.scatter_add(mat_g, indices,
mat_gbar_weight_t * grad_outer)
- mat_h = math_ops.pow(
- array_ops.gather(mat_g_updated, indices) + self._epsilon,
- neg_alpha)
+ mat_g_updated_slice = array_ops.gather(mat_g_updated, indices)
+ mat_h = array_ops.where(
+ math_ops.greater(mat_g_updated_slice, 0),
+ math_ops.pow(mat_g_updated_slice, neg_alpha),
+ array_ops.zeros_like(mat_g_updated_slice))
else:
mat_g_updated = self._weighted_average(mat_g,
self._mat_gbar_decay,
mat_gbar_decay_t,
mat_gbar_weight_t * grad_outer)
- mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha)
+ mat_h = array_ops.where(
+ math_ops.greater(mat_g_updated, 0),
+ math_ops.pow(mat_g_updated, neg_alpha),
+ array_ops.zeros_like(mat_g_updated))
# Need to do the transpose to ensure that the tensor becomes
# a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index a2fd8fbd87..e88c8221a0 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -279,7 +279,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# Update rule is var = var - lr * gg^{-0.5} * grad
# lr = 1
mat_g = (grad_np * grad_np)
- new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np
+ new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np
self.assertAllCloseAccordingToType(
new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
@@ -288,7 +288,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
new_val = sess.run(var)
mat_g += (grad_np_2 * grad_np_2)
- new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2
+ new_val_np -= np.power(mat_g, -0.5) * grad_np_2
self.assertAllCloseAccordingToType(
new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
@@ -339,7 +339,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(
grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0]
- mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_left = np.power(mat_g1, -0.25)
mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
@@ -353,7 +353,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 += np.sum(
grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0]
- mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_left = np.power(mat_g1, -0.25)
mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)