diff options
author | 2018-08-09 12:12:53 -0700 | |
---|---|---|
committer | 2018-08-09 12:17:06 -0700 | |
commit | 0771f37819c1077067340febad5a0d3abe8e561b (patch) | |
tree | b89ae358a9b925bbd57e840e1d5b763d16b00033 /tensorflow/contrib/opt | |
parent | 52f43f299ee6e79394313cbae6654be12a3ee455 (diff) |
Make sure Shampoo preconditions grads of matrices of size max_matrix_size
PiperOrigin-RevId: 208090007
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/shampoo.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py index a98866b180..294627f42a 100644 --- a/tensorflow/contrib/opt/python/training/shampoo.py +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -139,7 +139,7 @@ class ShampooOptimizer(optimizer.Optimizer): shape = np.array(v.get_shape()) for i, d in enumerate(shape): d_tensor = ops.convert_to_tensor(d) - if d < self._max_matrix_size: + if d <= self._max_matrix_size: mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) if self._svd_interval > 1: _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), @@ -163,7 +163,7 @@ class ShampooOptimizer(optimizer.Optimizer): return self._apply_sparse_shared(grad.values, grad.indices, var) def _apply_sparse_shared(self, grad_values, grad_indices, var): - if var.get_shape()[0] < self._max_matrix_size or self._gbar_decay != 0.0: + if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0: # The dimension is small enough, we can make the variable dense and # do a dense update dense_grad = array_ops.scatter_nd( @@ -408,7 +408,7 @@ class ShampooOptimizer(optimizer.Optimizer): for i, mat_g in enumerate(mat_g_list): # axes is the list of indices to reduce - everything but the current i. axes = list(range(i)) + list(range(i+1, v_rank)) - if shape[i] < self._max_matrix_size: + if shape[i] <= self._max_matrix_size: # If the tensor size is sufficiently small perform full Shampoo update # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this # is not strictly correct. However we will use it for now, and |