diff options
author | 2018-08-22 10:34:09 +0200 | |
---|---|---|
committer | 2018-08-22 10:35:29 +0200 | |
commit | 5fae570b594485b7288c7c556359058d04a5b845 (patch) | |
tree | 3c8b2fd2e01deab50ed53fe1063ab26b95a277a1 /tensorflow/contrib/opt | |
parent | 05f8ea8e9522a3027d4f3f7a54d716bfafed427a (diff) |
Fix sparse updates for optimizer using DecoupledWeightDecay.
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/weight_decay_optimizers.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index b9cf40eb7b..29acfc602e 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -26,6 +26,7 @@ from tensorflow.python.training import adam from tensorflow.python.training import momentum as momentum_opt from tensorflow.python.training import optimizer from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.ops import array_ops class DecoupledWeightDecayExtension(object): @@ -159,8 +160,8 @@ class DecoupledWeightDecayExtension(object): def _decay_weights_sparse_op(self, var, indices, scatter_add): if not self._decay_var_list or var in self._decay_var_list: - return scatter_add(var, indices, -self._weight_decay * var, - self._use_locking) + update = -self._weight_decay * array_ops.gather(var, indices) + return scatter_add(var, indices, update, self._use_locking) return control_flow_ops.no_op() # Here, we overwrite the apply functions that the base optimizer calls. |