aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-07 18:07:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 18:10:15 -0700
commitb941a031e8a2eb67e0083d8aa6ffe5a3ffe96f7b (patch)
tree38117b2d1d36c67e94a04103cb4c32f7205794e9 /tensorflow/contrib/learn
parentfba60ec27f4d415dafdf2ee916e2aa2004fa9635 (diff)
Pass checkpoint_path to predicate functions for experiment.continuous_eval even in the case of falsy eval_results
PiperOrigin-RevId: 199728382
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py2
2 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 541da90617..f8a3709ee5 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -505,7 +505,7 @@ class Experiment(object):
eval_result = None
last_warning_time = 0
while (not predicate_fn or predicate_fn(
- eval_result, checkpoint_path=previous_path if eval_result else None)):
+ eval_result, checkpoint_path=previous_path)):
# Exit if we have already reached number of steps to train.
if self._has_training_stopped(eval_result):
logging.info("Exiting continuous eval, global_step=%s >= "
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index d10927a0cd..fb16c94c29 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase):
noop_hook = _NoopHook()
def _predicate_fn(eval_result, checkpoint_path):
- self.assertEqual(not eval_result,
+ self.assertEqual(eval_result is None,
checkpoint_path is None)
return est.eval_count < 3 # pylint: disable=cell-var-from-loop