diff options
-rw-r--r-- | tensorflow/python/training/evaluation.py | 10 | ||||
-rw-r--r-- | tensorflow/python/training/evaluation_test.py | 36 |
2 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py index 3baf1541aa..fdcb9c2e90 100644 --- a/tensorflow/python/training/evaluation.py +++ b/tensorflow/python/training/evaluation.py @@ -83,7 +83,8 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook): """Constructs the run hook. Args: - num_evals: The number of evaluations to run for. + num_evals: The number of evaluations to run for. if set to None, will + iterate the dataset until all inputs are exhausted. log_progress: Whether to log evaluation progress, defaults to True. """ # The number of evals to run for. @@ -102,8 +103,11 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook): def after_run(self, run_context, run_values): evals_completed = run_values.results['evals_completed'] if self._log_progress: - logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) - if evals_completed >= self._num_evals: + if self._num_evals is None: + logging.info('Evaluation [%d]', evals_completed) + else: + logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) + if self._num_evals is not None and evals_completed >= self._num_evals: run_context.request_stop() diff --git a/tensorflow/python/training/evaluation_test.py b/tensorflow/python/training/evaluation_test.py index 78efaa8439..3de4ceda75 100644 --- a/tensorflow/python/training/evaluation_test.py +++ b/tensorflow/python/training/evaluation_test.py @@ -129,6 +129,42 @@ class EvaluateOnceTest(test.TestCase): hooks=[evaluation._StopAfterNEvalsHook(1),]) self.assertTrue(final_ops_values['accuracy'] > .99) + def testEvaluateWithFiniteInputs(self): + checkpoint_dir = os.path.join(self.get_temp_dir(), + 'evaluate_with_finite_inputs') + + # Train a Model to completion: + self._train_model(checkpoint_dir, num_steps=300) + + # Run evaluation. Inputs are fed through input producer for one epoch. + all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) + all_labels = constant_op.constant(self._labels, dtype=dtypes.float32) + + single_input, single_label = training.slice_input_producer( + [all_inputs, all_labels], num_epochs=1) + inputs, labels = training.batch([single_input, single_label], batch_size=6, + allow_smaller_final_batch=True) + + logits = logistic_classifier(inputs) + predictions = math_ops.round(logits) + + accuracy, update_op = metrics.accuracy( + predictions=predictions, labels=labels) + + checkpoint_path = saver.latest_checkpoint(checkpoint_dir) + + final_ops_values = evaluation._evaluate_once( + checkpoint_path=checkpoint_path, + eval_ops=update_op, + final_ops={'accuracy': accuracy, + 'eval_steps': evaluation._get_or_create_eval_step()}, + hooks=[evaluation._StopAfterNEvalsHook(None),]) + self.assertTrue(final_ops_values['accuracy'] > .99) + # Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs + # each; the 3rd iter evaluates the remaining 4 inputs, and the last one + # triggers an error which stops evaluation. + self.assertEqual(final_ops_values['eval_steps'], 4) + def testEvalOpAndFinalOp(self): checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops') |