aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 10:17:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 10:22:42 -0700
commitaca93368a979419360c1fd84b53b1766b19ba81a (patch)
tree2312ef53a30251ec2f5538d43ba066550679f6d9 /tensorflow/contrib/estimator
parent8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff)
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var() function and just use a standard assign_add(). PiperOrigin-RevId: 210743165
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index 505c94e971..513feb03b6 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -37,13 +37,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
@@ -339,7 +339,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -347,7 +347,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(