diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-29 10:17:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-29 10:22:42 -0700 |
commit | aca93368a979419360c1fd84b53b1766b19ba81a (patch) | |
tree | 2312ef53a30251ec2f5538d43ba066550679f6d9 /tensorflow/contrib/estimator | |
parent | 8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff) |
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var()
function and just use a standard assign_add().
PiperOrigin-RevId: 210743165
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/baseline_test.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py index 505c94e971..513feb03b6 100644 --- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -37,13 +37,13 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer from tensorflow.python.training import saver @@ -339,7 +339,7 @@ class BaselineEstimatorTrainingTest(test.TestCase): self.assertEquals(0, loss.shape.ndims) if expected_loss is None: if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() assert_loss = assert_close( math_ops.to_float(expected_loss, name='expected'), @@ -347,7 +347,7 @@ class BaselineEstimatorTrainingTest(test.TestCase): name='assert_loss') with ops.control_dependencies((assert_loss,)): if global_step is not None: - return distribute_lib.increment_var(global_step) + return state_ops.assign_add(global_step, 1).op return control_flow_ops.no_op() mock_optimizer = test.mock.NonCallableMock( |