aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator_test.py')
-rw-r--r--tensorflow/python/estimator/estimator_test.py32
1 files changed, 31 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 398ff20b6b..a1659156a6 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -618,12 +618,20 @@ class EstimatorEvaluateTest(test.TestCase):
class EstimatorPredictTest(test.TestCase):
- def test_no_trained_model(self):
+ def test_no_trained_model_in_model_dir(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
with self.assertRaisesRegexp(ValueError,
'Could not find trained model in model_dir'):
next(est.predict(dummy_input_fn))
+ def test_no_trained_model_invalid_checkpoint_path(self):
+ est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
+ with self.assertRaises(ValueError):
+ next(
+ est.predict(
+ dummy_input_fn,
+ checkpoint_path=saver.latest_checkpoint('fakedir')))
+
def test_tensor_predictions(self):
def _model_fn(features, labels, mode):
@@ -828,6 +836,28 @@ class EstimatorPredictTest(test.TestCase):
est2 = estimator.Estimator(model_fn=_model_fn, model_dir=est1.model_dir)
self.assertEqual([32.], next(est2.predict(dummy_input_fn)))
+ def test_predict_from_checkpoint_path(self):
+
+ def _model_fn(features, labels, mode):
+ _, _ = features, labels
+ v = variables.Variable([[16.]], name='weight')
+ prediction = v * 2
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(0.),
+ train_op=constant_op.constant(0.),
+ predictions=prediction)
+
+ est1 = estimator.Estimator(model_fn=_model_fn)
+ est1.train(dummy_input_fn, steps=1)
+ est2 = estimator.Estimator(model_fn=_model_fn, model_dir=est1.model_dir)
+ self.assertEqual(
+ [32.],
+ next(
+ est2.predict(
+ dummy_input_fn,
+ checkpoint_path=saver.latest_checkpoint(est1.model_dir))))
+
def test_scaffold_is_used(self):
def _model_fn_scaffold(features, labels, mode):