diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-31 16:24:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 16:38:15 -0700 |
commit | b1341d049a74b7054450a9641b1c186f23df501f (patch) | |
tree | 0c0fe23c80d517baf5315804e25ea40b8468357b /tensorflow/contrib/opt/python | |
parent | c72dc92100be6169ed5ae5b59ba7a91b34cb2db4 (diff) |
Fix normalization in Shampoo when dealing with differently sized tensors.
Add M^1/2 to reduce condition numbers, before computing inverse pth root.
PiperOrigin-RevId: 211162032
Diffstat (limited to 'tensorflow/contrib/opt/python')
-rw-r--r-- | tensorflow/contrib/opt/python/training/shampoo.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/shampoo_test.py | 194 |
2 files changed, 133 insertions, 83 deletions
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py index d897ede0c7..f161521b97 100644 --- a/tensorflow/contrib/opt/python/training/shampoo.py +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -77,7 +77,7 @@ class ShampooOptimizer(optimizer.Optimizer): learning_rate=1.0, svd_interval=1, precond_update_interval=1, - epsilon=0.1, + epsilon=1e-4, alpha=0.5, use_iterative_root=False, use_locking=False, @@ -257,13 +257,17 @@ class ShampooOptimizer(optimizer.Optimizer): def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, iter_count=100, epsilon=1e-6): """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.""" + + mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size, + iter_count, self._epsilon) mat_h = matrix_functions.matrix_inverse_pth_root( - mat_g, + mat_g_sqrt, mat_g_size, - alpha, + 2 * alpha, iter_count, epsilon, - ridge_epsilon=self._epsilon) + ridge_epsilon=0.0) + if mat_h_slot_name is not None: return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) return mat_h @@ -356,6 +360,8 @@ class ShampooOptimizer(optimizer.Optimizer): mat_gbar_weight_t * precond_update_interval, i), lambda: mat_g) + mat_g_updated = mat_g_updated / float(shape[i].value) + if self._svd_interval == 1: mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) else: @@ -377,7 +383,13 @@ class ShampooOptimizer(optimizer.Optimizer): name="precond_" + str(i)) else: # Tensor size is too large -- perform diagonal Shampoo update - grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) + # Only normalize non-vector cases. + if axes: + normalizer = 1.0 if indices is not None else float(shape[i].value) + grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer + else: + grad_outer = grad * grad + if i == 0 and indices is not None: assert self._mat_gbar_decay == 1.0 mat_g_updated = state_ops.scatter_add(mat_g, indices, diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py index b3688ab181..05bcf2cfa3 100644 --- a/tensorflow/contrib/opt/python/training/shampoo_test.py +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test TOLERANCE = 1e-3 +RIDGE_EPSILON = 1e-4 def np_power(mat_g, alpha): @@ -77,8 +78,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # let up compute this in numpy # Update rule is var = var - lr * mat_g^{-0.5} * grad # lr = 1 - mat_g = np.outer(grad_np, grad_np) - mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + mat_g = np.outer(grad_np, grad_np) / grad_np.shape[0] + mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5) new_val_np = init_var_np - np.dot(mat_h, grad_np) self.assertAllCloseAccordingToType(new_val_np, new_val, @@ -88,8 +89,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase): update_2.run() new_val = sess.run(var) - mat_g += np.outer(grad_np_2, grad_np_2) - mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + mat_g += np.outer(grad_np_2, grad_np_2) / grad_np.shape[0] + mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5) new_val_np -= np.dot(mat_h, grad_np_2) self.assertAllCloseAccordingToType(new_val_np, new_val, @@ -128,10 +129,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # let up compute this in numpy # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25} # lr = 1 - mat_g1 = np.dot(grad_np, grad_np.transpose()) - mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) - mat_g2 = np.dot(grad_np.transpose(), grad_np) - mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + mat_g1 = np.dot(grad_np, grad_np.transpose()) / grad_np.shape[0] + mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right) self.assertAllCloseAccordingToType(new_val_np, new_val, @@ -141,10 +142,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase): update_2.run() new_val = sess.run(var) - mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) - mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) - mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) - mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) / grad_np_2.shape[0] + mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right) self.assertAllCloseAccordingToType(new_val_np, new_val, @@ -188,12 +189,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # let up compute this in numpy # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad # lr = 1 - mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) - mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) - mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) - mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) - mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) - mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + mat_g1 = ( + np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) / + grad_np.shape[0]) + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 = ( + np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) / + grad_np.shape[1]) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 = ( + np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) / + grad_np.shape[2]) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0])) precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) @@ -207,12 +214,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase): update_2.run() new_val = sess.run(var) - mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) - mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) - mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) - mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) - mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) - mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + mat_g1 += ( + np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / + grad_np_2.shape[0]) + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 += ( + np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / + grad_np_2.shape[1]) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 += ( + np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / + grad_np_2.shape[2]) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0])) precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) @@ -265,19 +278,21 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # let up compute this in numpy # Update rule is var = var - lr * gg^{-0.5} * grad # lr = 1 - mat_g = grad_np * grad_np + 0.1 - new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np - - self.assertAllCloseAccordingToType(new_val_np, new_val) + mat_g = (grad_np * grad_np) + new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np + self.assertAllCloseAccordingToType( + new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE) # Run another step of Shampoo update_2.run() new_val = sess.run(var) - mat_g += grad_np_2 * grad_np_2 - new_val_np -= np.power(mat_g, -0.5) * grad_np_2 + mat_g += (grad_np_2 * grad_np_2) + new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2 + + self.assertAllCloseAccordingToType( + new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE) - self.assertAllCloseAccordingToType(new_val_np, new_val) @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) def testLargeMatrix(self, use_resource_var): @@ -322,10 +337,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # with broadcasting # lr = 1 - mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) - mat_left = np.power(mat_g1 + 0.1, -0.25) - mat_g2 = np.dot(grad_np.transpose(), grad_np) - mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + mat_g1 = np.sum( + grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0] + mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right) self.assertAllCloseAccordingToType(new_val_np, new_val, @@ -335,10 +351,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase): update_2.run() new_val = sess.run(var) - mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) - mat_left = np.power(mat_g1 + 0.1, -0.25) - mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) - mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + mat_g1 += np.sum( + grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0] + mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) new_val_np -= np.dot(grad_np_2 * mat_left, mat_right) self.assertAllCloseAccordingToType(new_val_np, new_val, @@ -405,9 +422,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) mat_g1_acc = np.zeros((size[0], 1)) mat_g1_acc[grad_indices] += mat_g1 - mat_left = np.power(mat_g1 + 0.1, -0.25) - mat_g2 = np.dot(grad_np.transpose(), grad_np) - mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) new_val_np = init_var_np new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right) @@ -420,9 +437,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) mat_g1_acc[grad_indices_2] += mat_g1 - mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25) - mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) - mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + mat_left = np.power(mat_g1_acc[grad_indices_2] + RIDGE_EPSILON, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1] + mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25) new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right) self.assertAllCloseAccordingToType(new_val_np, new_val, @@ -474,12 +491,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_dense = np.zeros_like(init_var_np) grad_dense[grad_indices] = grad_np - mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2])) - mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) - mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2])) - mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) - mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1])) - mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + mat_g1 = np.tensordot( + grad_dense, grad_dense, axes=([1, 2], [1, 2])) / grad_dense.shape[0] + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 = np.tensordot( + grad_dense, grad_dense, axes=([0, 2], [0, 2])) / grad_dense.shape[1] + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 = np.tensordot( + grad_dense, grad_dense, axes=([0, 1], [0, 1])) / grad_dense.shape[2] + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0])) precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) @@ -536,12 +556,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # let up compute this in numpy # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad # lr = 1 - mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) - mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) - mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) - mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) - mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) - mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + mat_g1 = np.tensordot( + grad_np, grad_np, axes=([1, 2], [1, 2])) / grad_np.shape[0] + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 = np.tensordot( + grad_np, grad_np, axes=([0, 2], [0, 2])) / grad_np.shape[1] + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 = np.tensordot( + grad_np, grad_np, axes=([0, 1], [0, 1])) / grad_np.shape[2] + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) gbar_np = gbar_weight * grad_np precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0])) @@ -556,12 +579,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase): update_2.run() new_val = sess.run(var) - mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) - mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) - mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) - mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) - mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) - mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + mat_g1 += np.tensordot( + grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / grad_np_2.shape[0] + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0) + mat_g2 += np.tensordot( + grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / grad_np_2.shape[1] + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0) + mat_g3 += np.tensordot( + grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / grad_np_2.shape[2] + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0) gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2 precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0])) @@ -626,13 +652,19 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # let up compute this in numpy # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad # lr = 1 - mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) - mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) - mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + mat_g1 += np.tensordot( + grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / grad_np[i].shape[0] + mat_g2 += np.tensordot( + grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / grad_np[i].shape[1] + mat_g3 += np.tensordot( + grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / grad_np[i].shape[2] if (i + 1) % svd_interval == 0: - mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) - mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) - mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), + -0.5 / 3.0) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), + -0.5 / 3.0) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), + -0.5 / 3.0) precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) @@ -700,17 +732,23 @@ class ShampooTest(test.TestCase, parameterized.TestCase): # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad # lr = 1 if (i + 1) % precond_update_interval == 0: - mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) - * precond_update_interval) - mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) - * precond_update_interval) - mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) - * precond_update_interval) + mat_g1 += ( + np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / + grad_np[i].shape[0] * precond_update_interval) + mat_g2 += ( + np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / + grad_np[i].shape[1] * precond_update_interval) + mat_g3 += ( + np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / + grad_np[i].shape[2] * precond_update_interval) if (i + 1) % svd_interval == 0: - mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) - mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) - mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), + -0.5 / 3.0) + mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), + -0.5 / 3.0) + mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), + -0.5 / 3.0) precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) |