aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/proximal_gradient_descent.py
diff options
context:
space:
mode:
authorGravatar Frank Li <lif@google.com>2016-07-19 20:35:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 21:47:38 -0700
commit82e53ebecdac677b37cd9316f8d1be5b1627eec3 (patch)
treed85df1cc9065727038f9a2a23d71888574197b29 /tensorflow/python/training/proximal_gradient_descent.py
parent35eaa24772e69e8dfddd284300851699e12f46ab (diff)
Fix sparse case of ProximalGradientDescent not being plumbed correctly
Change: 127908886
Diffstat (limited to 'tensorflow/python/training/proximal_gradient_descent.py')
-rw-r--r--tensorflow/python/training/proximal_gradient_descent.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/python/training/proximal_gradient_descent.py b/tensorflow/python/training/proximal_gradient_descent.py
index 299c6fa1c7..a5ff8a6cea 100644
--- a/tensorflow/python/training/proximal_gradient_descent.py
+++ b/tensorflow/python/training/proximal_gradient_descent.py
@@ -68,9 +68,14 @@ class ProximalGradientDescentOptimizer(optimizer.Optimizer):
use_locking=self._use_locking).op
def _apply_sparse(self, grad, var):
- delta = ops.IndexedSlices(grad.values * self._learning_rate_tensor,
- grad.indices, grad.dense_shape)
- return var.scatter_sub(delta, use_locking=self._use_locking)
+ return training_ops.sparse_apply_proximal_gradient_descent(
+ var,
+ self._learning_rate_tensor,
+ self._l1_regularization_strength_tensor,
+ self._l2_regularization_strength_tensor,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking).op
def _prepare(self):
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,