diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-26 06:00:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-26 06:04:42 -0700 |
commit | 544436bbdf279dd4be68ad71536ea0488258aa07 (patch) | |
tree | ff85ac9a4f472be34b739981c7388e218a4e46ec /tensorflow/contrib/opt | |
parent | ed37c8a8a07734a4eb13e14d7d7b67c81a2968b7 (diff) |
Automated g4 rollback of changelist 202000826
PiperOrigin-RevId: 202115471
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py index c5de3188e7..4a905b1b2a 100644 --- a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py +++ b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py @@ -63,7 +63,7 @@ class DropStaleGradientOptimizer(optimizer.Optimizer): def compute_gradients(self, loss, *args, **kwargs): # Record current global step for worker. with ops.colocate_with(loss): - self._local_step = training_util.get_global_step().read_value() + 0 + self._local_step = training_util.get_global_step() + 0 with ops.control_dependencies([self._local_step]): loss = gen_array_ops.identity(loss) @@ -102,13 +102,13 @@ class DropStaleGradientOptimizer(optimizer.Optimizer): with ops.control_dependencies(gradients), ops.colocate_with(global_step): staleness = gen_array_ops.reshape( - global_step.read_value() - self._local_step, shape=()) + global_step - self._local_step, shape=()) conditional_update = stale_counter.assign_add(control_flow_ops.cond( gen_math_ops.less_equal(staleness, self._staleness), _AcceptGradientOp, _DropGradientOp)) summary.scalar( - "Gradient staleness percentage", stale_counter / (math_ops.cast( - global_step.read_value() + 1, dtypes.float32))) + "Gradient staleness percentage", + stale_counter / (math_ops.cast(global_step + 1, dtypes.float32))) return conditional_update |