diff options
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 398ff20b6b..a1659156a6 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -618,12 +618,20 @@ class EstimatorEvaluateTest(test.TestCase): class EstimatorPredictTest(test.TestCase): - def test_no_trained_model(self): + def test_no_trained_model_in_model_dir(self): est = estimator.Estimator(model_fn=model_fn_global_step_incrementer) with self.assertRaisesRegexp(ValueError, 'Could not find trained model in model_dir'): next(est.predict(dummy_input_fn)) + def test_no_trained_model_invalid_checkpoint_path(self): + est = estimator.Estimator(model_fn=model_fn_global_step_incrementer) + with self.assertRaises(ValueError): + next( + est.predict( + dummy_input_fn, + checkpoint_path=saver.latest_checkpoint('fakedir'))) + def test_tensor_predictions(self): def _model_fn(features, labels, mode): @@ -828,6 +836,28 @@ class EstimatorPredictTest(test.TestCase): est2 = estimator.Estimator(model_fn=_model_fn, model_dir=est1.model_dir) self.assertEqual([32.], next(est2.predict(dummy_input_fn))) + def test_predict_from_checkpoint_path(self): + + def _model_fn(features, labels, mode): + _, _ = features, labels + v = variables.Variable([[16.]], name='weight') + prediction = v * 2 + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(0.), + train_op=constant_op.constant(0.), + predictions=prediction) + + est1 = estimator.Estimator(model_fn=_model_fn) + est1.train(dummy_input_fn, steps=1) + est2 = estimator.Estimator(model_fn=_model_fn, model_dir=est1.model_dir) + self.assertEqual( + [32.], + next( + est2.predict( + dummy_input_fn, + checkpoint_path=saver.latest_checkpoint(est1.model_dir)))) + def test_scaffold_is_used(self): def _model_fn_scaffold(features, labels, mode): |