aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-06-25 12:39:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 12:43:56 -0700
commitaedd096b83886e0ea99611cc5488284a98fb9b01 (patch)
treeffff736b60b3b6edc7965d15422a7a15021b1618 /tensorflow/contrib/opt
parentadc03426f11e7d22dca8ee59146423eb4d9668eb (diff)
Uses resource variables by default for the global step.
Relnotes: hooks will now see deterministically the value of the global step before updating instead of the value after updating. PiperOrigin-RevId: 202000826
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py
index 4a905b1b2a..c5de3188e7 100644
--- a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer.py
@@ -63,7 +63,7 @@ class DropStaleGradientOptimizer(optimizer.Optimizer):
def compute_gradients(self, loss, *args, **kwargs):
# Record current global step for worker.
with ops.colocate_with(loss):
- self._local_step = training_util.get_global_step() + 0
+ self._local_step = training_util.get_global_step().read_value() + 0
with ops.control_dependencies([self._local_step]):
loss = gen_array_ops.identity(loss)
@@ -102,13 +102,13 @@ class DropStaleGradientOptimizer(optimizer.Optimizer):
with ops.control_dependencies(gradients), ops.colocate_with(global_step):
staleness = gen_array_ops.reshape(
- global_step - self._local_step, shape=())
+ global_step.read_value() - self._local_step, shape=())
conditional_update = stale_counter.assign_add(control_flow_ops.cond(
gen_math_ops.less_equal(staleness, self._staleness),
_AcceptGradientOp, _DropGradientOp))
summary.scalar(
- "Gradient staleness percentage",
- stale_counter / (math_ops.cast(global_step + 1, dtypes.float32)))
+ "Gradient staleness percentage", stale_counter / (math_ops.cast(
+ global_step.read_value() + 1, dtypes.float32)))
return conditional_update