aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-09 12:12:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 12:17:06 -0700
commit0771f37819c1077067340febad5a0d3abe8e561b (patch)
treeb89ae358a9b925bbd57e840e1d5b763d16b00033 /tensorflow/contrib/opt
parent52f43f299ee6e79394313cbae6654be12a3ee455 (diff)
Make sure Shampoo preconditions grads of matrices of size max_matrix_size
PiperOrigin-RevId: 208090007
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
index a98866b180..294627f42a 100644
--- a/tensorflow/contrib/opt/python/training/shampoo.py
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -139,7 +139,7 @@ class ShampooOptimizer(optimizer.Optimizer):
shape = np.array(v.get_shape())
for i, d in enumerate(shape):
d_tensor = ops.convert_to_tensor(d)
- if d < self._max_matrix_size:
+ if d <= self._max_matrix_size:
mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor))
if self._svd_interval > 1:
_ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor),
@@ -163,7 +163,7 @@ class ShampooOptimizer(optimizer.Optimizer):
return self._apply_sparse_shared(grad.values, grad.indices, var)
def _apply_sparse_shared(self, grad_values, grad_indices, var):
- if var.get_shape()[0] < self._max_matrix_size or self._gbar_decay != 0.0:
+ if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0:
# The dimension is small enough, we can make the variable dense and
# do a dense update
dense_grad = array_ops.scatter_nd(
@@ -408,7 +408,7 @@ class ShampooOptimizer(optimizer.Optimizer):
for i, mat_g in enumerate(mat_g_list):
# axes is the list of indices to reduce - everything but the current i.
axes = list(range(i)) + list(range(i+1, v_rank))
- if shape[i] < self._max_matrix_size:
+ if shape[i] <= self._max_matrix_size:
# If the tensor size is sufficiently small perform full Shampoo update
# Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
# is not strictly correct. However we will use it for now, and