aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/evaluation.py10
-rw-r--r--tensorflow/python/training/evaluation_test.py36
2 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py
index 3baf1541aa..fdcb9c2e90 100644
--- a/tensorflow/python/training/evaluation.py
+++ b/tensorflow/python/training/evaluation.py
@@ -83,7 +83,8 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
"""Constructs the run hook.
Args:
- num_evals: The number of evaluations to run for.
+ num_evals: The number of evaluations to run for. if set to None, will
+ iterate the dataset until all inputs are exhausted.
log_progress: Whether to log evaluation progress, defaults to True.
"""
# The number of evals to run for.
@@ -102,8 +103,11 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values):
evals_completed = run_values.results['evals_completed']
if self._log_progress:
- logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
- if evals_completed >= self._num_evals:
+ if self._num_evals is None:
+ logging.info('Evaluation [%d]', evals_completed)
+ else:
+ logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
+ if self._num_evals is not None and evals_completed >= self._num_evals:
run_context.request_stop()
diff --git a/tensorflow/python/training/evaluation_test.py b/tensorflow/python/training/evaluation_test.py
index 78efaa8439..3de4ceda75 100644
--- a/tensorflow/python/training/evaluation_test.py
+++ b/tensorflow/python/training/evaluation_test.py
@@ -129,6 +129,42 @@ class EvaluateOnceTest(test.TestCase):
hooks=[evaluation._StopAfterNEvalsHook(1),])
self.assertTrue(final_ops_values['accuracy'] > .99)
+ def testEvaluateWithFiniteInputs(self):
+ checkpoint_dir = os.path.join(self.get_temp_dir(),
+ 'evaluate_with_finite_inputs')
+
+ # Train a Model to completion:
+ self._train_model(checkpoint_dir, num_steps=300)
+
+ # Run evaluation. Inputs are fed through input producer for one epoch.
+ all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
+ all_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
+
+ single_input, single_label = training.slice_input_producer(
+ [all_inputs, all_labels], num_epochs=1)
+ inputs, labels = training.batch([single_input, single_label], batch_size=6,
+ allow_smaller_final_batch=True)
+
+ logits = logistic_classifier(inputs)
+ predictions = math_ops.round(logits)
+
+ accuracy, update_op = metrics.accuracy(
+ predictions=predictions, labels=labels)
+
+ checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
+
+ final_ops_values = evaluation._evaluate_once(
+ checkpoint_path=checkpoint_path,
+ eval_ops=update_op,
+ final_ops={'accuracy': accuracy,
+ 'eval_steps': evaluation._get_or_create_eval_step()},
+ hooks=[evaluation._StopAfterNEvalsHook(None),])
+ self.assertTrue(final_ops_values['accuracy'] > .99)
+ # Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs
+ # each; the 3rd iter evaluates the remaining 4 inputs, and the last one
+ # triggers an error which stops evaluation.
+ self.assertEqual(final_ops_values['eval_steps'], 4)
+
def testEvalOpAndFinalOp(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')